ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""Pattern characterization with proper statistical testing.
|
|
2
|
+
|
|
3
|
+
This module provides PatternCharacterizer for characterizing error patterns
|
|
4
|
+
identified through clustering, with:
|
|
5
|
+
- Welch's t-test (doesn't assume equal variance)
|
|
6
|
+
- Mann-Whitney U test (non-parametric)
|
|
7
|
+
- Benjamini-Hochberg FDR correction for multiple testing
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import TYPE_CHECKING, Any
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from scipy import stats
|
|
17
|
+
|
|
18
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import ErrorPattern
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from numpy.typing import NDArray
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CharacterizationConfig:
|
|
26
|
+
"""Configuration for pattern characterization.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
alpha: Significance level for statistical tests (default: 0.05)
|
|
30
|
+
top_n_features: Number of top features to include in characterization
|
|
31
|
+
use_fdr_correction: Whether to apply Benjamini-Hochberg FDR correction
|
|
32
|
+
min_samples_per_test: Minimum samples needed for each group in t-test
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
alpha: float = 0.05
|
|
36
|
+
top_n_features: int = 5
|
|
37
|
+
use_fdr_correction: bool = True
|
|
38
|
+
min_samples_per_test: int = 3
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class FeatureStatistics:
|
|
43
|
+
"""Statistical test results for a single feature.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
feature_name: Name of the feature
|
|
47
|
+
mean_shap: Mean SHAP value in the cluster
|
|
48
|
+
mean_shap_other: Mean SHAP value in other clusters
|
|
49
|
+
p_value_t: P-value from Welch's t-test
|
|
50
|
+
p_value_mw: P-value from Mann-Whitney U test
|
|
51
|
+
q_value_t: FDR-corrected p-value (t-test), if correction applied
|
|
52
|
+
q_value_mw: FDR-corrected p-value (MW test), if correction applied
|
|
53
|
+
is_significant: Whether the feature is statistically significant
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
feature_name: str
|
|
57
|
+
mean_shap: float
|
|
58
|
+
mean_shap_other: float
|
|
59
|
+
p_value_t: float
|
|
60
|
+
p_value_mw: float
|
|
61
|
+
q_value_t: float | None = None
|
|
62
|
+
q_value_mw: float | None = None
|
|
63
|
+
is_significant: bool = False
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def benjamini_hochberg(
|
|
67
|
+
p_values: list[float], alpha: float = 0.05
|
|
68
|
+
) -> tuple[list[float], list[bool]]:
|
|
69
|
+
"""Apply Benjamini-Hochberg FDR correction to p-values.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
p_values: List of raw p-values
|
|
73
|
+
alpha: Significance level (default: 0.05)
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (q_values, is_significant) where:
|
|
77
|
+
- q_values: FDR-adjusted p-values (monotone)
|
|
78
|
+
- is_significant: Boolean mask for significant results
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
BH procedure controls False Discovery Rate (FDR) - the expected
|
|
82
|
+
proportion of false discoveries among rejected hypotheses.
|
|
83
|
+
This is less conservative than Bonferroni correction.
|
|
84
|
+
"""
|
|
85
|
+
if not p_values:
|
|
86
|
+
return [], []
|
|
87
|
+
|
|
88
|
+
n = len(p_values)
|
|
89
|
+
p_array = np.asarray(p_values)
|
|
90
|
+
|
|
91
|
+
# Sort p-values and track original order
|
|
92
|
+
sorted_indices = np.argsort(p_array)
|
|
93
|
+
sorted_p = p_array[sorted_indices]
|
|
94
|
+
|
|
95
|
+
# BH adjustment: q_i = min(p_i * n / rank, 1.0)
|
|
96
|
+
# Then enforce monotonicity from largest to smallest
|
|
97
|
+
ranks = np.arange(1, n + 1)
|
|
98
|
+
q_sorted = np.minimum(sorted_p * n / ranks, 1.0)
|
|
99
|
+
|
|
100
|
+
# Enforce monotonicity: q[i] = min(q[i], q[i+1], ..., q[n])
|
|
101
|
+
# Process from end to start
|
|
102
|
+
for i in range(n - 2, -1, -1):
|
|
103
|
+
q_sorted[i] = min(q_sorted[i], q_sorted[i + 1])
|
|
104
|
+
|
|
105
|
+
# Restore original order
|
|
106
|
+
q_values = np.empty(n)
|
|
107
|
+
q_values[sorted_indices] = q_sorted
|
|
108
|
+
|
|
109
|
+
# Determine significance
|
|
110
|
+
is_significant = q_values < alpha
|
|
111
|
+
|
|
112
|
+
return q_values.tolist(), is_significant.tolist()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class PatternCharacterizer:
|
|
116
|
+
"""Characterizes error patterns with proper statistical testing.
|
|
117
|
+
|
|
118
|
+
Uses Welch's t-test (doesn't assume equal variance) and Mann-Whitney U test,
|
|
119
|
+
with optional Benjamini-Hochberg FDR correction for multiple testing.
|
|
120
|
+
|
|
121
|
+
Attributes:
|
|
122
|
+
config: Characterization configuration
|
|
123
|
+
feature_names: List of all feature names
|
|
124
|
+
|
|
125
|
+
Example:
|
|
126
|
+
>>> characterizer = PatternCharacterizer(feature_names)
|
|
127
|
+
>>> pattern = characterizer.characterize_cluster(
|
|
128
|
+
... cluster_shap=cluster_vectors,
|
|
129
|
+
... other_shap=other_vectors,
|
|
130
|
+
... cluster_id=0,
|
|
131
|
+
... )
|
|
132
|
+
>>> print(pattern.top_features)
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
feature_names: list[str],
|
|
138
|
+
config: CharacterizationConfig | None = None,
|
|
139
|
+
) -> None:
|
|
140
|
+
"""Initialize characterizer.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
feature_names: List of all feature names
|
|
144
|
+
config: Characterization configuration (uses defaults if None)
|
|
145
|
+
"""
|
|
146
|
+
self.feature_names = feature_names
|
|
147
|
+
self.config = config or CharacterizationConfig()
|
|
148
|
+
|
|
149
|
+
def characterize_cluster(
|
|
150
|
+
self,
|
|
151
|
+
cluster_shap: NDArray[np.floating[Any]],
|
|
152
|
+
other_shap: NDArray[np.floating[Any]],
|
|
153
|
+
cluster_id: int,
|
|
154
|
+
centroids: NDArray[np.floating[Any]] | None = None,
|
|
155
|
+
) -> ErrorPattern:
|
|
156
|
+
"""Characterize a single cluster as an error pattern.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
cluster_shap: SHAP vectors for trades in this cluster (n_cluster x n_features)
|
|
160
|
+
other_shap: SHAP vectors for all other trades (n_other x n_features)
|
|
161
|
+
cluster_id: Cluster identifier (0-indexed)
|
|
162
|
+
centroids: Optional cluster centroids for separation score calculation
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
ErrorPattern with statistical characterization
|
|
166
|
+
"""
|
|
167
|
+
n_trades = cluster_shap.shape[0]
|
|
168
|
+
n_features = len(self.feature_names)
|
|
169
|
+
|
|
170
|
+
# Compute mean SHAP per feature for this cluster
|
|
171
|
+
mean_shap_cluster = np.mean(cluster_shap, axis=0)
|
|
172
|
+
mean_shap_other = (
|
|
173
|
+
np.mean(other_shap, axis=0) if len(other_shap) > 0 else np.zeros(n_features)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Statistical tests for each feature
|
|
177
|
+
feature_stats = self._compute_feature_statistics(
|
|
178
|
+
cluster_shap, other_shap, mean_shap_cluster, mean_shap_other
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Apply FDR correction if configured
|
|
182
|
+
if self.config.use_fdr_correction:
|
|
183
|
+
feature_stats = self._apply_fdr_correction(feature_stats)
|
|
184
|
+
|
|
185
|
+
# Sort by absolute mean SHAP (descending)
|
|
186
|
+
feature_stats.sort(key=lambda x: abs(x.mean_shap), reverse=True)
|
|
187
|
+
|
|
188
|
+
# Take top N
|
|
189
|
+
top_stats = feature_stats[: self.config.top_n_features]
|
|
190
|
+
|
|
191
|
+
# Build top_features tuple list for ErrorPattern
|
|
192
|
+
top_features = [
|
|
193
|
+
(
|
|
194
|
+
fs.feature_name,
|
|
195
|
+
fs.mean_shap,
|
|
196
|
+
fs.p_value_t,
|
|
197
|
+
fs.p_value_mw,
|
|
198
|
+
fs.is_significant,
|
|
199
|
+
)
|
|
200
|
+
for fs in top_stats
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Generate pattern description
|
|
204
|
+
description = self._generate_description(top_stats)
|
|
205
|
+
|
|
206
|
+
# Compute separation and distinctiveness scores
|
|
207
|
+
separation_score = self._compute_separation_score(mean_shap_cluster, centroids, cluster_id)
|
|
208
|
+
distinctiveness = self._compute_distinctiveness(mean_shap_cluster, mean_shap_other)
|
|
209
|
+
|
|
210
|
+
return ErrorPattern(
|
|
211
|
+
cluster_id=cluster_id,
|
|
212
|
+
n_trades=n_trades,
|
|
213
|
+
description=description,
|
|
214
|
+
top_features=top_features,
|
|
215
|
+
separation_score=separation_score,
|
|
216
|
+
distinctiveness=distinctiveness,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def _compute_feature_statistics(
|
|
220
|
+
self,
|
|
221
|
+
cluster_shap: NDArray[np.floating[Any]],
|
|
222
|
+
other_shap: NDArray[np.floating[Any]],
|
|
223
|
+
mean_shap_cluster: NDArray[np.floating[Any]],
|
|
224
|
+
mean_shap_other: NDArray[np.floating[Any]],
|
|
225
|
+
) -> list[FeatureStatistics]:
|
|
226
|
+
"""Compute statistical tests for each feature.
|
|
227
|
+
|
|
228
|
+
Uses Welch's t-test (equal_var=False) instead of standard t-test
|
|
229
|
+
to handle unequal variances between groups.
|
|
230
|
+
"""
|
|
231
|
+
results = []
|
|
232
|
+
|
|
233
|
+
for idx, feature_name in enumerate(self.feature_names):
|
|
234
|
+
cluster_values = cluster_shap[:, idx]
|
|
235
|
+
other_values = other_shap[:, idx] if len(other_shap) > 0 else np.array([])
|
|
236
|
+
|
|
237
|
+
# Skip if insufficient samples
|
|
238
|
+
if (
|
|
239
|
+
len(cluster_values) < self.config.min_samples_per_test
|
|
240
|
+
or len(other_values) < self.config.min_samples_per_test
|
|
241
|
+
):
|
|
242
|
+
results.append(
|
|
243
|
+
FeatureStatistics(
|
|
244
|
+
feature_name=feature_name,
|
|
245
|
+
mean_shap=float(mean_shap_cluster[idx]),
|
|
246
|
+
mean_shap_other=float(mean_shap_other[idx]),
|
|
247
|
+
p_value_t=1.0,
|
|
248
|
+
p_value_mw=1.0,
|
|
249
|
+
is_significant=False,
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
# Welch's t-test (doesn't assume equal variance)
|
|
255
|
+
# This is the key fix: using equal_var=False
|
|
256
|
+
try:
|
|
257
|
+
t_stat, p_value_t = stats.ttest_ind(cluster_values, other_values, equal_var=False)
|
|
258
|
+
p_value_t = float(p_value_t) if not np.isnan(p_value_t) else 1.0
|
|
259
|
+
except Exception:
|
|
260
|
+
p_value_t = 1.0
|
|
261
|
+
|
|
262
|
+
# Mann-Whitney U test (non-parametric)
|
|
263
|
+
try:
|
|
264
|
+
_, p_value_mw = stats.mannwhitneyu(
|
|
265
|
+
cluster_values, other_values, alternative="two-sided"
|
|
266
|
+
)
|
|
267
|
+
p_value_mw = float(p_value_mw) if not np.isnan(p_value_mw) else 1.0
|
|
268
|
+
except ValueError:
|
|
269
|
+
# Can fail if all values are identical
|
|
270
|
+
p_value_mw = 1.0
|
|
271
|
+
|
|
272
|
+
results.append(
|
|
273
|
+
FeatureStatistics(
|
|
274
|
+
feature_name=feature_name,
|
|
275
|
+
mean_shap=float(mean_shap_cluster[idx]),
|
|
276
|
+
mean_shap_other=float(mean_shap_other[idx]),
|
|
277
|
+
p_value_t=p_value_t,
|
|
278
|
+
p_value_mw=p_value_mw,
|
|
279
|
+
# Will be set after FDR correction
|
|
280
|
+
is_significant=False,
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return results
|
|
285
|
+
|
|
286
|
+
def _apply_fdr_correction(
|
|
287
|
+
self, feature_stats: list[FeatureStatistics]
|
|
288
|
+
) -> list[FeatureStatistics]:
|
|
289
|
+
"""Apply Benjamini-Hochberg FDR correction to all p-values.
|
|
290
|
+
|
|
291
|
+
This corrects for multiple testing across all features, reducing
|
|
292
|
+
false positive rate at the cost of some statistical power.
|
|
293
|
+
"""
|
|
294
|
+
if not feature_stats:
|
|
295
|
+
return feature_stats
|
|
296
|
+
|
|
297
|
+
# Collect p-values
|
|
298
|
+
p_values_t = [fs.p_value_t for fs in feature_stats]
|
|
299
|
+
p_values_mw = [fs.p_value_mw for fs in feature_stats]
|
|
300
|
+
|
|
301
|
+
# Apply BH correction
|
|
302
|
+
q_values_t, sig_t = benjamini_hochberg(p_values_t, self.config.alpha)
|
|
303
|
+
q_values_mw, sig_mw = benjamini_hochberg(p_values_mw, self.config.alpha)
|
|
304
|
+
|
|
305
|
+
# Update statistics with corrected values
|
|
306
|
+
corrected = []
|
|
307
|
+
for i, fs in enumerate(feature_stats):
|
|
308
|
+
# Significant if either test rejects after FDR correction
|
|
309
|
+
is_sig = sig_t[i] or sig_mw[i]
|
|
310
|
+
|
|
311
|
+
corrected.append(
|
|
312
|
+
FeatureStatistics(
|
|
313
|
+
feature_name=fs.feature_name,
|
|
314
|
+
mean_shap=fs.mean_shap,
|
|
315
|
+
mean_shap_other=fs.mean_shap_other,
|
|
316
|
+
p_value_t=fs.p_value_t,
|
|
317
|
+
p_value_mw=fs.p_value_mw,
|
|
318
|
+
q_value_t=q_values_t[i],
|
|
319
|
+
q_value_mw=q_values_mw[i],
|
|
320
|
+
is_significant=is_sig,
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return corrected
|
|
325
|
+
|
|
326
|
+
def _generate_description(self, top_stats: list[FeatureStatistics]) -> str:
|
|
327
|
+
"""Generate human-readable pattern description."""
|
|
328
|
+
if not top_stats:
|
|
329
|
+
return "Unknown pattern"
|
|
330
|
+
|
|
331
|
+
# Filter to significant features only
|
|
332
|
+
sig_features = [fs for fs in top_stats if fs.is_significant]
|
|
333
|
+
|
|
334
|
+
# Fall back to top features if none significant
|
|
335
|
+
features_to_use = sig_features[:3] if sig_features else top_stats[:2]
|
|
336
|
+
|
|
337
|
+
components = []
|
|
338
|
+
for fs in features_to_use:
|
|
339
|
+
direction = "High" if fs.mean_shap > 0 else "Low"
|
|
340
|
+
arrow = "↑" if fs.mean_shap > 0 else "↓"
|
|
341
|
+
components.append(f"{direction} {fs.feature_name} ({arrow}{fs.mean_shap:.2f})")
|
|
342
|
+
|
|
343
|
+
if len(components) == 1:
|
|
344
|
+
return f"{components[0]} → Losses"
|
|
345
|
+
return " + ".join(components) + " → Losses"
|
|
346
|
+
|
|
347
|
+
def _compute_separation_score(
|
|
348
|
+
self,
|
|
349
|
+
centroid: NDArray[np.floating[Any]],
|
|
350
|
+
all_centroids: NDArray[np.floating[Any]] | None,
|
|
351
|
+
cluster_id: int,
|
|
352
|
+
) -> float:
|
|
353
|
+
"""Compute separation score (distance to nearest other cluster)."""
|
|
354
|
+
if all_centroids is None or len(all_centroids) <= 1:
|
|
355
|
+
return 0.0
|
|
356
|
+
|
|
357
|
+
min_distance = float("inf")
|
|
358
|
+
for i, other_centroid in enumerate(all_centroids):
|
|
359
|
+
if i != cluster_id:
|
|
360
|
+
distance = float(np.linalg.norm(centroid - other_centroid))
|
|
361
|
+
min_distance = min(min_distance, distance)
|
|
362
|
+
|
|
363
|
+
return min_distance if min_distance != float("inf") else 0.0
|
|
364
|
+
|
|
365
|
+
def _compute_distinctiveness(
|
|
366
|
+
self,
|
|
367
|
+
cluster_centroid: NDArray[np.floating[Any]],
|
|
368
|
+
other_mean: NDArray[np.floating[Any]],
|
|
369
|
+
) -> float:
|
|
370
|
+
"""Compute distinctiveness (ratio of max SHAP vs other clusters)."""
|
|
371
|
+
max_cluster = np.max(np.abs(cluster_centroid))
|
|
372
|
+
max_other = np.max(np.abs(other_mean))
|
|
373
|
+
|
|
374
|
+
if max_other == 0:
|
|
375
|
+
return float(max_cluster) if max_cluster > 0 else 1.0
|
|
376
|
+
|
|
377
|
+
return float(max_cluster / max_other)
|
|
378
|
+
|
|
379
|
+
def characterize_all_clusters(
|
|
380
|
+
self,
|
|
381
|
+
shap_vectors: NDArray[np.floating[Any]],
|
|
382
|
+
cluster_labels: list[int],
|
|
383
|
+
n_clusters: int,
|
|
384
|
+
centroids: NDArray[np.floating[Any]] | None = None,
|
|
385
|
+
) -> list[ErrorPattern]:
|
|
386
|
+
"""Characterize all clusters.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
shap_vectors: All SHAP vectors (n_samples x n_features)
|
|
390
|
+
cluster_labels: Cluster assignment for each sample
|
|
391
|
+
n_clusters: Total number of clusters
|
|
392
|
+
centroids: Optional cluster centroids
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
List of ErrorPattern for each cluster
|
|
396
|
+
"""
|
|
397
|
+
labels_array = np.asarray(cluster_labels)
|
|
398
|
+
patterns = []
|
|
399
|
+
|
|
400
|
+
for cluster_id in range(n_clusters):
|
|
401
|
+
mask = labels_array == cluster_id
|
|
402
|
+
cluster_shap = shap_vectors[mask]
|
|
403
|
+
other_shap = shap_vectors[~mask]
|
|
404
|
+
|
|
405
|
+
pattern = self.characterize_cluster(
|
|
406
|
+
cluster_shap=cluster_shap,
|
|
407
|
+
other_shap=other_shap,
|
|
408
|
+
cluster_id=cluster_id,
|
|
409
|
+
centroids=centroids,
|
|
410
|
+
)
|
|
411
|
+
patterns.append(pattern)
|
|
412
|
+
|
|
413
|
+
return patterns
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""Hierarchical clustering for trade error patterns.
|
|
2
|
+
|
|
3
|
+
Provides clustering of SHAP vectors to identify distinct error patterns,
|
|
4
|
+
with proper handling of small sample sizes.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import ClusteringResult
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from numpy.typing import NDArray
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DistanceMetric = Literal["euclidean", "cosine", "correlation", "cityblock"]
|
|
21
|
+
LinkageMethod = Literal["ward", "average", "complete", "single"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ClusteringConfig:
|
|
26
|
+
"""Configuration for hierarchical clustering.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
distance_metric: Distance metric for pdist ('euclidean', 'cosine', etc.)
|
|
30
|
+
linkage_method: Linkage method for hierarchical clustering
|
|
31
|
+
min_cluster_size: Minimum trades per cluster
|
|
32
|
+
min_trades_for_clustering: Minimum trades required to attempt clustering
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
distance_metric: DistanceMetric = "euclidean"
|
|
36
|
+
linkage_method: LinkageMethod = "ward"
|
|
37
|
+
min_cluster_size: int = 5
|
|
38
|
+
min_trades_for_clustering: int = 10
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def find_optimal_clusters(
|
|
42
|
+
linkage_matrix: NDArray[np.floating[Any]],
|
|
43
|
+
n_samples: int,
|
|
44
|
+
min_cluster_size: int = 5,
|
|
45
|
+
) -> int:
|
|
46
|
+
"""Find optimal number of clusters using elbow method.
|
|
47
|
+
|
|
48
|
+
Uses the acceleration of merge distances (second derivative) to find
|
|
49
|
+
the "elbow" point in the dendrogram.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
linkage_matrix: Linkage matrix from hierarchical clustering
|
|
53
|
+
n_samples: Total number of samples
|
|
54
|
+
min_cluster_size: Minimum samples per cluster
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Optimal number of clusters respecting min_cluster_size constraint
|
|
58
|
+
|
|
59
|
+
Note:
|
|
60
|
+
The key fix here is respecting min_cluster_size even when that means
|
|
61
|
+
returning 1 cluster. Previously, the code would force 2 clusters even
|
|
62
|
+
when there weren't enough samples to support min_cluster_size per cluster.
|
|
63
|
+
"""
|
|
64
|
+
# Get merge distances (last column of linkage matrix)
|
|
65
|
+
distances = linkage_matrix[:, 2]
|
|
66
|
+
|
|
67
|
+
# Compute first derivative (rate of change)
|
|
68
|
+
first_deriv = np.diff(distances)
|
|
69
|
+
|
|
70
|
+
# Compute second derivative (acceleration)
|
|
71
|
+
second_deriv = np.diff(first_deriv)
|
|
72
|
+
|
|
73
|
+
# Find elbow: Maximum acceleration point
|
|
74
|
+
if len(second_deriv) > 0:
|
|
75
|
+
elbow_idx = int(np.argmax(second_deriv))
|
|
76
|
+
# Convert index to number of clusters
|
|
77
|
+
# linkage_matrix has (n_samples - 1) rows
|
|
78
|
+
n_clusters = max(1, n_samples - elbow_idx - 2)
|
|
79
|
+
else:
|
|
80
|
+
# Fallback: sqrt(n) heuristic
|
|
81
|
+
n_clusters = max(1, int(np.sqrt(n_samples)))
|
|
82
|
+
|
|
83
|
+
# CRITICAL FIX: Respect min_cluster_size constraint
|
|
84
|
+
# max_clusters is at least 1 to avoid edge case where we'd return 0
|
|
85
|
+
max_clusters = max(1, n_samples // min_cluster_size)
|
|
86
|
+
n_clusters = min(n_clusters, max_clusters)
|
|
87
|
+
|
|
88
|
+
# Only force at least 2 clusters if we have room for them
|
|
89
|
+
# This is the bug fix: don't force 2 if max_clusters < 2
|
|
90
|
+
if max_clusters >= 2:
|
|
91
|
+
n_clusters = max(2, n_clusters)
|
|
92
|
+
|
|
93
|
+
return int(n_clusters)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def compute_cluster_sizes(
|
|
97
|
+
labels: NDArray[np.intp] | list[int],
|
|
98
|
+
n_clusters: int,
|
|
99
|
+
) -> list[int]:
|
|
100
|
+
"""Compute number of samples in each cluster using vectorized bincount.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
labels: Cluster assignment for each sample (0-indexed)
|
|
104
|
+
n_clusters: Total number of clusters
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
List of cluster sizes
|
|
108
|
+
"""
|
|
109
|
+
labels_array = np.asarray(labels, dtype=np.intp)
|
|
110
|
+
counts = np.bincount(labels_array, minlength=n_clusters)
|
|
111
|
+
return counts.tolist()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def compute_centroids(
|
|
115
|
+
vectors: NDArray[np.floating[Any]],
|
|
116
|
+
labels: NDArray[np.intp] | list[int],
|
|
117
|
+
n_clusters: int,
|
|
118
|
+
) -> NDArray[np.floating[Any]]:
|
|
119
|
+
"""Compute cluster centroids (mean vector per cluster) using vectorized operations.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
vectors: SHAP vectors of shape (n_samples, n_features)
|
|
123
|
+
labels: Cluster assignment for each sample (0-indexed)
|
|
124
|
+
n_clusters: Total number of clusters
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Centroids of shape (n_clusters, n_features)
|
|
128
|
+
"""
|
|
129
|
+
labels_array = np.asarray(labels, dtype=np.intp)
|
|
130
|
+
n_features = vectors.shape[1]
|
|
131
|
+
|
|
132
|
+
centroids = np.zeros((n_clusters, n_features), dtype=np.float64)
|
|
133
|
+
|
|
134
|
+
for k in range(n_clusters):
|
|
135
|
+
mask = labels_array == k
|
|
136
|
+
if np.any(mask):
|
|
137
|
+
centroids[k] = vectors[mask].mean(axis=0)
|
|
138
|
+
|
|
139
|
+
return centroids
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class HierarchicalClusterer:
|
|
143
|
+
"""Hierarchical clustering for SHAP vectors.
|
|
144
|
+
|
|
145
|
+
Provides clustering of trade SHAP vectors to identify distinct error patterns,
|
|
146
|
+
with quality metrics and dendrogram support.
|
|
147
|
+
|
|
148
|
+
Attributes:
|
|
149
|
+
config: Clustering configuration
|
|
150
|
+
|
|
151
|
+
Example:
|
|
152
|
+
>>> clusterer = HierarchicalClusterer()
|
|
153
|
+
>>> result = clusterer.cluster(shap_vectors, n_clusters=3)
|
|
154
|
+
>>> print(f"Silhouette: {result.silhouette_score:.3f}")
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def __init__(self, config: ClusteringConfig | None = None) -> None:
|
|
158
|
+
"""Initialize clusterer.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
config: Clustering configuration (uses defaults if None)
|
|
162
|
+
"""
|
|
163
|
+
self.config = config or ClusteringConfig()
|
|
164
|
+
|
|
165
|
+
def cluster(
|
|
166
|
+
self,
|
|
167
|
+
vectors: NDArray[np.floating[Any]],
|
|
168
|
+
n_clusters: int | None = None,
|
|
169
|
+
) -> ClusteringResult:
|
|
170
|
+
"""Cluster SHAP vectors using hierarchical clustering.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
vectors: SHAP vectors of shape (n_samples, n_features)
|
|
174
|
+
n_clusters: Number of clusters (auto-determined if None)
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
ClusteringResult with assignments, linkage matrix, and quality metrics
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If insufficient samples or invalid input shape
|
|
181
|
+
ImportError: If scipy is not installed
|
|
182
|
+
"""
|
|
183
|
+
# Validate inputs
|
|
184
|
+
if vectors.size == 0:
|
|
185
|
+
raise ValueError("Cannot cluster empty vectors")
|
|
186
|
+
|
|
187
|
+
if vectors.ndim != 2:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"vectors must be 2D array (n_samples, n_features), got shape {vectors.shape}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
n_samples, n_features = vectors.shape
|
|
193
|
+
|
|
194
|
+
if n_samples < self.config.min_trades_for_clustering:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Insufficient samples for clustering: {n_samples} < "
|
|
197
|
+
f"{self.config.min_trades_for_clustering}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Import scipy
|
|
201
|
+
try:
|
|
202
|
+
import scipy.cluster.hierarchy as sch
|
|
203
|
+
from scipy.spatial.distance import pdist
|
|
204
|
+
except ImportError as e:
|
|
205
|
+
raise ImportError(
|
|
206
|
+
"scipy required for clustering. Install with: pip install scipy"
|
|
207
|
+
) from e
|
|
208
|
+
|
|
209
|
+
# Compute pairwise distances
|
|
210
|
+
distances = pdist(vectors, metric=self.config.distance_metric)
|
|
211
|
+
|
|
212
|
+
# Perform hierarchical clustering
|
|
213
|
+
linkage_matrix = sch.linkage(distances, method=self.config.linkage_method)
|
|
214
|
+
|
|
215
|
+
# Determine number of clusters
|
|
216
|
+
if n_clusters is None:
|
|
217
|
+
n_clusters = find_optimal_clusters(
|
|
218
|
+
linkage_matrix, n_samples, self.config.min_cluster_size
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Cut dendrogram to get cluster assignments
|
|
222
|
+
labels = sch.fcluster(linkage_matrix, t=n_clusters, criterion="maxclust")
|
|
223
|
+
# fcluster returns 1-indexed labels, convert to 0-indexed
|
|
224
|
+
labels = labels - 1
|
|
225
|
+
|
|
226
|
+
# Compute cluster metrics
|
|
227
|
+
cluster_sizes = compute_cluster_sizes(labels, n_clusters)
|
|
228
|
+
centroids = compute_centroids(vectors, labels, n_clusters)
|
|
229
|
+
|
|
230
|
+
# Compute quality metrics
|
|
231
|
+
silhouette = self._compute_silhouette(vectors, labels)
|
|
232
|
+
davies_bouldin = self._compute_davies_bouldin(vectors, labels)
|
|
233
|
+
calinski_harabasz = self._compute_calinski_harabasz(vectors, labels)
|
|
234
|
+
|
|
235
|
+
return ClusteringResult(
|
|
236
|
+
n_clusters=n_clusters,
|
|
237
|
+
cluster_assignments=labels.tolist(),
|
|
238
|
+
linkage_matrix=linkage_matrix,
|
|
239
|
+
centroids=centroids,
|
|
240
|
+
silhouette_score=silhouette,
|
|
241
|
+
davies_bouldin_score=davies_bouldin,
|
|
242
|
+
calinski_harabasz_score=calinski_harabasz,
|
|
243
|
+
cluster_sizes=cluster_sizes,
|
|
244
|
+
distance_metric=self.config.distance_metric,
|
|
245
|
+
linkage_method=self.config.linkage_method,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def _compute_silhouette(
|
|
249
|
+
self,
|
|
250
|
+
vectors: NDArray[np.floating[Any]],
|
|
251
|
+
labels: NDArray[np.intp],
|
|
252
|
+
) -> float:
|
|
253
|
+
"""Compute silhouette score for clustering quality.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Silhouette score (-1 to 1, higher is better)
|
|
257
|
+
"""
|
|
258
|
+
try:
|
|
259
|
+
from sklearn.metrics import silhouette_score
|
|
260
|
+
|
|
261
|
+
# Need at least 2 clusters for silhouette
|
|
262
|
+
unique_labels = np.unique(labels)
|
|
263
|
+
if len(unique_labels) < 2:
|
|
264
|
+
return 0.0
|
|
265
|
+
|
|
266
|
+
return float(silhouette_score(vectors, labels))
|
|
267
|
+
except ImportError:
|
|
268
|
+
return 0.0
|
|
269
|
+
|
|
270
|
+
def _compute_davies_bouldin(
|
|
271
|
+
self,
|
|
272
|
+
vectors: NDArray[np.floating[Any]],
|
|
273
|
+
labels: NDArray[np.intp],
|
|
274
|
+
) -> float | None:
|
|
275
|
+
"""Compute Davies-Bouldin index (lower is better)."""
|
|
276
|
+
try:
|
|
277
|
+
from sklearn.metrics import davies_bouldin_score
|
|
278
|
+
|
|
279
|
+
unique_labels = np.unique(labels)
|
|
280
|
+
if len(unique_labels) < 2:
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
return float(davies_bouldin_score(vectors, labels))
|
|
284
|
+
except ImportError:
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
def _compute_calinski_harabasz(
|
|
288
|
+
self,
|
|
289
|
+
vectors: NDArray[np.floating[Any]],
|
|
290
|
+
labels: NDArray[np.intp],
|
|
291
|
+
) -> float | None:
|
|
292
|
+
"""Compute Calinski-Harabasz score (higher is better)."""
|
|
293
|
+
try:
|
|
294
|
+
from sklearn.metrics import calinski_harabasz_score
|
|
295
|
+
|
|
296
|
+
unique_labels = np.unique(labels)
|
|
297
|
+
if len(unique_labels) < 2:
|
|
298
|
+
return None
|
|
299
|
+
|
|
300
|
+
return float(calinski_harabasz_score(vectors, labels))
|
|
301
|
+
except ImportError:
|
|
302
|
+
return None
|