ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1136 @@
|
|
|
1
|
+
"""Trade-level analysis for backtest diagnostics and SHAP attribution.
|
|
2
|
+
|
|
3
|
+
This module provides tools for analyzing individual trades from backtests,
|
|
4
|
+
identifying worst/best performers, and computing trade-level statistics.
|
|
5
|
+
|
|
6
|
+
Core Components:
|
|
7
|
+
- TradeMetrics: Enriched trade data with computed metrics
|
|
8
|
+
- TradeAnalysis: Main analyzer for extracting worst/best trades
|
|
9
|
+
- TradeStatistics: Aggregate statistics across trades
|
|
10
|
+
- TradeAnalysisResult: Result schema with serialization
|
|
11
|
+
|
|
12
|
+
Integration with ml4t-diagnostics workflow:
|
|
13
|
+
1. Load backtest results → Extract trades (TradeRecord instances)
|
|
14
|
+
2. Analyze trades → Identify worst performers (TradeAnalysis)
|
|
15
|
+
3. Compute statistics → Understand trade distribution (TradeStatistics)
|
|
16
|
+
4. Feed to SHAP → Explain failures (trade_shap_diagnostics.py)
|
|
17
|
+
|
|
18
|
+
Example - Basic usage:
|
|
19
|
+
>>> from ml4t.diagnostic.integration import TradeRecord
|
|
20
|
+
>>> from ml4t.diagnostic.evaluation import TradeAnalysis
|
|
21
|
+
>>> from datetime import datetime, timedelta
|
|
22
|
+
>>>
|
|
23
|
+
>>> # Create trade records from backtest
|
|
24
|
+
>>> trades = [
|
|
25
|
+
... TradeRecord(
|
|
26
|
+
... timestamp=datetime(2024, 1, 15),
|
|
27
|
+
... symbol="AAPL",
|
|
28
|
+
... entry_price=150.0,
|
|
29
|
+
... exit_price=155.0,
|
|
30
|
+
... pnl=500.0,
|
|
31
|
+
... duration=timedelta(days=5),
|
|
32
|
+
... direction="long"
|
|
33
|
+
... ),
|
|
34
|
+
... # ... more trades
|
|
35
|
+
... ]
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Analyze trades
|
|
38
|
+
>>> analyzer = TradeAnalysis(trades)
|
|
39
|
+
>>> worst = analyzer.worst_trades(n=10)
|
|
40
|
+
>>> best = analyzer.best_trades(n=10)
|
|
41
|
+
>>> stats = analyzer.compute_statistics()
|
|
42
|
+
>>>
|
|
43
|
+
>>> print(f"Win rate: {stats.win_rate:.2%}")
|
|
44
|
+
>>> print(f"Average PnL: ${stats.avg_pnl:.2f}")
|
|
45
|
+
|
|
46
|
+
Example - Advanced usage with config:
|
|
47
|
+
>>> from ml4t.diagnostic.config import TradeConfig, ExtractionSettings, FilterSettings
|
|
48
|
+
>>>
|
|
49
|
+
>>> config = TradeConfig(
|
|
50
|
+
... extraction=ExtractionSettings(n_worst=20, n_best=10),
|
|
51
|
+
... filter=FilterSettings(
|
|
52
|
+
... min_duration=timedelta(hours=1),
|
|
53
|
+
... min_pnl=-1000.0
|
|
54
|
+
... )
|
|
55
|
+
... )
|
|
56
|
+
>>>
|
|
57
|
+
>>> analyzer = TradeAnalysis.from_config(trades, config)
|
|
58
|
+
>>> result = analyzer.analyze()
|
|
59
|
+
>>>
|
|
60
|
+
>>> # Export for storage
|
|
61
|
+
>>> result.to_json_string()
|
|
62
|
+
>>> result.get_dataframe("worst_trades")
|
|
63
|
+
>>> result.get_dataframe("statistics")
|
|
64
|
+
|
|
65
|
+
Example - Integration with SHAP diagnostics:
|
|
66
|
+
>>> from ml4t.diagnostic.evaluation import TradeShapAnalyzer
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Get worst trades
|
|
69
|
+
>>> worst_trades = analyzer.worst_trades(n=20)
|
|
70
|
+
>>>
|
|
71
|
+
>>> # Explain with SHAP
|
|
72
|
+
>>> shap_analyzer = TradeShapAnalyzer(model, features, shap_values)
|
|
73
|
+
>>> patterns = shap_analyzer.explain_worst_trades(worst_trades)
|
|
74
|
+
>>>
|
|
75
|
+
>>> for pattern in patterns:
|
|
76
|
+
... print(pattern.hypothesis)
|
|
77
|
+
... print(pattern.actions)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
from __future__ import annotations
|
|
81
|
+
|
|
82
|
+
import heapq
|
|
83
|
+
from datetime import UTC, datetime, timedelta
|
|
84
|
+
from typing import Any, Literal, SupportsFloat, cast
|
|
85
|
+
|
|
86
|
+
import polars as pl
|
|
87
|
+
from pydantic import BaseModel, Field, field_validator
|
|
88
|
+
|
|
89
|
+
from ml4t.diagnostic.integration.backtest_contract import TradeRecord
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TradeMetrics(BaseModel):
|
|
93
|
+
"""Enriched trade data with computed metrics for analysis.
|
|
94
|
+
|
|
95
|
+
Extends TradeRecord with additional computed fields useful for
|
|
96
|
+
trade analysis, ranking, and diagnostics. Provides methods for
|
|
97
|
+
DataFrame conversion and serialization.
|
|
98
|
+
|
|
99
|
+
This class wraps TradeRecord and adds:
|
|
100
|
+
- Return percentage calculation
|
|
101
|
+
- Duration in hours/days for easy filtering
|
|
102
|
+
- Return per day (annualized-like metric)
|
|
103
|
+
- Ranking helpers
|
|
104
|
+
|
|
105
|
+
Required Fields (from TradeRecord):
|
|
106
|
+
timestamp: Trade exit timestamp
|
|
107
|
+
symbol: Asset symbol
|
|
108
|
+
entry_price: Average entry price
|
|
109
|
+
exit_price: Average exit price
|
|
110
|
+
pnl: Realized profit/loss
|
|
111
|
+
duration: Time between entry and exit
|
|
112
|
+
direction: Trade direction (long/short)
|
|
113
|
+
|
|
114
|
+
Computed Fields:
|
|
115
|
+
return_pct: Return as percentage of entry price
|
|
116
|
+
duration_hours: Duration in hours
|
|
117
|
+
duration_days: Duration in days
|
|
118
|
+
pnl_per_day: PnL normalized by duration
|
|
119
|
+
|
|
120
|
+
Example - Create from TradeRecord:
|
|
121
|
+
>>> trade_record = TradeRecord(
|
|
122
|
+
... timestamp=datetime(2024, 1, 15),
|
|
123
|
+
... symbol="AAPL",
|
|
124
|
+
... entry_price=150.0,
|
|
125
|
+
... exit_price=155.0,
|
|
126
|
+
... pnl=500.0,
|
|
127
|
+
... duration=timedelta(days=5),
|
|
128
|
+
... direction="long",
|
|
129
|
+
... quantity=100
|
|
130
|
+
... )
|
|
131
|
+
>>> metrics = TradeMetrics.from_trade_record(trade_record)
|
|
132
|
+
>>> print(f"Return: {metrics.return_pct:.2%}")
|
|
133
|
+
>>> print(f"PnL per day: ${metrics.pnl_per_day:.2f}")
|
|
134
|
+
|
|
135
|
+
Example - Convert to DataFrame:
|
|
136
|
+
>>> trades = [TradeMetrics.from_trade_record(tr) for tr in trade_records]
|
|
137
|
+
>>> df = TradeMetrics.to_dataframe(trades)
|
|
138
|
+
>>> print(df.select(["symbol", "pnl", "return_pct"]))
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
# Core fields (from TradeRecord)
|
|
142
|
+
timestamp: datetime = Field(..., description="Trade exit timestamp")
|
|
143
|
+
symbol: str = Field(..., min_length=1, description="Asset symbol")
|
|
144
|
+
entry_price: float = Field(..., gt=0.0, description="Average entry price")
|
|
145
|
+
exit_price: float = Field(..., gt=0.0, description="Average exit price")
|
|
146
|
+
pnl: float = Field(..., description="Realized profit/loss")
|
|
147
|
+
duration: timedelta = Field(..., description="Time between entry and exit")
|
|
148
|
+
direction: Literal["long", "short"] | None = Field(None, description="Trade direction")
|
|
149
|
+
|
|
150
|
+
# Optional fields (from TradeRecord)
|
|
151
|
+
quantity: float | None = Field(None, gt=0.0, description="Position size")
|
|
152
|
+
entry_timestamp: datetime | None = Field(None, description="Position entry timestamp")
|
|
153
|
+
fees: float | None = Field(None, ge=0.0, description="Total transaction fees")
|
|
154
|
+
slippage: float | None = Field(None, ge=0.0, description="Slippage cost")
|
|
155
|
+
metadata: dict[str, Any] | None = Field(None, description="Arbitrary metadata")
|
|
156
|
+
regime_info: dict[str, str] | None = Field(None, description="Market regime info")
|
|
157
|
+
|
|
158
|
+
@field_validator("duration")
|
|
159
|
+
@classmethod
|
|
160
|
+
def validate_duration_positive(cls, v: timedelta) -> timedelta:
|
|
161
|
+
"""Ensure duration is positive."""
|
|
162
|
+
if v.total_seconds() <= 0:
|
|
163
|
+
raise ValueError(f"Duration must be positive, got {v}")
|
|
164
|
+
return v
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def return_pct(self) -> float:
|
|
168
|
+
"""Return as percentage of entry price.
|
|
169
|
+
|
|
170
|
+
Formula:
|
|
171
|
+
- Long: (exit_price - entry_price) / entry_price
|
|
172
|
+
- Short: (entry_price - exit_price) / entry_price
|
|
173
|
+
- Unknown: absolute price change / entry_price (unsigned)
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Return percentage (e.g., 0.05 = 5% return)
|
|
177
|
+
|
|
178
|
+
Example:
|
|
179
|
+
>>> metrics.return_pct # 0.0333 = 3.33%
|
|
180
|
+
"""
|
|
181
|
+
if self.direction == "long":
|
|
182
|
+
return (self.exit_price - self.entry_price) / self.entry_price
|
|
183
|
+
elif self.direction == "short":
|
|
184
|
+
return (self.entry_price - self.exit_price) / self.entry_price
|
|
185
|
+
else:
|
|
186
|
+
# Unknown direction - use absolute price change (unsigned return)
|
|
187
|
+
return abs(self.exit_price - self.entry_price) / self.entry_price
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def duration_hours(self) -> float:
|
|
191
|
+
"""Duration in hours.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Duration as float hours
|
|
195
|
+
|
|
196
|
+
Example:
|
|
197
|
+
>>> metrics.duration_hours # 120.5
|
|
198
|
+
"""
|
|
199
|
+
return self.duration.total_seconds() / 3600.0
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def duration_days(self) -> float:
|
|
203
|
+
"""Duration in days.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Duration as float days
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
>>> metrics.duration_days # 5.02
|
|
210
|
+
"""
|
|
211
|
+
return self.duration.total_seconds() / 86400.0
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def pnl_per_day(self) -> float:
|
|
215
|
+
"""PnL normalized by duration in days.
|
|
216
|
+
|
|
217
|
+
Provides a duration-adjusted performance metric. Useful for
|
|
218
|
+
comparing trades of different holding periods.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
PnL per day (can be negative)
|
|
222
|
+
|
|
223
|
+
Example:
|
|
224
|
+
>>> metrics.pnl_per_day # 100.0 (earned $100/day)
|
|
225
|
+
"""
|
|
226
|
+
if self.duration_days == 0:
|
|
227
|
+
return 0.0
|
|
228
|
+
return self.pnl / self.duration_days
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def from_trade_record(cls, trade: TradeRecord) -> TradeMetrics:
|
|
232
|
+
"""Create TradeMetrics from TradeRecord.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
trade: TradeRecord instance from backtest
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
TradeMetrics with computed fields
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
>>> metrics = TradeMetrics.from_trade_record(trade_record)
|
|
242
|
+
"""
|
|
243
|
+
return cls(
|
|
244
|
+
timestamp=trade.timestamp,
|
|
245
|
+
symbol=trade.symbol,
|
|
246
|
+
entry_price=trade.entry_price,
|
|
247
|
+
exit_price=trade.exit_price,
|
|
248
|
+
pnl=trade.pnl,
|
|
249
|
+
duration=trade.duration,
|
|
250
|
+
direction=trade.direction,
|
|
251
|
+
quantity=trade.quantity,
|
|
252
|
+
entry_timestamp=trade.entry_timestamp,
|
|
253
|
+
fees=trade.fees,
|
|
254
|
+
slippage=trade.slippage,
|
|
255
|
+
metadata=trade.metadata,
|
|
256
|
+
regime_info=trade.regime_info,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def to_dict(self) -> dict[str, Any]:
|
|
260
|
+
"""Export to dictionary format.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Dictionary with all trade data including computed fields
|
|
264
|
+
|
|
265
|
+
Example:
|
|
266
|
+
>>> metrics.to_dict()
|
|
267
|
+
{
|
|
268
|
+
'timestamp': '2024-01-15T10:30:00',
|
|
269
|
+
'symbol': 'AAPL',
|
|
270
|
+
'pnl': 500.0,
|
|
271
|
+
'return_pct': 0.0333,
|
|
272
|
+
'duration_hours': 120.0,
|
|
273
|
+
...
|
|
274
|
+
}
|
|
275
|
+
"""
|
|
276
|
+
data = self.model_dump(mode="json")
|
|
277
|
+
# Convert timedelta to total seconds for JSON compatibility
|
|
278
|
+
if "duration" in data:
|
|
279
|
+
data["duration_seconds"] = self.duration.total_seconds()
|
|
280
|
+
del data["duration"]
|
|
281
|
+
# Include computed properties
|
|
282
|
+
data["return_pct"] = self.return_pct
|
|
283
|
+
data["duration_hours"] = self.duration_hours
|
|
284
|
+
data["duration_days"] = self.duration_days
|
|
285
|
+
data["pnl_per_day"] = self.pnl_per_day
|
|
286
|
+
return data
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def to_dataframe(trades: list[TradeMetrics]) -> pl.DataFrame:
|
|
290
|
+
"""Convert list of TradeMetrics to Polars DataFrame.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
trades: List of TradeMetrics instances
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Polars DataFrame with all trade data and computed metrics
|
|
297
|
+
|
|
298
|
+
Example:
|
|
299
|
+
>>> df = TradeMetrics.to_dataframe(metrics_list)
|
|
300
|
+
>>> print(df.select(["symbol", "pnl", "return_pct"]))
|
|
301
|
+
>>> df.sort("pnl").head(10) # Worst 10 trades
|
|
302
|
+
"""
|
|
303
|
+
if not trades:
|
|
304
|
+
# Return empty DataFrame with expected schema (must match non-empty schema)
|
|
305
|
+
return pl.DataFrame(
|
|
306
|
+
schema={
|
|
307
|
+
"timestamp": pl.Datetime,
|
|
308
|
+
"symbol": pl.String,
|
|
309
|
+
"entry_price": pl.Float64,
|
|
310
|
+
"exit_price": pl.Float64,
|
|
311
|
+
"pnl": pl.Float64,
|
|
312
|
+
"duration_seconds": pl.Float64,
|
|
313
|
+
"direction": pl.String,
|
|
314
|
+
"quantity": pl.Float64,
|
|
315
|
+
"entry_timestamp": pl.Datetime,
|
|
316
|
+
"fees": pl.Float64,
|
|
317
|
+
"slippage": pl.Float64,
|
|
318
|
+
"return_pct": pl.Float64,
|
|
319
|
+
"duration_hours": pl.Float64,
|
|
320
|
+
"duration_days": pl.Float64,
|
|
321
|
+
"pnl_per_day": pl.Float64,
|
|
322
|
+
}
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Convert to list of dicts
|
|
326
|
+
data = []
|
|
327
|
+
for trade in trades:
|
|
328
|
+
trade_dict = {
|
|
329
|
+
"timestamp": trade.timestamp,
|
|
330
|
+
"symbol": trade.symbol,
|
|
331
|
+
"entry_price": trade.entry_price,
|
|
332
|
+
"exit_price": trade.exit_price,
|
|
333
|
+
"pnl": trade.pnl,
|
|
334
|
+
"duration_seconds": trade.duration.total_seconds(),
|
|
335
|
+
"direction": trade.direction,
|
|
336
|
+
"quantity": trade.quantity,
|
|
337
|
+
"entry_timestamp": trade.entry_timestamp,
|
|
338
|
+
"fees": trade.fees,
|
|
339
|
+
"slippage": trade.slippage,
|
|
340
|
+
"return_pct": trade.return_pct,
|
|
341
|
+
"duration_hours": trade.duration_hours,
|
|
342
|
+
"duration_days": trade.duration_days,
|
|
343
|
+
"pnl_per_day": trade.pnl_per_day,
|
|
344
|
+
}
|
|
345
|
+
data.append(trade_dict)
|
|
346
|
+
|
|
347
|
+
return pl.DataFrame(data)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class TradeFilters(BaseModel):
|
|
351
|
+
"""Typed filter configuration for trade analysis.
|
|
352
|
+
|
|
353
|
+
Provides type-safe, validated filtering options instead of raw dict[str, Any].
|
|
354
|
+
All fields are optional - only specified filters are applied.
|
|
355
|
+
|
|
356
|
+
Fields:
|
|
357
|
+
symbols: List of symbols to include (None = all symbols)
|
|
358
|
+
min_duration: Minimum trade duration (None = no minimum)
|
|
359
|
+
min_pnl: Minimum PnL to include (None = no minimum)
|
|
360
|
+
max_pnl: Maximum PnL to include (None = no maximum)
|
|
361
|
+
start_date: Start of date range (None = no start bound)
|
|
362
|
+
end_date: End of date range (None = no end bound)
|
|
363
|
+
|
|
364
|
+
Example:
|
|
365
|
+
>>> filters = TradeFilters(
|
|
366
|
+
... symbols=["AAPL", "MSFT"],
|
|
367
|
+
... min_duration=timedelta(hours=1),
|
|
368
|
+
... min_pnl=-1000.0
|
|
369
|
+
... )
|
|
370
|
+
>>> analyzer = TradeAnalysis(trades, filters=filters)
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
symbols: list[str] | None = Field(None, description="Symbols to include")
|
|
374
|
+
min_duration: timedelta | None = Field(None, description="Minimum trade duration")
|
|
375
|
+
min_pnl: float | None = Field(None, description="Minimum PnL to include")
|
|
376
|
+
max_pnl: float | None = Field(None, description="Maximum PnL to include")
|
|
377
|
+
start_date: datetime | None = Field(None, description="Start of date range")
|
|
378
|
+
end_date: datetime | None = Field(None, description="End of date range")
|
|
379
|
+
|
|
380
|
+
def to_dict(self) -> dict[str, Any]:
|
|
381
|
+
"""Convert to legacy dict format for backward compatibility."""
|
|
382
|
+
result: dict[str, Any] = {}
|
|
383
|
+
if self.symbols is not None:
|
|
384
|
+
result["symbols"] = self.symbols
|
|
385
|
+
if self.min_duration is not None:
|
|
386
|
+
result["min_duration_seconds"] = self.min_duration.total_seconds()
|
|
387
|
+
if self.min_pnl is not None:
|
|
388
|
+
result["min_pnl"] = self.min_pnl
|
|
389
|
+
if self.max_pnl is not None:
|
|
390
|
+
result["max_pnl"] = self.max_pnl
|
|
391
|
+
if self.start_date is not None:
|
|
392
|
+
result["start_date"] = self.start_date
|
|
393
|
+
if self.end_date is not None:
|
|
394
|
+
result["end_date"] = self.end_date
|
|
395
|
+
return result
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class TradeStatistics(BaseModel):
|
|
399
|
+
"""Aggregate statistics across multiple trades.
|
|
400
|
+
|
|
401
|
+
Computes summary statistics for trade analysis:
|
|
402
|
+
- Win/loss metrics (win rate, profit factor)
|
|
403
|
+
- PnL distribution (mean, std, quartiles, skewness)
|
|
404
|
+
- Duration distribution (mean, median, quartiles)
|
|
405
|
+
- Trade counts and breakdowns
|
|
406
|
+
|
|
407
|
+
Used by TradeAnalysisResult to provide high-level performance summary.
|
|
408
|
+
|
|
409
|
+
Fields:
|
|
410
|
+
n_trades: Total number of trades
|
|
411
|
+
n_winners: Number of profitable trades
|
|
412
|
+
n_losers: Number of losing trades
|
|
413
|
+
win_rate: Fraction of winning trades
|
|
414
|
+
total_pnl: Sum of all PnL
|
|
415
|
+
avg_pnl: Mean PnL per trade
|
|
416
|
+
pnl_std: Standard deviation of PnL
|
|
417
|
+
pnl_skewness: Skewness of PnL distribution
|
|
418
|
+
pnl_kurtosis: Kurtosis of PnL distribution
|
|
419
|
+
pnl_quartiles: 25th, 50th (median), 75th percentiles
|
|
420
|
+
avg_winner: Average PnL of winning trades
|
|
421
|
+
avg_loser: Average PnL of losing trades
|
|
422
|
+
profit_factor: Gross profit / gross loss
|
|
423
|
+
avg_duration_days: Average trade duration in days
|
|
424
|
+
median_duration_days: Median trade duration
|
|
425
|
+
duration_quartiles: Duration percentiles
|
|
426
|
+
|
|
427
|
+
Example:
|
|
428
|
+
>>> stats = TradeStatistics.compute(trades)
|
|
429
|
+
>>> print(f"Win rate: {stats.win_rate:.2%}")
|
|
430
|
+
>>> print(f"Avg PnL: ${stats.avg_pnl:.2f}")
|
|
431
|
+
>>> print(f"Profit factor: {stats.profit_factor:.2f}")
|
|
432
|
+
>>> print(stats.summary())
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
# Trade counts
|
|
436
|
+
n_trades: int = Field(..., ge=0, description="Total number of trades")
|
|
437
|
+
n_winners: int = Field(..., ge=0, description="Number of profitable trades")
|
|
438
|
+
n_losers: int = Field(..., ge=0, description="Number of losing trades")
|
|
439
|
+
|
|
440
|
+
# Win rate and PnL metrics
|
|
441
|
+
win_rate: float = Field(..., ge=0.0, le=1.0, description="Fraction of winning trades")
|
|
442
|
+
total_pnl: float = Field(..., description="Sum of all PnL")
|
|
443
|
+
avg_pnl: float = Field(..., description="Mean PnL per trade")
|
|
444
|
+
pnl_std: float = Field(..., ge=0.0, description="Standard deviation of PnL")
|
|
445
|
+
|
|
446
|
+
# Distribution metrics
|
|
447
|
+
pnl_skewness: float | None = Field(None, description="PnL distribution skewness")
|
|
448
|
+
pnl_kurtosis: float | None = Field(None, description="PnL distribution kurtosis")
|
|
449
|
+
pnl_quartiles: dict[str, float] = Field(..., description="PnL quartiles (q25, q50, q75)")
|
|
450
|
+
|
|
451
|
+
# Winner/loser breakdown
|
|
452
|
+
avg_winner: float | None = Field(None, description="Average PnL of winners")
|
|
453
|
+
avg_loser: float | None = Field(None, description="Average PnL of losers")
|
|
454
|
+
profit_factor: float | None = Field(None, description="Gross profit / gross loss")
|
|
455
|
+
|
|
456
|
+
# Duration metrics
|
|
457
|
+
avg_duration_days: float = Field(..., ge=0.0, description="Average duration in days")
|
|
458
|
+
median_duration_days: float = Field(..., ge=0.0, description="Median duration in days")
|
|
459
|
+
duration_quartiles: dict[str, float] = Field(
|
|
460
|
+
..., description="Duration quartiles (q25, q50, q75)"
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
@staticmethod
|
|
464
|
+
def compute(trades: list[TradeMetrics]) -> TradeStatistics:
|
|
465
|
+
"""Compute statistics from list of trades.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
trades: List of TradeMetrics instances
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
TradeStatistics with all computed metrics
|
|
472
|
+
|
|
473
|
+
Raises:
|
|
474
|
+
ValueError: If trades list is empty
|
|
475
|
+
|
|
476
|
+
Example:
|
|
477
|
+
>>> stats = TradeStatistics.compute(metrics_list)
|
|
478
|
+
"""
|
|
479
|
+
if not trades:
|
|
480
|
+
raise ValueError("Cannot compute statistics for empty trade list")
|
|
481
|
+
|
|
482
|
+
# Convert to DataFrame for efficient computation
|
|
483
|
+
df = TradeMetrics.to_dataframe(trades)
|
|
484
|
+
|
|
485
|
+
# Count trades
|
|
486
|
+
n_trades = len(df)
|
|
487
|
+
n_winners = int(df.filter(pl.col("pnl") > 0).height)
|
|
488
|
+
n_losers = int(df.filter(pl.col("pnl") < 0).height)
|
|
489
|
+
win_rate = n_winners / n_trades if n_trades > 0 else 0.0
|
|
490
|
+
|
|
491
|
+
# PnL metrics
|
|
492
|
+
pnl_series = df["pnl"]
|
|
493
|
+
total_pnl = float(cast(SupportsFloat, pnl_series.sum()))
|
|
494
|
+
avg_pnl = float(cast(SupportsFloat, pnl_series.mean()))
|
|
495
|
+
pnl_std_value = pnl_series.std()
|
|
496
|
+
pnl_std = float(cast(SupportsFloat, pnl_std_value)) if pnl_std_value is not None else 0.0
|
|
497
|
+
|
|
498
|
+
# Distribution metrics (requires scipy for skewness/kurtosis)
|
|
499
|
+
try:
|
|
500
|
+
from scipy import stats as scipy_stats
|
|
501
|
+
|
|
502
|
+
pnl_values = pnl_series.to_numpy()
|
|
503
|
+
pnl_skewness = float(scipy_stats.skew(pnl_values))
|
|
504
|
+
pnl_kurtosis = float(scipy_stats.kurtosis(pnl_values))
|
|
505
|
+
except ImportError:
|
|
506
|
+
pnl_skewness = None
|
|
507
|
+
pnl_kurtosis = None
|
|
508
|
+
|
|
509
|
+
# Quartiles
|
|
510
|
+
pnl_q25 = float(cast(SupportsFloat, pnl_series.quantile(0.25)))
|
|
511
|
+
pnl_q50 = float(cast(SupportsFloat, pnl_series.quantile(0.50)))
|
|
512
|
+
pnl_q75 = float(cast(SupportsFloat, pnl_series.quantile(0.75)))
|
|
513
|
+
pnl_quartiles = {"q25": pnl_q25, "q50": pnl_q50, "q75": pnl_q75}
|
|
514
|
+
|
|
515
|
+
# Winner/loser breakdown
|
|
516
|
+
winners = df.filter(pl.col("pnl") > 0)
|
|
517
|
+
losers = df.filter(pl.col("pnl") < 0)
|
|
518
|
+
|
|
519
|
+
avg_winner = (
|
|
520
|
+
float(cast(SupportsFloat, winners["pnl"].mean())) if winners.height > 0 else None
|
|
521
|
+
)
|
|
522
|
+
avg_loser = float(cast(SupportsFloat, losers["pnl"].mean())) if losers.height > 0 else None
|
|
523
|
+
|
|
524
|
+
# Profit factor (only defined if both winners and losers exist)
|
|
525
|
+
gross_profit = float(winners["pnl"].sum()) if winners.height > 0 else 0.0
|
|
526
|
+
gross_loss = abs(float(losers["pnl"].sum())) if losers.height > 0 else 0.0
|
|
527
|
+
if winners.height > 0 and losers.height > 0:
|
|
528
|
+
profit_factor = gross_profit / gross_loss
|
|
529
|
+
else:
|
|
530
|
+
profit_factor = None # Undefined when all winners or all losers
|
|
531
|
+
|
|
532
|
+
# Duration metrics
|
|
533
|
+
duration_series = df["duration_days"]
|
|
534
|
+
avg_duration_days = float(cast(SupportsFloat, duration_series.mean()))
|
|
535
|
+
median_duration_days = float(cast(SupportsFloat, duration_series.median()))
|
|
536
|
+
dur_q25 = float(cast(SupportsFloat, duration_series.quantile(0.25)))
|
|
537
|
+
dur_q50 = float(cast(SupportsFloat, duration_series.quantile(0.50)))
|
|
538
|
+
dur_q75 = float(cast(SupportsFloat, duration_series.quantile(0.75)))
|
|
539
|
+
duration_quartiles = {"q25": dur_q25, "q50": dur_q50, "q75": dur_q75}
|
|
540
|
+
|
|
541
|
+
return TradeStatistics(
|
|
542
|
+
n_trades=n_trades,
|
|
543
|
+
n_winners=n_winners,
|
|
544
|
+
n_losers=n_losers,
|
|
545
|
+
win_rate=win_rate,
|
|
546
|
+
total_pnl=total_pnl,
|
|
547
|
+
avg_pnl=avg_pnl,
|
|
548
|
+
pnl_std=pnl_std,
|
|
549
|
+
pnl_skewness=pnl_skewness,
|
|
550
|
+
pnl_kurtosis=pnl_kurtosis,
|
|
551
|
+
pnl_quartiles=pnl_quartiles,
|
|
552
|
+
avg_winner=avg_winner,
|
|
553
|
+
avg_loser=avg_loser,
|
|
554
|
+
profit_factor=profit_factor,
|
|
555
|
+
avg_duration_days=avg_duration_days,
|
|
556
|
+
median_duration_days=median_duration_days,
|
|
557
|
+
duration_quartiles=duration_quartiles,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
def summary(self) -> str:
|
|
561
|
+
"""Generate human-readable summary of statistics.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Formatted summary string
|
|
565
|
+
|
|
566
|
+
Example:
|
|
567
|
+
>>> print(stats.summary())
|
|
568
|
+
Trade Statistics
|
|
569
|
+
================
|
|
570
|
+
Total trades: 150
|
|
571
|
+
Win rate: 62.67%
|
|
572
|
+
...
|
|
573
|
+
"""
|
|
574
|
+
lines = ["Trade Statistics", "=" * 50]
|
|
575
|
+
|
|
576
|
+
# Trade counts
|
|
577
|
+
lines.append(f"Total trades: {self.n_trades}")
|
|
578
|
+
lines.append(f"Winners: {self.n_winners} | Losers: {self.n_losers}")
|
|
579
|
+
lines.append(f"Win rate: {self.win_rate:.2%}")
|
|
580
|
+
lines.append("")
|
|
581
|
+
|
|
582
|
+
# PnL summary
|
|
583
|
+
lines.append("PnL Metrics")
|
|
584
|
+
lines.append("-" * 50)
|
|
585
|
+
lines.append(f"Total PnL: ${self.total_pnl:,.2f}")
|
|
586
|
+
lines.append(f"Average PnL: ${self.avg_pnl:,.2f} ± ${self.pnl_std:,.2f}")
|
|
587
|
+
if self.avg_winner is not None:
|
|
588
|
+
lines.append(f"Avg winner: ${self.avg_winner:,.2f}")
|
|
589
|
+
if self.avg_loser is not None:
|
|
590
|
+
lines.append(f"Avg loser: ${self.avg_loser:,.2f}")
|
|
591
|
+
if self.profit_factor is not None:
|
|
592
|
+
lines.append(f"Profit factor: {self.profit_factor:.2f}")
|
|
593
|
+
lines.append("")
|
|
594
|
+
|
|
595
|
+
# Distribution
|
|
596
|
+
lines.append("PnL Distribution")
|
|
597
|
+
lines.append("-" * 50)
|
|
598
|
+
lines.append(
|
|
599
|
+
f"Q25: ${self.pnl_quartiles['q25']:,.2f} | "
|
|
600
|
+
f"Median: ${self.pnl_quartiles['q50']:,.2f} | "
|
|
601
|
+
f"Q75: ${self.pnl_quartiles['q75']:,.2f}"
|
|
602
|
+
)
|
|
603
|
+
if self.pnl_skewness is not None:
|
|
604
|
+
lines.append(f"Skewness: {self.pnl_skewness:.3f}")
|
|
605
|
+
if self.pnl_kurtosis is not None:
|
|
606
|
+
lines.append(f"Kurtosis: {self.pnl_kurtosis:.3f}")
|
|
607
|
+
lines.append("")
|
|
608
|
+
|
|
609
|
+
# Duration
|
|
610
|
+
lines.append("Duration Metrics")
|
|
611
|
+
lines.append("-" * 50)
|
|
612
|
+
lines.append(f"Average: {self.avg_duration_days:.2f} days")
|
|
613
|
+
lines.append(f"Median: {self.median_duration_days:.2f} days")
|
|
614
|
+
lines.append(
|
|
615
|
+
f"Q25: {self.duration_quartiles['q25']:.2f} | "
|
|
616
|
+
f"Q50: {self.duration_quartiles['q50']:.2f} | "
|
|
617
|
+
f"Q75: {self.duration_quartiles['q75']:.2f}"
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
return "\n".join(lines)
|
|
621
|
+
|
|
622
|
+
def to_dataframe(self) -> pl.DataFrame:
|
|
623
|
+
"""Convert statistics to DataFrame.
|
|
624
|
+
|
|
625
|
+
Returns:
|
|
626
|
+
Single-row DataFrame with all statistics
|
|
627
|
+
|
|
628
|
+
Example:
|
|
629
|
+
>>> df = stats.to_dataframe()
|
|
630
|
+
"""
|
|
631
|
+
return pl.DataFrame(
|
|
632
|
+
[
|
|
633
|
+
{
|
|
634
|
+
"n_trades": self.n_trades,
|
|
635
|
+
"n_winners": self.n_winners,
|
|
636
|
+
"n_losers": self.n_losers,
|
|
637
|
+
"win_rate": self.win_rate,
|
|
638
|
+
"total_pnl": self.total_pnl,
|
|
639
|
+
"avg_pnl": self.avg_pnl,
|
|
640
|
+
"pnl_std": self.pnl_std,
|
|
641
|
+
"pnl_skewness": self.pnl_skewness,
|
|
642
|
+
"pnl_kurtosis": self.pnl_kurtosis,
|
|
643
|
+
"pnl_q25": self.pnl_quartiles["q25"],
|
|
644
|
+
"pnl_q50": self.pnl_quartiles["q50"],
|
|
645
|
+
"pnl_q75": self.pnl_quartiles["q75"],
|
|
646
|
+
"avg_winner": self.avg_winner,
|
|
647
|
+
"avg_loser": self.avg_loser,
|
|
648
|
+
"profit_factor": self.profit_factor,
|
|
649
|
+
"avg_duration_days": self.avg_duration_days,
|
|
650
|
+
"median_duration_days": self.median_duration_days,
|
|
651
|
+
"dur_q25": self.duration_quartiles["q25"],
|
|
652
|
+
"dur_q50": self.duration_quartiles["q50"],
|
|
653
|
+
"dur_q75": self.duration_quartiles["q75"],
|
|
654
|
+
}
|
|
655
|
+
]
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
class TradeAnalysis:
|
|
660
|
+
"""Main analyzer for extracting worst/best trades and computing statistics.
|
|
661
|
+
|
|
662
|
+
Provides high-level API for trade analysis workflows:
|
|
663
|
+
1. Load trades (TradeRecord instances from backtest)
|
|
664
|
+
2. Extract worst performers → Feed to SHAP diagnostics
|
|
665
|
+
3. Extract best performers → Understand success patterns
|
|
666
|
+
4. Compute statistics → Aggregate performance metrics
|
|
667
|
+
|
|
668
|
+
The analyzer supports filtering by:
|
|
669
|
+
- Symbol (e.g., only analyze AAPL, MSFT)
|
|
670
|
+
- Duration (e.g., trades lasting > 1 hour)
|
|
671
|
+
- PnL range (e.g., exclude small trades)
|
|
672
|
+
- Date range (e.g., trades in Q4 2024)
|
|
673
|
+
|
|
674
|
+
Example - Basic usage:
|
|
675
|
+
>>> analyzer = TradeAnalysis(trade_records)
|
|
676
|
+
>>> worst = analyzer.worst_trades(n=20)
|
|
677
|
+
>>> best = analyzer.best_trades(n=10)
|
|
678
|
+
>>> stats = analyzer.compute_statistics()
|
|
679
|
+
>>> print(stats.summary())
|
|
680
|
+
|
|
681
|
+
Example - With filtering:
|
|
682
|
+
>>> from ml4t.diagnostic.evaluation.trade_analysis import TradeFilters
|
|
683
|
+
>>>
|
|
684
|
+
>>> filters = TradeFilters(
|
|
685
|
+
... symbols=["AAPL", "MSFT"],
|
|
686
|
+
... min_duration=timedelta(hours=1),
|
|
687
|
+
... min_pnl=-1000.0,
|
|
688
|
+
... start_date=datetime(2024, 10, 1)
|
|
689
|
+
... )
|
|
690
|
+
>>>
|
|
691
|
+
>>> analyzer = TradeAnalysis(trade_records, filters=filters)
|
|
692
|
+
>>> result = analyzer.analyze(n_worst=20, n_best=10)
|
|
693
|
+
|
|
694
|
+
Example - Integration with config:
|
|
695
|
+
>>> from ml4t.diagnostic.config import TradeConfig, ExtractionSettings
|
|
696
|
+
>>>
|
|
697
|
+
>>> config = TradeConfig(
|
|
698
|
+
... extraction=ExtractionSettings(n_worst=20, n_best=10),
|
|
699
|
+
... )
|
|
700
|
+
>>>
|
|
701
|
+
>>> analyzer = TradeAnalysis.from_config(trade_records, config)
|
|
702
|
+
>>> result = analyzer.analyze()
|
|
703
|
+
>>> result.to_json_string()
|
|
704
|
+
"""
|
|
705
|
+
|
|
706
|
+
def __init__(
|
|
707
|
+
self,
|
|
708
|
+
trades: list[TradeRecord],
|
|
709
|
+
filter_config: dict[str, Any] | None = None,
|
|
710
|
+
*,
|
|
711
|
+
filters: TradeFilters | None = None,
|
|
712
|
+
):
|
|
713
|
+
"""Initialize analyzer with trades.
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
trades: List of TradeRecord instances from backtest
|
|
717
|
+
filter_config: Optional filtering configuration (legacy dict format)
|
|
718
|
+
filters: Optional typed TradeFilters (preferred over filter_config)
|
|
719
|
+
|
|
720
|
+
Example:
|
|
721
|
+
>>> # Using typed filters (preferred)
|
|
722
|
+
>>> filters = TradeFilters(symbols=["AAPL"], min_pnl=-1000)
|
|
723
|
+
>>> analyzer = TradeAnalysis(trades, filters=filters)
|
|
724
|
+
>>>
|
|
725
|
+
>>> # Using legacy dict format
|
|
726
|
+
>>> analyzer = TradeAnalysis(trades, filter_config={"symbols": ["AAPL"]})
|
|
727
|
+
"""
|
|
728
|
+
if not trades:
|
|
729
|
+
raise ValueError("Cannot analyze empty trade list")
|
|
730
|
+
|
|
731
|
+
# Convert to TradeMetrics
|
|
732
|
+
self.trades = [TradeMetrics.from_trade_record(t) for t in trades]
|
|
733
|
+
|
|
734
|
+
# Normalize filters to dict format (TradeFilters takes precedence)
|
|
735
|
+
if filters is not None:
|
|
736
|
+
filter_config = filters.to_dict()
|
|
737
|
+
|
|
738
|
+
# Apply filters if provided
|
|
739
|
+
if filter_config:
|
|
740
|
+
self.trades = self._apply_filters(self.trades, filter_config)
|
|
741
|
+
|
|
742
|
+
if not self.trades:
|
|
743
|
+
raise ValueError("No trades remaining after applying filters")
|
|
744
|
+
|
|
745
|
+
@classmethod
|
|
746
|
+
def from_config(
|
|
747
|
+
cls,
|
|
748
|
+
trades: list[TradeRecord],
|
|
749
|
+
config: Any, # TradeConfig - avoid circular import
|
|
750
|
+
) -> TradeAnalysis:
|
|
751
|
+
"""Create analyzer from configuration.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
trades: List of TradeRecord instances
|
|
755
|
+
config: TradeConfig instance
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
TradeAnalysis instance
|
|
759
|
+
|
|
760
|
+
Example:
|
|
761
|
+
>>> from ml4t.diagnostic.config import TradeConfig, ExtractionSettings
|
|
762
|
+
>>> config = TradeConfig(extraction=ExtractionSettings(n_worst=20, n_best=10))
|
|
763
|
+
>>> analyzer = TradeAnalysis.from_config(trades, config)
|
|
764
|
+
"""
|
|
765
|
+
# Extract filter config if present
|
|
766
|
+
filter_config = getattr(config, "filters", None)
|
|
767
|
+
return cls(trades, filter_config=filter_config)
|
|
768
|
+
|
|
769
|
+
@staticmethod
|
|
770
|
+
def _apply_filters(
|
|
771
|
+
trades: list[TradeMetrics],
|
|
772
|
+
filters: dict[str, Any],
|
|
773
|
+
) -> list[TradeMetrics]:
|
|
774
|
+
"""Apply filters to trade list in a single pass.
|
|
775
|
+
|
|
776
|
+
Args:
|
|
777
|
+
trades: List of TradeMetrics
|
|
778
|
+
filters: Filter criteria
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
Filtered trade list
|
|
782
|
+
"""
|
|
783
|
+
# Pre-extract filter values to avoid repeated dict lookups
|
|
784
|
+
symbols: set[str] | None = None
|
|
785
|
+
if "symbols" in filters and filters["symbols"]:
|
|
786
|
+
symbols = set(filters["symbols"])
|
|
787
|
+
|
|
788
|
+
min_dur: float | None = filters.get("min_duration_seconds")
|
|
789
|
+
min_pnl: float | None = filters.get("min_pnl")
|
|
790
|
+
max_pnl: float | None = filters.get("max_pnl")
|
|
791
|
+
start_date: datetime | None = filters.get("start_date")
|
|
792
|
+
end_date: datetime | None = filters.get("end_date")
|
|
793
|
+
|
|
794
|
+
# Single-pass filtering
|
|
795
|
+
def matches(t: TradeMetrics) -> bool:
|
|
796
|
+
if symbols is not None and t.symbol not in symbols:
|
|
797
|
+
return False
|
|
798
|
+
if min_dur is not None and t.duration.total_seconds() < min_dur:
|
|
799
|
+
return False
|
|
800
|
+
if min_pnl is not None and t.pnl < min_pnl:
|
|
801
|
+
return False
|
|
802
|
+
if max_pnl is not None and t.pnl > max_pnl:
|
|
803
|
+
return False
|
|
804
|
+
if start_date is not None and t.timestamp < start_date:
|
|
805
|
+
return False
|
|
806
|
+
if end_date is not None and t.timestamp > end_date:
|
|
807
|
+
return False
|
|
808
|
+
return True
|
|
809
|
+
|
|
810
|
+
return [t for t in trades if matches(t)]
|
|
811
|
+
|
|
812
|
+
def worst_trades(self, n: int = 10) -> list[TradeMetrics]:
|
|
813
|
+
"""Extract N worst trades by PnL.
|
|
814
|
+
|
|
815
|
+
Uses heapq.nsmallest for O(n + k log n) efficiency when k << n.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
n: Number of worst trades to extract
|
|
819
|
+
|
|
820
|
+
Returns:
|
|
821
|
+
List of worst N trades, sorted by PnL (ascending)
|
|
822
|
+
|
|
823
|
+
Example:
|
|
824
|
+
>>> worst = analyzer.worst_trades(n=20)
|
|
825
|
+
>>> for trade in worst[:5]:
|
|
826
|
+
... print(f"{trade.symbol}: ${trade.pnl:.2f}")
|
|
827
|
+
|
|
828
|
+
See Also
|
|
829
|
+
--------
|
|
830
|
+
best_trades : Extract best performing trades
|
|
831
|
+
compute_statistics : Aggregate performance metrics
|
|
832
|
+
analyze : Complete analysis with worst/best trades
|
|
833
|
+
"""
|
|
834
|
+
if n <= 0:
|
|
835
|
+
raise ValueError(f"n must be positive, got {n}")
|
|
836
|
+
|
|
837
|
+
# Use heapq for O(n + k log n) instead of O(n log n) sort
|
|
838
|
+
return heapq.nsmallest(n, self.trades, key=lambda t: t.pnl)
|
|
839
|
+
|
|
840
|
+
def best_trades(self, n: int = 10) -> list[TradeMetrics]:
|
|
841
|
+
"""Extract N best trades by PnL.
|
|
842
|
+
|
|
843
|
+
Uses heapq.nlargest for O(n + k log n) efficiency when k << n.
|
|
844
|
+
|
|
845
|
+
Args:
|
|
846
|
+
n: Number of best trades to extract
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
List of best N trades, sorted by PnL (descending)
|
|
850
|
+
|
|
851
|
+
Example:
|
|
852
|
+
>>> best = analyzer.best_trades(n=10)
|
|
853
|
+
>>> for trade in best[:5]:
|
|
854
|
+
... print(f"{trade.symbol}: ${trade.pnl:.2f}")
|
|
855
|
+
|
|
856
|
+
See Also
|
|
857
|
+
--------
|
|
858
|
+
worst_trades : Extract worst performing trades
|
|
859
|
+
compute_statistics : Aggregate performance metrics
|
|
860
|
+
analyze : Complete analysis with worst/best trades
|
|
861
|
+
"""
|
|
862
|
+
if n <= 0:
|
|
863
|
+
raise ValueError(f"n must be positive, got {n}")
|
|
864
|
+
|
|
865
|
+
# Use heapq for O(n + k log n) instead of O(n log n) sort
|
|
866
|
+
return heapq.nlargest(n, self.trades, key=lambda t: t.pnl)
|
|
867
|
+
|
|
868
|
+
def compute_statistics(self) -> TradeStatistics:
|
|
869
|
+
"""Compute aggregate statistics across all trades.
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
TradeStatistics with summary metrics
|
|
873
|
+
|
|
874
|
+
Example:
|
|
875
|
+
>>> stats = analyzer.compute_statistics()
|
|
876
|
+
>>> print(f"Win rate: {stats.win_rate:.2%}")
|
|
877
|
+
|
|
878
|
+
See Also
|
|
879
|
+
--------
|
|
880
|
+
TradeStatistics : Statistics result schema
|
|
881
|
+
TradeStatistics.compute : Static method for statistics computation
|
|
882
|
+
analyze : Complete analysis including statistics
|
|
883
|
+
"""
|
|
884
|
+
return TradeStatistics.compute(self.trades)
|
|
885
|
+
|
|
886
|
+
def analyze(
|
|
887
|
+
self,
|
|
888
|
+
n_worst: int = 10,
|
|
889
|
+
n_best: int = 10,
|
|
890
|
+
) -> TradeAnalysisResult:
|
|
891
|
+
"""Run complete analysis and return result object.
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
n_worst: Number of worst trades to extract
|
|
895
|
+
n_best: Number of best trades to extract
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
TradeAnalysisResult with all data
|
|
899
|
+
|
|
900
|
+
Example:
|
|
901
|
+
>>> result = analyzer.analyze(n_worst=20, n_best=10)
|
|
902
|
+
>>> print(result.summary())
|
|
903
|
+
>>> result.to_json_string()
|
|
904
|
+
|
|
905
|
+
See Also
|
|
906
|
+
--------
|
|
907
|
+
worst_trades : Extract worst trades
|
|
908
|
+
best_trades : Extract best trades
|
|
909
|
+
compute_statistics : Compute aggregate statistics
|
|
910
|
+
TradeAnalysisResult : Result schema with serialization
|
|
911
|
+
"""
|
|
912
|
+
return TradeAnalysisResult(
|
|
913
|
+
worst_trades=self.worst_trades(n_worst),
|
|
914
|
+
best_trades=self.best_trades(n_best),
|
|
915
|
+
statistics=self.compute_statistics(),
|
|
916
|
+
n_total_trades=len(self.trades),
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
class TradeAnalysisResult(BaseModel):
|
|
921
|
+
"""Result schema for trade analysis with serialization support.
|
|
922
|
+
|
|
923
|
+
Contains the complete output of a trade analysis:
|
|
924
|
+
- Worst N trades (for SHAP diagnostics)
|
|
925
|
+
- Best N trades (for success pattern analysis)
|
|
926
|
+
- Aggregate statistics across all trades
|
|
927
|
+
- Metadata (total trades analyzed)
|
|
928
|
+
|
|
929
|
+
This schema extends BaseResult to provide:
|
|
930
|
+
- JSON serialization via to_json_string()
|
|
931
|
+
- DataFrame export via get_dataframe()
|
|
932
|
+
- Human-readable summary via summary()
|
|
933
|
+
|
|
934
|
+
Use this to store and retrieve analysis results, or to pass
|
|
935
|
+
data between different stages of the diagnostics workflow.
|
|
936
|
+
|
|
937
|
+
Fields:
|
|
938
|
+
worst_trades: List of worst N trades by PnL
|
|
939
|
+
best_trades: List of best N trades by PnL
|
|
940
|
+
statistics: Aggregate statistics
|
|
941
|
+
n_total_trades: Total trades analyzed (before worst/best filtering)
|
|
942
|
+
analysis_type: Type of analysis ("trade_analysis")
|
|
943
|
+
created_at: ISO timestamp of analysis creation
|
|
944
|
+
|
|
945
|
+
Example - Basic usage:
|
|
946
|
+
>>> result = analyzer.analyze(n_worst=20, n_best=10)
|
|
947
|
+
>>> print(result.summary())
|
|
948
|
+
>>> result.to_json_string()
|
|
949
|
+
>>> df = result.get_dataframe("worst_trades")
|
|
950
|
+
|
|
951
|
+
Example - Serialization:
|
|
952
|
+
>>> # Save to file
|
|
953
|
+
>>> with open("analysis_result.json", "w") as f:
|
|
954
|
+
... f.write(result.to_json_string())
|
|
955
|
+
>>>
|
|
956
|
+
>>> # Load from file
|
|
957
|
+
>>> with open("analysis_result.json") as f:
|
|
958
|
+
... data = json.load(f)
|
|
959
|
+
>>> result = TradeAnalysisResult(**data)
|
|
960
|
+
|
|
961
|
+
Example - DataFrame export:
|
|
962
|
+
>>> # Get worst trades as DataFrame
|
|
963
|
+
>>> df_worst = result.get_dataframe("worst_trades")
|
|
964
|
+
>>>
|
|
965
|
+
>>> # Get statistics as DataFrame
|
|
966
|
+
>>> df_stats = result.get_dataframe("statistics")
|
|
967
|
+
>>>
|
|
968
|
+
>>> # Get all available DataFrames
|
|
969
|
+
>>> available = result.list_available_dataframes()
|
|
970
|
+
"""
|
|
971
|
+
|
|
972
|
+
# Result fields
|
|
973
|
+
worst_trades: list[TradeMetrics] = Field(
|
|
974
|
+
...,
|
|
975
|
+
description="List of worst N trades by PnL",
|
|
976
|
+
)
|
|
977
|
+
best_trades: list[TradeMetrics] = Field(
|
|
978
|
+
...,
|
|
979
|
+
description="List of best N trades by PnL",
|
|
980
|
+
)
|
|
981
|
+
statistics: TradeStatistics = Field(
|
|
982
|
+
...,
|
|
983
|
+
description="Aggregate statistics across all trades",
|
|
984
|
+
)
|
|
985
|
+
n_total_trades: int = Field(
|
|
986
|
+
...,
|
|
987
|
+
ge=1,
|
|
988
|
+
description="Total number of trades analyzed",
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
# Metadata fields
|
|
992
|
+
analysis_type: str = Field(
|
|
993
|
+
default="trade_analysis",
|
|
994
|
+
description="Type of analysis performed",
|
|
995
|
+
)
|
|
996
|
+
created_at: datetime = Field(
|
|
997
|
+
default_factory=lambda: datetime.now(UTC),
|
|
998
|
+
description="Analysis creation timestamp (UTC)",
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
def to_json_string(self, *, indent: int = 2) -> str:
|
|
1002
|
+
"""Export to JSON string.
|
|
1003
|
+
|
|
1004
|
+
Args:
|
|
1005
|
+
indent: Indentation level (None for compact)
|
|
1006
|
+
|
|
1007
|
+
Returns:
|
|
1008
|
+
JSON string representation
|
|
1009
|
+
|
|
1010
|
+
Example:
|
|
1011
|
+
>>> json_str = result.to_json_string()
|
|
1012
|
+
>>> with open("result.json", "w") as f:
|
|
1013
|
+
... f.write(json_str)
|
|
1014
|
+
"""
|
|
1015
|
+
return self.model_dump_json(indent=indent)
|
|
1016
|
+
|
|
1017
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1018
|
+
"""Export to Python dictionary.
|
|
1019
|
+
|
|
1020
|
+
Returns:
|
|
1021
|
+
Dictionary representation
|
|
1022
|
+
|
|
1023
|
+
Example:
|
|
1024
|
+
>>> data = result.to_dict()
|
|
1025
|
+
>>> data["statistics"]["win_rate"]
|
|
1026
|
+
"""
|
|
1027
|
+
return self.model_dump(mode="python")
|
|
1028
|
+
|
|
1029
|
+
def get_dataframe(self, name: str = "worst_trades") -> pl.DataFrame:
|
|
1030
|
+
"""Get results as Polars DataFrame.
|
|
1031
|
+
|
|
1032
|
+
Available DataFrames:
|
|
1033
|
+
- "worst_trades": Worst N trades with all fields
|
|
1034
|
+
- "best_trades": Best N trades with all fields
|
|
1035
|
+
- "statistics": Aggregate statistics (single row)
|
|
1036
|
+
- "all_trades": Combined worst + best trades
|
|
1037
|
+
|
|
1038
|
+
Args:
|
|
1039
|
+
name: DataFrame name to retrieve
|
|
1040
|
+
|
|
1041
|
+
Returns:
|
|
1042
|
+
Polars DataFrame with requested data
|
|
1043
|
+
|
|
1044
|
+
Raises:
|
|
1045
|
+
ValueError: If DataFrame name not available
|
|
1046
|
+
|
|
1047
|
+
Example:
|
|
1048
|
+
>>> df_worst = result.get_dataframe("worst_trades")
|
|
1049
|
+
>>> df_stats = result.get_dataframe("statistics")
|
|
1050
|
+
"""
|
|
1051
|
+
if name == "worst_trades":
|
|
1052
|
+
return TradeMetrics.to_dataframe(self.worst_trades)
|
|
1053
|
+
elif name == "best_trades":
|
|
1054
|
+
return TradeMetrics.to_dataframe(self.best_trades)
|
|
1055
|
+
elif name == "statistics":
|
|
1056
|
+
return self.statistics.to_dataframe()
|
|
1057
|
+
elif name == "all_trades":
|
|
1058
|
+
# Combine worst and best
|
|
1059
|
+
all_trades = self.worst_trades + self.best_trades
|
|
1060
|
+
return TradeMetrics.to_dataframe(all_trades)
|
|
1061
|
+
else:
|
|
1062
|
+
available = self.list_available_dataframes()
|
|
1063
|
+
raise ValueError(f"DataFrame '{name}' not available. Available: {', '.join(available)}")
|
|
1064
|
+
|
|
1065
|
+
def list_available_dataframes(self) -> list[str]:
|
|
1066
|
+
"""List available DataFrame views.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
List of available DataFrame names
|
|
1070
|
+
|
|
1071
|
+
Example:
|
|
1072
|
+
>>> result.list_available_dataframes()
|
|
1073
|
+
['worst_trades', 'best_trades', 'statistics', 'all_trades']
|
|
1074
|
+
"""
|
|
1075
|
+
return ["worst_trades", "best_trades", "statistics", "all_trades"]
|
|
1076
|
+
|
|
1077
|
+
def summary(self) -> str:
|
|
1078
|
+
"""Generate human-readable summary of analysis.
|
|
1079
|
+
|
|
1080
|
+
Returns:
|
|
1081
|
+
Formatted summary string
|
|
1082
|
+
|
|
1083
|
+
Example:
|
|
1084
|
+
>>> print(result.summary())
|
|
1085
|
+
Trade Analysis Summary
|
|
1086
|
+
======================
|
|
1087
|
+
...
|
|
1088
|
+
"""
|
|
1089
|
+
lines = ["Trade Analysis Summary", "=" * 60]
|
|
1090
|
+
|
|
1091
|
+
# Overview
|
|
1092
|
+
lines.append(f"Analysis timestamp: {self.created_at.isoformat()}")
|
|
1093
|
+
lines.append(f"Total trades analyzed: {self.n_total_trades}")
|
|
1094
|
+
lines.append(f"Worst trades extracted: {len(self.worst_trades)}")
|
|
1095
|
+
lines.append(f"Best trades extracted: {len(self.best_trades)}")
|
|
1096
|
+
lines.append("")
|
|
1097
|
+
|
|
1098
|
+
# Statistics summary
|
|
1099
|
+
lines.append("Overall Statistics")
|
|
1100
|
+
lines.append("-" * 60)
|
|
1101
|
+
stats = self.statistics
|
|
1102
|
+
lines.append(f"Win rate: {stats.win_rate:.2%}")
|
|
1103
|
+
lines.append(f"Total PnL: ${stats.total_pnl:,.2f}")
|
|
1104
|
+
lines.append(f"Average PnL: ${stats.avg_pnl:,.2f} ± ${stats.pnl_std:,.2f}")
|
|
1105
|
+
if stats.profit_factor is not None:
|
|
1106
|
+
lines.append(f"Profit factor: {stats.profit_factor:.2f}")
|
|
1107
|
+
lines.append(f"Average duration: {stats.avg_duration_days:.2f} days")
|
|
1108
|
+
lines.append("")
|
|
1109
|
+
|
|
1110
|
+
# Worst trades preview
|
|
1111
|
+
lines.append("Worst Trades (Top 5)")
|
|
1112
|
+
lines.append("-" * 60)
|
|
1113
|
+
for i, trade in enumerate(self.worst_trades[:5], 1):
|
|
1114
|
+
lines.append(
|
|
1115
|
+
f"{i}. {trade.symbol}: ${trade.pnl:,.2f} ({trade.return_pct:+.2%}) [{trade.duration_days:.1f}d]"
|
|
1116
|
+
)
|
|
1117
|
+
lines.append("")
|
|
1118
|
+
|
|
1119
|
+
# Best trades preview
|
|
1120
|
+
lines.append("Best Trades (Top 5)")
|
|
1121
|
+
lines.append("-" * 60)
|
|
1122
|
+
for i, trade in enumerate(self.best_trades[:5], 1):
|
|
1123
|
+
lines.append(
|
|
1124
|
+
f"{i}. {trade.symbol}: ${trade.pnl:,.2f} ({trade.return_pct:+.2%}) [{trade.duration_days:.1f}d]"
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
return "\n".join(lines)
|
|
1128
|
+
|
|
1129
|
+
def __repr__(self) -> str:
|
|
1130
|
+
"""Concise representation."""
|
|
1131
|
+
return (
|
|
1132
|
+
f"TradeAnalysisResult("
|
|
1133
|
+
f"n_worst={len(self.worst_trades)}, "
|
|
1134
|
+
f"n_best={len(self.best_trades)}, "
|
|
1135
|
+
f"n_total={self.n_total_trades})"
|
|
1136
|
+
)
|