ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
"""Pydantic models for Trade SHAP diagnostics.
|
|
2
|
+
|
|
3
|
+
This module contains the data models used throughout the Trade SHAP analysis:
|
|
4
|
+
- TradeShapExplanation: SHAP explanation for a single trade
|
|
5
|
+
- ClusteringResult: Result of error pattern clustering
|
|
6
|
+
- ErrorPattern: Characterized error pattern from clustered trades
|
|
7
|
+
- TradeShapResult: Complete result of trade-level SHAP analysis
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import TYPE_CHECKING, Any
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from numpy.typing import NDArray
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TradeExplainFailure(BaseModel):
|
|
24
|
+
"""Structured failure result for trade explanation.
|
|
25
|
+
|
|
26
|
+
Used instead of exceptions for expected failure cases (alignment missing,
|
|
27
|
+
feature mismatch, etc.) to enable batch processing without try/except.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
trade_id: Unique trade identifier
|
|
31
|
+
timestamp: Trade entry timestamp
|
|
32
|
+
reason: Machine-readable failure reason code
|
|
33
|
+
details: Additional context about the failure
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
trade_id: str = Field(..., description="Unique trade identifier")
|
|
37
|
+
timestamp: datetime = Field(..., description="Trade entry timestamp")
|
|
38
|
+
reason: str = Field(
|
|
39
|
+
...,
|
|
40
|
+
description="Failure reason: 'alignment_missing', 'shap_error', 'feature_mismatch'",
|
|
41
|
+
)
|
|
42
|
+
details: dict[str, Any] = Field(default_factory=dict, description="Additional failure context")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TradeShapExplanation(BaseModel):
|
|
46
|
+
"""SHAP explanation for a single trade.
|
|
47
|
+
|
|
48
|
+
Contains SHAP attribution details for one trade, including:
|
|
49
|
+
- Top contributing features (sorted by absolute SHAP value)
|
|
50
|
+
- Feature values at trade entry
|
|
51
|
+
- Full SHAP vector for all features
|
|
52
|
+
- Waterfall plot data (future enhancement)
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
trade_id: Unique trade identifier (symbol_timestamp)
|
|
56
|
+
timestamp: Trade entry timestamp
|
|
57
|
+
top_features: List of (feature_name, shap_value) sorted by |shap_value| descending
|
|
58
|
+
feature_values: Dictionary of feature values at trade entry
|
|
59
|
+
shap_vector: Full SHAP vector for all features (numpy array)
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> explanation.top_features[:3]
|
|
63
|
+
[('momentum_20d', 0.342), ('volatility_10d', -0.215), ('rsi_14d', 0.108)]
|
|
64
|
+
|
|
65
|
+
>>> explanation.feature_values['momentum_20d']
|
|
66
|
+
1.235
|
|
67
|
+
|
|
68
|
+
>>> explanation.shap_vector.shape
|
|
69
|
+
(50,) # 50 features
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
trade_id: str = Field(..., description="Unique trade identifier")
|
|
73
|
+
timestamp: datetime = Field(..., description="Trade entry timestamp")
|
|
74
|
+
top_features: list[tuple[str, float]] = Field(
|
|
75
|
+
..., description="Top N features by absolute SHAP value (descending)"
|
|
76
|
+
)
|
|
77
|
+
feature_values: dict[str, float] = Field(
|
|
78
|
+
..., description="Feature values at trade entry timestamp"
|
|
79
|
+
)
|
|
80
|
+
shap_vector: NDArray[np.floating[Any]] = Field(
|
|
81
|
+
..., description="Full SHAP vector for all features"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
class Config:
|
|
85
|
+
"""Pydantic config."""
|
|
86
|
+
|
|
87
|
+
arbitrary_types_allowed = True
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ClusteringResult(BaseModel):
|
|
91
|
+
"""Result of error pattern clustering.
|
|
92
|
+
|
|
93
|
+
Contains cluster assignments, centroids, quality metrics, and linkage matrix
|
|
94
|
+
for dendrogram visualization.
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
n_clusters: Number of clusters identified
|
|
98
|
+
cluster_assignments: Cluster ID for each trade (0-indexed list)
|
|
99
|
+
linkage_matrix: Scipy linkage matrix for dendrogram plotting
|
|
100
|
+
centroids: Mean SHAP vector for each cluster (shape: n_clusters x n_features)
|
|
101
|
+
silhouette_score: Quality metric (range: -1 to 1, higher is better)
|
|
102
|
+
- 1.0: Perfect separation
|
|
103
|
+
- 0.5: Good separation
|
|
104
|
+
- 0.0: Overlapping clusters
|
|
105
|
+
- <0.0: Poor clustering (mis-assigned trades)
|
|
106
|
+
davies_bouldin_score: Davies-Bouldin Index (lower = better, min: 0)
|
|
107
|
+
- Measures ratio of within-cluster to between-cluster distances
|
|
108
|
+
- < 1.0: Good clustering
|
|
109
|
+
- 1.0-2.0: Acceptable clustering
|
|
110
|
+
- > 2.0: Poor clustering
|
|
111
|
+
calinski_harabasz_score: Calinski-Harabasz Score (higher = better, min: 0)
|
|
112
|
+
- Also known as Variance Ratio Criterion
|
|
113
|
+
- Measures ratio of between-cluster to within-cluster dispersion
|
|
114
|
+
- Higher values indicate better-defined clusters
|
|
115
|
+
cluster_sizes: Number of trades in each cluster
|
|
116
|
+
distance_metric: Distance metric used ('euclidean', 'cosine', etc.)
|
|
117
|
+
linkage_method: Linkage method used ('ward', 'average', 'complete', 'single')
|
|
118
|
+
|
|
119
|
+
Example - Basic inspection:
|
|
120
|
+
>>> result = analyzer.cluster_patterns(shap_vectors)
|
|
121
|
+
>>> print(f"Found {result.n_clusters} clusters")
|
|
122
|
+
>>> print(f"Cluster sizes: {result.cluster_sizes}")
|
|
123
|
+
>>> print(f"Quality (silhouette): {result.silhouette_score:.3f}")
|
|
124
|
+
|
|
125
|
+
Example - Visualize dendrogram:
|
|
126
|
+
>>> from scipy.cluster.hierarchy import dendrogram
|
|
127
|
+
>>> import matplotlib.pyplot as plt
|
|
128
|
+
>>> dendrogram(result.linkage_matrix)
|
|
129
|
+
>>> plt.title("Error Pattern Dendrogram")
|
|
130
|
+
>>> plt.xlabel("Trade Index")
|
|
131
|
+
>>> plt.ylabel("Distance")
|
|
132
|
+
>>> plt.show()
|
|
133
|
+
|
|
134
|
+
Example - Analyze specific cluster:
|
|
135
|
+
>>> cluster_id = 0
|
|
136
|
+
>>> trades_in_cluster = [i for i, c in enumerate(result.cluster_assignments) if c == cluster_id]
|
|
137
|
+
>>> cluster_centroid = result.centroids[cluster_id]
|
|
138
|
+
>>> print(f"Cluster {cluster_id}: {len(trades_in_cluster)} trades")
|
|
139
|
+
>>> print(f"Centroid (mean SHAP): {cluster_centroid}")
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
- linkage_matrix can be used directly with scipy.cluster.hierarchy.dendrogram()
|
|
143
|
+
- centroids represent "typical" SHAP pattern for each cluster
|
|
144
|
+
- silhouette_score > 0.5 indicates well-separated clusters
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
n_clusters: int = Field(..., description="Number of clusters identified")
|
|
148
|
+
cluster_assignments: list[int] = Field(..., description="Cluster ID for each trade (0-indexed)")
|
|
149
|
+
linkage_matrix: NDArray[np.floating[Any]] = Field(
|
|
150
|
+
..., description="Scipy linkage matrix for dendrogram"
|
|
151
|
+
)
|
|
152
|
+
centroids: NDArray[np.floating[Any]] = Field(
|
|
153
|
+
..., description="Mean SHAP vector per cluster (n_clusters x n_features)"
|
|
154
|
+
)
|
|
155
|
+
silhouette_score: float = Field(
|
|
156
|
+
..., description="Cluster quality metric (range: -1 to 1, higher is better)"
|
|
157
|
+
)
|
|
158
|
+
davies_bouldin_score: float | None = Field(
|
|
159
|
+
None,
|
|
160
|
+
description="Davies-Bouldin Index (lower = better, min: 0, no upper bound). "
|
|
161
|
+
"Measures ratio of within-cluster to between-cluster distances. "
|
|
162
|
+
"Values < 1.0 indicate good clustering.",
|
|
163
|
+
)
|
|
164
|
+
calinski_harabasz_score: float | None = Field(
|
|
165
|
+
None,
|
|
166
|
+
description="Calinski-Harabasz Score (higher = better, min: 0, no upper bound). "
|
|
167
|
+
"Also known as Variance Ratio Criterion. "
|
|
168
|
+
"Measures ratio of between-cluster to within-cluster dispersion.",
|
|
169
|
+
)
|
|
170
|
+
cluster_sizes: list[int] = Field(..., description="Number of trades per cluster")
|
|
171
|
+
distance_metric: str = Field(..., description="Distance metric used for clustering")
|
|
172
|
+
linkage_method: str = Field(..., description="Linkage method used for clustering")
|
|
173
|
+
|
|
174
|
+
class Config:
|
|
175
|
+
"""Pydantic config."""
|
|
176
|
+
|
|
177
|
+
arbitrary_types_allowed = True
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class ErrorPattern(BaseModel):
|
|
181
|
+
"""Characterized error pattern from clustered trades.
|
|
182
|
+
|
|
183
|
+
Represents a distinct pattern of trading errors identified through SHAP-based
|
|
184
|
+
clustering and statistical characterization. Contains the defining features,
|
|
185
|
+
quality metrics, and (optionally) generated hypotheses and action suggestions.
|
|
186
|
+
|
|
187
|
+
Attributes:
|
|
188
|
+
cluster_id: Unique identifier for this error pattern (0-indexed)
|
|
189
|
+
n_trades: Number of trades exhibiting this pattern
|
|
190
|
+
description: Human-readable pattern description
|
|
191
|
+
Format: "High feature_X (up 0.45) + Low feature_Y (down -0.32) -> Losses"
|
|
192
|
+
top_features: Top contributing SHAP features
|
|
193
|
+
List of (feature_name, mean_shap, p_value_t, p_value_mw, is_significant)
|
|
194
|
+
separation_score: Distance to nearest other cluster (higher = more distinct)
|
|
195
|
+
distinctiveness: Ratio of max SHAP vs other clusters (higher = more unique)
|
|
196
|
+
hypothesis: Optional generated hypothesis about why pattern causes losses
|
|
197
|
+
actions: Optional list of suggested remediation actions
|
|
198
|
+
confidence: Optional confidence score for hypothesis (0-1)
|
|
199
|
+
|
|
200
|
+
Example - Basic pattern:
|
|
201
|
+
>>> pattern = ErrorPattern(
|
|
202
|
+
... cluster_id=0,
|
|
203
|
+
... n_trades=15,
|
|
204
|
+
... description="High momentum (up 0.45) + Low volatility (down -0.32) -> Losses",
|
|
205
|
+
... top_features=[
|
|
206
|
+
... ("momentum_20d", 0.45, 0.001, 0.002, True),
|
|
207
|
+
... ("volatility_10d", -0.32, 0.003, 0.004, True)
|
|
208
|
+
... ],
|
|
209
|
+
... separation_score=1.2,
|
|
210
|
+
... distinctiveness=1.8
|
|
211
|
+
... )
|
|
212
|
+
>>> print(pattern.summary())
|
|
213
|
+
"Pattern 0: 15 trades - High momentum (up 0.45) + Low volatility (down -0.32) -> Losses"
|
|
214
|
+
|
|
215
|
+
Example - With hypothesis and actions:
|
|
216
|
+
>>> pattern = ErrorPattern(
|
|
217
|
+
... cluster_id=1,
|
|
218
|
+
... n_trades=22,
|
|
219
|
+
... description="High RSI (up 0.38) + High volume (up 0.29) -> Losses",
|
|
220
|
+
... top_features=[("rsi_14", 0.38, 0.001, 0.001, True)],
|
|
221
|
+
... separation_score=0.9,
|
|
222
|
+
... distinctiveness=1.5,
|
|
223
|
+
... hypothesis="Trades entering overbought conditions with high volume (potential reversals)",
|
|
224
|
+
... actions=[
|
|
225
|
+
... "Add overbought filter: skip trades when RSI > 70",
|
|
226
|
+
... "Consider volume profile: avoid high volume in overbought zones",
|
|
227
|
+
... "Add mean reversion features to capture reversal dynamics"
|
|
228
|
+
... ],
|
|
229
|
+
... confidence=0.85
|
|
230
|
+
... )
|
|
231
|
+
>>> for action in pattern.actions:
|
|
232
|
+
... print(f" - {action}")
|
|
233
|
+
|
|
234
|
+
Note:
|
|
235
|
+
- hypothesis, actions, and confidence are populated by HypothesisGenerator
|
|
236
|
+
- top_features are sorted by absolute SHAP value (descending)
|
|
237
|
+
- separation_score and distinctiveness are quality metrics for pattern validation
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
cluster_id: int = Field(..., description="Cluster identifier (0-indexed)", ge=0)
|
|
241
|
+
n_trades: int = Field(..., description="Number of trades in this pattern", gt=0)
|
|
242
|
+
description: str = Field(..., description="Human-readable pattern description", min_length=1)
|
|
243
|
+
top_features: list[tuple[str, float, float, float, bool]] = Field(
|
|
244
|
+
...,
|
|
245
|
+
description="Top SHAP features: (name, mean_shap, p_value_t, p_value_mw, is_significant)",
|
|
246
|
+
)
|
|
247
|
+
separation_score: float = Field(
|
|
248
|
+
..., description="Distance to nearest other cluster (higher = better)", ge=0.0
|
|
249
|
+
)
|
|
250
|
+
distinctiveness: float = Field(
|
|
251
|
+
..., description="Ratio of max SHAP vs other clusters (higher = better)", gt=0.0
|
|
252
|
+
)
|
|
253
|
+
hypothesis: str | None = Field(
|
|
254
|
+
None, description="Generated hypothesis about why this pattern causes losses"
|
|
255
|
+
)
|
|
256
|
+
actions: list[str] | None = Field(
|
|
257
|
+
None, description="Suggested remediation actions for this pattern"
|
|
258
|
+
)
|
|
259
|
+
confidence: float | None = Field(
|
|
260
|
+
None, description="Confidence score for hypothesis (0-1)", ge=0.0, le=1.0
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def to_dict(self) -> dict[str, Any]:
|
|
264
|
+
"""Convert ErrorPattern to dictionary.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Dictionary representation suitable for JSON serialization
|
|
268
|
+
|
|
269
|
+
Example:
|
|
270
|
+
>>> pattern_dict = pattern.to_dict()
|
|
271
|
+
>>> import json
|
|
272
|
+
>>> json.dumps(pattern_dict, indent=2)
|
|
273
|
+
"""
|
|
274
|
+
return {
|
|
275
|
+
"cluster_id": self.cluster_id,
|
|
276
|
+
"n_trades": self.n_trades,
|
|
277
|
+
"description": self.description,
|
|
278
|
+
"top_features": [
|
|
279
|
+
{
|
|
280
|
+
"feature_name": feat[0],
|
|
281
|
+
"mean_shap": feat[1],
|
|
282
|
+
"p_value_t": feat[2],
|
|
283
|
+
"p_value_mw": feat[3],
|
|
284
|
+
"is_significant": feat[4],
|
|
285
|
+
}
|
|
286
|
+
for feat in self.top_features
|
|
287
|
+
],
|
|
288
|
+
"separation_score": self.separation_score,
|
|
289
|
+
"distinctiveness": self.distinctiveness,
|
|
290
|
+
"hypothesis": self.hypothesis,
|
|
291
|
+
"actions": self.actions if self.actions else [],
|
|
292
|
+
"confidence": self.confidence,
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
def summary(self, include_actions: bool = False) -> str:
|
|
296
|
+
"""Generate human-readable summary of error pattern.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
include_actions: Whether to include action suggestions in summary
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Formatted summary string
|
|
303
|
+
|
|
304
|
+
Example:
|
|
305
|
+
>>> print(pattern.summary())
|
|
306
|
+
"Pattern 0: 15 trades - High momentum (up 0.45) + Low volatility (down -0.32) -> Losses"
|
|
307
|
+
|
|
308
|
+
>>> print(pattern.summary(include_actions=True))
|
|
309
|
+
'''
|
|
310
|
+
Pattern 0: 15 trades
|
|
311
|
+
Description: High momentum (up 0.45) + Low volatility (down -0.32) -> Losses
|
|
312
|
+
Hypothesis: Trades entering overbought conditions
|
|
313
|
+
Actions:
|
|
314
|
+
- Add overbought filter: skip trades when RSI > 70
|
|
315
|
+
- Consider volume profile
|
|
316
|
+
Confidence: 85%
|
|
317
|
+
'''
|
|
318
|
+
"""
|
|
319
|
+
if not include_actions or not self.hypothesis:
|
|
320
|
+
# Simple one-line summary
|
|
321
|
+
return f"Pattern {self.cluster_id}: {self.n_trades} trades - {self.description}"
|
|
322
|
+
|
|
323
|
+
# Detailed multi-line summary with hypothesis and actions
|
|
324
|
+
lines = [
|
|
325
|
+
f"Pattern {self.cluster_id}: {self.n_trades} trades",
|
|
326
|
+
f"Description: {self.description}",
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
if self.hypothesis:
|
|
330
|
+
lines.append(f"Hypothesis: {self.hypothesis}")
|
|
331
|
+
|
|
332
|
+
if self.actions:
|
|
333
|
+
lines.append("Actions:")
|
|
334
|
+
for action in self.actions:
|
|
335
|
+
lines.append(f" - {action}")
|
|
336
|
+
|
|
337
|
+
if self.confidence is not None:
|
|
338
|
+
lines.append(f"Confidence: {self.confidence:.0%}")
|
|
339
|
+
|
|
340
|
+
return "\n".join(lines)
|
|
341
|
+
|
|
342
|
+
class Config:
|
|
343
|
+
"""Pydantic config."""
|
|
344
|
+
|
|
345
|
+
arbitrary_types_allowed = True
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class TradeShapResult(BaseModel):
|
|
349
|
+
"""Complete result of trade-level SHAP analysis.
|
|
350
|
+
|
|
351
|
+
Contains SHAP explanations for multiple trades, along with error patterns
|
|
352
|
+
and actionable recommendations.
|
|
353
|
+
|
|
354
|
+
Attributes:
|
|
355
|
+
n_trades_analyzed: Total number of trades attempted to analyze
|
|
356
|
+
n_trades_explained: Number of trades successfully explained
|
|
357
|
+
n_trades_failed: Number of trades that failed explanation
|
|
358
|
+
explanations: List of successful TradeShapExplanation objects
|
|
359
|
+
failed_trades: List of (trade_id, error_message) tuples for failed trades
|
|
360
|
+
error_patterns: Identified error patterns from clustering
|
|
361
|
+
|
|
362
|
+
Example:
|
|
363
|
+
>>> result = analyzer.explain_worst_trades(trades, n=20)
|
|
364
|
+
>>> print(f"Success rate: {result.n_trades_explained}/{result.n_trades_analyzed}")
|
|
365
|
+
>>> for explanation in result.explanations:
|
|
366
|
+
... print(f"Trade {explanation.trade_id}: top feature = {explanation.top_features[0]}")
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
n_trades_analyzed: int = Field(..., description="Total trades analyzed")
|
|
370
|
+
n_trades_explained: int = Field(..., description="Trades successfully explained")
|
|
371
|
+
n_trades_failed: int = Field(..., description="Trades that failed explanation")
|
|
372
|
+
explanations: list[TradeShapExplanation] = Field(
|
|
373
|
+
default_factory=list, description="Successful SHAP explanations"
|
|
374
|
+
)
|
|
375
|
+
failed_trades: list[tuple[str, str]] = Field(
|
|
376
|
+
default_factory=list, description="Failed trades: (trade_id, error_message)"
|
|
377
|
+
)
|
|
378
|
+
error_patterns: list[ErrorPattern] = Field(
|
|
379
|
+
default_factory=list,
|
|
380
|
+
description="Identified error patterns (populated by clustering and characterization)",
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
class Config:
|
|
384
|
+
"""Pydantic config."""
|
|
385
|
+
|
|
386
|
+
arbitrary_types_allowed = True
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Normalization functions for SHAP vector clustering.
|
|
2
|
+
|
|
3
|
+
Provides L1, L2, and standardization normalization with proper
|
|
4
|
+
handling of edge cases (zero vectors, zero variance).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
NormalizationType = Literal["l1", "l2", "standardize", "none"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def normalize_l1(vectors: NDArray[np.floating[Any]]) -> NDArray[np.floating[Any]]:
|
|
21
|
+
"""L1 normalization: Scale each row by sum of absolute values.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
vectors: Input vectors of shape (n_samples, n_features)
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
L1-normalized vectors where each row sums to 1.0 (in absolute terms)
|
|
28
|
+
|
|
29
|
+
Note:
|
|
30
|
+
Zero vectors are returned unchanged (no division by zero)
|
|
31
|
+
"""
|
|
32
|
+
l1_norms = np.sum(np.abs(vectors), axis=1, keepdims=True)
|
|
33
|
+
l1_norms = np.where(l1_norms == 0, 1.0, l1_norms)
|
|
34
|
+
return vectors / l1_norms
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def normalize_l2(vectors: NDArray[np.floating[Any]]) -> NDArray[np.floating[Any]]:
|
|
38
|
+
"""L2 normalization: Scale each row to unit Euclidean norm.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
vectors: Input vectors of shape (n_samples, n_features)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
L2-normalized unit vectors (norm = 1.0 per row)
|
|
45
|
+
|
|
46
|
+
Note:
|
|
47
|
+
Zero vectors are returned unchanged (no division by zero)
|
|
48
|
+
"""
|
|
49
|
+
l2_norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
|
50
|
+
l2_norms = np.where(l2_norms == 0, 1.0, l2_norms)
|
|
51
|
+
return vectors / l2_norms
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def standardize(vectors: NDArray[np.floating[Any]]) -> NDArray[np.floating[Any]]:
|
|
55
|
+
"""Z-score standardization: (x - mean) / std per feature.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
vectors: Input vectors of shape (n_samples, n_features)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Standardized vectors (mean=0, std=1 per feature column)
|
|
62
|
+
|
|
63
|
+
Note:
|
|
64
|
+
Zero-variance features are returned unchanged
|
|
65
|
+
"""
|
|
66
|
+
mean = np.mean(vectors, axis=0, keepdims=True)
|
|
67
|
+
std = np.std(vectors, axis=0, keepdims=True)
|
|
68
|
+
std = np.where(std == 0, 1.0, std)
|
|
69
|
+
return (vectors - mean) / std
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def normalize(
|
|
73
|
+
vectors: NDArray[np.floating[Any]],
|
|
74
|
+
method: NormalizationType | None = None,
|
|
75
|
+
) -> NDArray[np.floating[Any]]:
|
|
76
|
+
"""Apply normalization to vectors.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
vectors: Input vectors of shape (n_samples, n_features)
|
|
80
|
+
method: Normalization method: 'l1', 'l2', 'standardize', 'none', or None
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Normalized vectors
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If normalization produces NaN/Inf or method is unknown
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
>>> vectors = np.array([[1, 2, 3], [4, 5, 6]])
|
|
90
|
+
>>> normalize(vectors, method='l2')
|
|
91
|
+
array([[0.267, 0.535, 0.802],
|
|
92
|
+
[0.456, 0.570, 0.684]])
|
|
93
|
+
"""
|
|
94
|
+
if method is None or method == "none":
|
|
95
|
+
return vectors.copy()
|
|
96
|
+
elif method == "l1":
|
|
97
|
+
normalized = normalize_l1(vectors)
|
|
98
|
+
elif method == "l2":
|
|
99
|
+
normalized = normalize_l2(vectors)
|
|
100
|
+
elif method == "standardize":
|
|
101
|
+
normalized = standardize(vectors)
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Invalid normalization method: '{method}'. "
|
|
105
|
+
"Valid options: 'l1', 'l2', 'standardize', 'none', None"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Validate output
|
|
109
|
+
if not np.all(np.isfinite(normalized)):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"Normalization produced NaN or Inf values. "
|
|
112
|
+
"This may indicate zero-variance features or numerical instability. "
|
|
113
|
+
f"Normalization method: {method}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return normalized
|