ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Dashboard data types and configuration.
|
|
2
|
+
|
|
3
|
+
Provides unified data structures for the dashboard to eliminate dict/object branching.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DashboardConfig:
|
|
17
|
+
"""Configuration for the dashboard.
|
|
18
|
+
|
|
19
|
+
Attributes
|
|
20
|
+
----------
|
|
21
|
+
allow_pickle_upload : bool
|
|
22
|
+
Whether to allow uploading pickle files. Disabled by default for security.
|
|
23
|
+
styled : bool
|
|
24
|
+
Whether to apply professional CSS styling.
|
|
25
|
+
title : str
|
|
26
|
+
Dashboard title.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
allow_pickle_upload: bool = False # Security: disabled by default
|
|
30
|
+
styled: bool = False
|
|
31
|
+
title: str = "Trade SHAP Diagnostics"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class DashboardBundle:
|
|
36
|
+
"""Unified data container for all dashboard tabs.
|
|
37
|
+
|
|
38
|
+
This normalizes the varied input formats (dict vs object, different field names)
|
|
39
|
+
into a single consistent representation that all tabs can consume.
|
|
40
|
+
|
|
41
|
+
Attributes
|
|
42
|
+
----------
|
|
43
|
+
trades_df : pd.DataFrame
|
|
44
|
+
One row per trade with stable columns:
|
|
45
|
+
- trade_id: str
|
|
46
|
+
- entry_time: datetime
|
|
47
|
+
- exit_time: datetime (optional)
|
|
48
|
+
- pnl: float
|
|
49
|
+
- return_pct: float (optional)
|
|
50
|
+
- symbol: str (optional)
|
|
51
|
+
Sorted chronologically by entry_time for time-series tests.
|
|
52
|
+
returns : np.ndarray | None
|
|
53
|
+
Trade returns array. Prefers return_pct if available, falls back to pnl.
|
|
54
|
+
returns_label : str
|
|
55
|
+
What the returns array represents: "return_pct", "pnl", or "none".
|
|
56
|
+
explanations : list[dict]
|
|
57
|
+
Normalized explanation dictionaries with stable keys:
|
|
58
|
+
- trade_id: str
|
|
59
|
+
- shap_vector: list[float]
|
|
60
|
+
- top_features: list[tuple[str, float]]
|
|
61
|
+
- trade_metrics: dict (optional)
|
|
62
|
+
patterns_df : pd.DataFrame
|
|
63
|
+
One row per error pattern with stable columns:
|
|
64
|
+
- cluster_id: int
|
|
65
|
+
- n_trades: int
|
|
66
|
+
- description: str
|
|
67
|
+
- top_features: list[tuple]
|
|
68
|
+
- hypothesis: str (optional)
|
|
69
|
+
- actions: list[str] (optional)
|
|
70
|
+
- confidence: float (optional)
|
|
71
|
+
n_trades_analyzed : int
|
|
72
|
+
Total number of trades analyzed.
|
|
73
|
+
n_trades_explained : int
|
|
74
|
+
Number of trades successfully explained.
|
|
75
|
+
n_trades_failed : int
|
|
76
|
+
Number of trades that failed explanation.
|
|
77
|
+
failed_trades : list[tuple[str, str]]
|
|
78
|
+
List of (trade_id, reason) for failed explanations.
|
|
79
|
+
config : DashboardConfig
|
|
80
|
+
Dashboard configuration.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
trades_df: pd.DataFrame
|
|
84
|
+
returns: np.ndarray | None
|
|
85
|
+
returns_label: str # "return_pct" | "pnl" | "none"
|
|
86
|
+
explanations: list[dict[str, Any]]
|
|
87
|
+
patterns_df: pd.DataFrame
|
|
88
|
+
n_trades_analyzed: int
|
|
89
|
+
n_trades_explained: int
|
|
90
|
+
n_trades_failed: int
|
|
91
|
+
failed_trades: list[tuple[str, str]]
|
|
92
|
+
config: DashboardConfig = field(default_factory=DashboardConfig)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class ReturnSummary:
|
|
97
|
+
"""Summary statistics for a returns series.
|
|
98
|
+
|
|
99
|
+
Attributes
|
|
100
|
+
----------
|
|
101
|
+
n_samples : int
|
|
102
|
+
Number of samples.
|
|
103
|
+
mean : float
|
|
104
|
+
Mean return.
|
|
105
|
+
std : float
|
|
106
|
+
Standard deviation.
|
|
107
|
+
sharpe : float
|
|
108
|
+
Sharpe ratio (mean / std).
|
|
109
|
+
skewness : float
|
|
110
|
+
Skewness of distribution.
|
|
111
|
+
kurtosis : float
|
|
112
|
+
Kurtosis of distribution (not excess, 3.0 for normal).
|
|
113
|
+
min_val : float
|
|
114
|
+
Minimum value.
|
|
115
|
+
max_val : float
|
|
116
|
+
Maximum value.
|
|
117
|
+
win_rate : float
|
|
118
|
+
Fraction of positive returns.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
n_samples: int
|
|
122
|
+
mean: float
|
|
123
|
+
std: float
|
|
124
|
+
sharpe: float
|
|
125
|
+
skewness: float
|
|
126
|
+
kurtosis: float
|
|
127
|
+
min_val: float
|
|
128
|
+
max_val: float
|
|
129
|
+
win_rate: float
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Trade-level SHAP diagnostics models.
|
|
2
|
+
|
|
3
|
+
This package contains the data models for Trade SHAP analysis.
|
|
4
|
+
The main TradeShapAnalyzer and HypothesisGenerator classes are imported
|
|
5
|
+
from the parent module for backward compatibility.
|
|
6
|
+
|
|
7
|
+
For analysis, import from the evaluation module:
|
|
8
|
+
>>> from ml4t.diagnostic.evaluation import TradeShapAnalyzer
|
|
9
|
+
|
|
10
|
+
For models only:
|
|
11
|
+
>>> from ml4t.diagnostic.evaluation.trade_shap import (
|
|
12
|
+
... TradeShapResult,
|
|
13
|
+
... ErrorPattern,
|
|
14
|
+
... ClusteringResult,
|
|
15
|
+
... TradeShapExplanation,
|
|
16
|
+
... )
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Import models from the dedicated models module
|
|
20
|
+
from ml4t.diagnostic.evaluation.trade_shap.alignment import (
|
|
21
|
+
AlignmentResult,
|
|
22
|
+
TimestampAligner,
|
|
23
|
+
)
|
|
24
|
+
from ml4t.diagnostic.evaluation.trade_shap.characterize import (
|
|
25
|
+
CharacterizationConfig,
|
|
26
|
+
FeatureStatistics,
|
|
27
|
+
PatternCharacterizer,
|
|
28
|
+
benjamini_hochberg,
|
|
29
|
+
)
|
|
30
|
+
from ml4t.diagnostic.evaluation.trade_shap.cluster import (
|
|
31
|
+
ClusteringConfig,
|
|
32
|
+
HierarchicalClusterer,
|
|
33
|
+
compute_centroids,
|
|
34
|
+
compute_cluster_sizes,
|
|
35
|
+
find_optimal_clusters,
|
|
36
|
+
)
|
|
37
|
+
from ml4t.diagnostic.evaluation.trade_shap.explain import TradeShapExplainer
|
|
38
|
+
from ml4t.diagnostic.evaluation.trade_shap.hypotheses import (
|
|
39
|
+
HypothesisConfig,
|
|
40
|
+
HypothesisGenerator,
|
|
41
|
+
Template,
|
|
42
|
+
TemplateMatcher,
|
|
43
|
+
load_templates,
|
|
44
|
+
)
|
|
45
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import (
|
|
46
|
+
ClusteringResult,
|
|
47
|
+
ErrorPattern,
|
|
48
|
+
TradeExplainFailure,
|
|
49
|
+
TradeShapExplanation,
|
|
50
|
+
TradeShapResult,
|
|
51
|
+
)
|
|
52
|
+
from ml4t.diagnostic.evaluation.trade_shap.normalize import (
|
|
53
|
+
NormalizationType,
|
|
54
|
+
normalize,
|
|
55
|
+
normalize_l1,
|
|
56
|
+
normalize_l2,
|
|
57
|
+
standardize,
|
|
58
|
+
)
|
|
59
|
+
from ml4t.diagnostic.evaluation.trade_shap.pipeline import (
|
|
60
|
+
TradeShapPipeline,
|
|
61
|
+
TradeShapPipelineConfig,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
__all__ = [
|
|
65
|
+
# Alignment
|
|
66
|
+
"TimestampAligner",
|
|
67
|
+
"AlignmentResult",
|
|
68
|
+
# Explainer
|
|
69
|
+
"TradeShapExplainer",
|
|
70
|
+
# Normalization
|
|
71
|
+
"normalize",
|
|
72
|
+
"normalize_l1",
|
|
73
|
+
"normalize_l2",
|
|
74
|
+
"standardize",
|
|
75
|
+
"NormalizationType",
|
|
76
|
+
# Clustering
|
|
77
|
+
"HierarchicalClusterer",
|
|
78
|
+
"ClusteringConfig",
|
|
79
|
+
"find_optimal_clusters",
|
|
80
|
+
"compute_cluster_sizes",
|
|
81
|
+
"compute_centroids",
|
|
82
|
+
# Characterization
|
|
83
|
+
"PatternCharacterizer",
|
|
84
|
+
"CharacterizationConfig",
|
|
85
|
+
"FeatureStatistics",
|
|
86
|
+
"benjamini_hochberg",
|
|
87
|
+
# Hypothesis generation
|
|
88
|
+
"HypothesisGenerator",
|
|
89
|
+
"HypothesisConfig",
|
|
90
|
+
"TemplateMatcher",
|
|
91
|
+
"Template",
|
|
92
|
+
"load_templates",
|
|
93
|
+
# Pipeline
|
|
94
|
+
"TradeShapPipeline",
|
|
95
|
+
"TradeShapPipelineConfig",
|
|
96
|
+
# Result models
|
|
97
|
+
"TradeShapResult",
|
|
98
|
+
"TradeShapExplanation",
|
|
99
|
+
"TradeExplainFailure",
|
|
100
|
+
"ClusteringResult",
|
|
101
|
+
"ErrorPattern",
|
|
102
|
+
]
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Fast timestamp alignment for trade SHAP analysis.
|
|
2
|
+
|
|
3
|
+
This module provides O(log n) timestamp lookup instead of O(n) linear scan,
|
|
4
|
+
using precomputed indices and binary search for nearest-match scenarios.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from numpy.typing import NDArray
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class AlignmentResult:
|
|
21
|
+
"""Result of timestamp alignment.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
index: Index into the feature DataFrame, or None if not found
|
|
25
|
+
exact: Whether this was an exact match
|
|
26
|
+
distance_seconds: Distance in seconds from target (0 if exact)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
index: int | None
|
|
30
|
+
exact: bool
|
|
31
|
+
distance_seconds: float
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class TimestampAligner:
|
|
36
|
+
"""Fast timestamp alignment using precomputed indices.
|
|
37
|
+
|
|
38
|
+
Provides O(1) exact match lookup via dict and O(log n) nearest-match
|
|
39
|
+
via binary search on sorted numpy datetime64 array.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
timestamps_ns: Sorted numpy array of timestamps as int64 nanoseconds
|
|
43
|
+
index_by_ts: Dict mapping datetime to index for O(1) exact lookup
|
|
44
|
+
tolerance_seconds: Maximum allowed distance for nearest match
|
|
45
|
+
_sorted_indices: Original indices corresponding to sorted timestamps
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> import pandas as pd
|
|
49
|
+
>>> timestamps = pd.DatetimeIndex(['2024-01-01', '2024-01-02', '2024-01-03'])
|
|
50
|
+
>>> aligner = TimestampAligner.from_datetime_index(timestamps, tolerance_seconds=3600)
|
|
51
|
+
>>> result = aligner.align(datetime(2024, 1, 2))
|
|
52
|
+
>>> result.index
|
|
53
|
+
1
|
|
54
|
+
>>> result.exact
|
|
55
|
+
True
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
timestamps_ns: NDArray[np.int64]
|
|
59
|
+
index_by_ts: dict[datetime, int] = field(default_factory=dict)
|
|
60
|
+
tolerance_seconds: float = 0.0
|
|
61
|
+
_sorted_indices: NDArray[np.intp] = field(default_factory=lambda: np.array([], dtype=np.intp))
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_datetime_index(
|
|
65
|
+
cls,
|
|
66
|
+
timestamps: NDArray | list[datetime],
|
|
67
|
+
tolerance_seconds: float = 0.0,
|
|
68
|
+
) -> TimestampAligner:
|
|
69
|
+
"""Create aligner from datetime index or array.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
timestamps: DatetimeIndex, numpy datetime64 array, or list of datetimes
|
|
73
|
+
tolerance_seconds: Maximum allowed distance for nearest match (default: 0 = exact only)
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
TimestampAligner ready for fast lookups
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
ValueError: If timestamps array is empty
|
|
80
|
+
"""
|
|
81
|
+
# Convert to numpy datetime64[ns] if needed
|
|
82
|
+
ts_array = np.asarray(timestamps, dtype="datetime64[ns]")
|
|
83
|
+
|
|
84
|
+
if len(ts_array) == 0:
|
|
85
|
+
raise ValueError("Cannot create aligner from empty timestamp array")
|
|
86
|
+
|
|
87
|
+
# Convert to int64 nanoseconds for fast comparison
|
|
88
|
+
ts_ns = ts_array.astype(np.int64)
|
|
89
|
+
|
|
90
|
+
# Get sort order (we need original indices)
|
|
91
|
+
sorted_indices = np.argsort(ts_ns)
|
|
92
|
+
sorted_ts_ns = ts_ns[sorted_indices]
|
|
93
|
+
|
|
94
|
+
# Build exact-match dict using original timestamps
|
|
95
|
+
# For duplicates, keep FIRST occurrence (standard behavior)
|
|
96
|
+
index_by_ts: dict[datetime, int] = {}
|
|
97
|
+
for i, ts in enumerate(timestamps):
|
|
98
|
+
if hasattr(ts, "to_pydatetime"):
|
|
99
|
+
# pandas Timestamp
|
|
100
|
+
dt = ts.to_pydatetime()
|
|
101
|
+
elif isinstance(ts, np.datetime64):
|
|
102
|
+
# numpy datetime64
|
|
103
|
+
dt = ts.astype("datetime64[us]").astype(datetime)
|
|
104
|
+
else:
|
|
105
|
+
dt = ts
|
|
106
|
+
# Only store first occurrence of each timestamp
|
|
107
|
+
if dt not in index_by_ts:
|
|
108
|
+
index_by_ts[dt] = i
|
|
109
|
+
|
|
110
|
+
return cls(
|
|
111
|
+
timestamps_ns=sorted_ts_ns,
|
|
112
|
+
index_by_ts=index_by_ts,
|
|
113
|
+
tolerance_seconds=tolerance_seconds,
|
|
114
|
+
_sorted_indices=sorted_indices,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def align(self, target: datetime) -> AlignmentResult:
|
|
118
|
+
"""Find index for target timestamp.
|
|
119
|
+
|
|
120
|
+
First attempts exact match via dict lookup (O(1)).
|
|
121
|
+
If no exact match and tolerance > 0, uses binary search for nearest (O(log n)).
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
target: Target timestamp to align
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
AlignmentResult with index (or None), exact flag, and distance
|
|
128
|
+
"""
|
|
129
|
+
# Try exact match first (O(1))
|
|
130
|
+
if target in self.index_by_ts:
|
|
131
|
+
return AlignmentResult(index=self.index_by_ts[target], exact=True, distance_seconds=0.0)
|
|
132
|
+
|
|
133
|
+
# No exact match - if no tolerance, return None
|
|
134
|
+
if self.tolerance_seconds <= 0:
|
|
135
|
+
return AlignmentResult(index=None, exact=False, distance_seconds=float("inf"))
|
|
136
|
+
|
|
137
|
+
# Binary search for nearest (O(log n))
|
|
138
|
+
target_ns = np.datetime64(target, "ns").astype(np.int64)
|
|
139
|
+
insert_pos = np.searchsorted(self.timestamps_ns, target_ns)
|
|
140
|
+
|
|
141
|
+
# Check neighbors
|
|
142
|
+
candidates = []
|
|
143
|
+
if insert_pos > 0:
|
|
144
|
+
candidates.append(insert_pos - 1)
|
|
145
|
+
if insert_pos < len(self.timestamps_ns):
|
|
146
|
+
candidates.append(insert_pos)
|
|
147
|
+
|
|
148
|
+
if not candidates:
|
|
149
|
+
return AlignmentResult(index=None, exact=False, distance_seconds=float("inf"))
|
|
150
|
+
|
|
151
|
+
# Find closest
|
|
152
|
+
best_idx = None
|
|
153
|
+
best_distance_ns = float("inf")
|
|
154
|
+
|
|
155
|
+
for sorted_idx in candidates:
|
|
156
|
+
distance_ns = abs(self.timestamps_ns[sorted_idx] - target_ns)
|
|
157
|
+
if distance_ns < best_distance_ns:
|
|
158
|
+
best_distance_ns = distance_ns
|
|
159
|
+
best_idx = sorted_idx
|
|
160
|
+
|
|
161
|
+
# Convert to seconds and check tolerance
|
|
162
|
+
distance_seconds = best_distance_ns / 1e9
|
|
163
|
+
|
|
164
|
+
if distance_seconds <= self.tolerance_seconds:
|
|
165
|
+
# Map back to original index
|
|
166
|
+
original_idx = int(self._sorted_indices[best_idx])
|
|
167
|
+
return AlignmentResult(
|
|
168
|
+
index=original_idx,
|
|
169
|
+
exact=False,
|
|
170
|
+
distance_seconds=distance_seconds,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return AlignmentResult(index=None, exact=False, distance_seconds=distance_seconds)
|
|
174
|
+
|
|
175
|
+
def align_many(self, targets: list[datetime]) -> list[AlignmentResult]:
|
|
176
|
+
"""Align multiple timestamps.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
targets: List of target timestamps
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
List of AlignmentResult for each target
|
|
183
|
+
"""
|
|
184
|
+
return [self.align(t) for t in targets]
|
|
185
|
+
|
|
186
|
+
def __len__(self) -> int:
|
|
187
|
+
"""Number of timestamps in the aligner."""
|
|
188
|
+
return len(self.timestamps_ns)
|