ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
"""Weights & Biases integration for experiment tracking.
|
|
2
|
+
|
|
3
|
+
This module provides hooks for logging ml4t-diagnostic experiments to W&B,
|
|
4
|
+
enabling tracking of evaluation metrics, hyperparameters, and
|
|
5
|
+
visualizations across experiments.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numbers
|
|
9
|
+
import warnings
|
|
10
|
+
from typing import Any, SupportsFloat, cast
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import wandb # type: ignore[import-not-found,unused-ignore]
|
|
17
|
+
|
|
18
|
+
HAS_WANDB = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
HAS_WANDB = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class WandbLogger:
|
|
24
|
+
"""Logger for Weights & Biases experiment tracking.
|
|
25
|
+
|
|
26
|
+
This class provides a unified interface for logging ml4t-diagnostic
|
|
27
|
+
experiments to W&B, handling initialization, metric logging,
|
|
28
|
+
and artifact management.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
project: str | None = None,
|
|
34
|
+
entity: str | None = None,
|
|
35
|
+
name: str | None = None,
|
|
36
|
+
config: dict[str, Any] | None = None,
|
|
37
|
+
tags: list[str] | None = None,
|
|
38
|
+
notes: str | None = None,
|
|
39
|
+
disabled: bool = False,
|
|
40
|
+
):
|
|
41
|
+
"""Initialize W&B logger.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
project : str, optional
|
|
46
|
+
W&B project name
|
|
47
|
+
entity : str, optional
|
|
48
|
+
W&B entity (team or username)
|
|
49
|
+
name : str, optional
|
|
50
|
+
Run name
|
|
51
|
+
config : dict, optional
|
|
52
|
+
Configuration dictionary to log
|
|
53
|
+
tags : list[str], optional
|
|
54
|
+
Tags for the run
|
|
55
|
+
notes : str, optional
|
|
56
|
+
Notes about the run
|
|
57
|
+
disabled : bool
|
|
58
|
+
If True, disables W&B logging
|
|
59
|
+
"""
|
|
60
|
+
self.disabled = disabled or not HAS_WANDB
|
|
61
|
+
self.run = None
|
|
62
|
+
|
|
63
|
+
if self.disabled:
|
|
64
|
+
if not HAS_WANDB and not disabled:
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"wandb not installed. Install with: pip install wandb",
|
|
67
|
+
stacklevel=2,
|
|
68
|
+
)
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# Initialize W&B run
|
|
72
|
+
self.run = wandb.init(
|
|
73
|
+
project=project or "ml4t-diagnostic",
|
|
74
|
+
entity=entity,
|
|
75
|
+
name=name,
|
|
76
|
+
config=config,
|
|
77
|
+
tags=tags or [],
|
|
78
|
+
notes=notes,
|
|
79
|
+
reinit=True,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def log_config(self, config: dict[str, Any]) -> None:
|
|
83
|
+
"""Log configuration parameters.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
config : dict
|
|
88
|
+
Configuration dictionary
|
|
89
|
+
"""
|
|
90
|
+
if self.disabled or self.run is None:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
# Flatten nested config for W&B
|
|
94
|
+
flat_config = self._flatten_dict(config)
|
|
95
|
+
wandb.config.update(flat_config)
|
|
96
|
+
|
|
97
|
+
def log_metrics(
|
|
98
|
+
self,
|
|
99
|
+
metrics: dict[str, Any],
|
|
100
|
+
step: int | None = None,
|
|
101
|
+
prefix: str = "",
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Log evaluation metrics.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
metrics : dict
|
|
108
|
+
Metrics to log
|
|
109
|
+
step : int, optional
|
|
110
|
+
Step number (e.g., CV fold)
|
|
111
|
+
prefix : str
|
|
112
|
+
Prefix for metric names
|
|
113
|
+
"""
|
|
114
|
+
if self.disabled or self.run is None:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# Prepare metrics for logging
|
|
118
|
+
log_dict = {}
|
|
119
|
+
|
|
120
|
+
for name, value in metrics.items():
|
|
121
|
+
key = f"{prefix}{name}" if prefix else name
|
|
122
|
+
|
|
123
|
+
if isinstance(value, dict):
|
|
124
|
+
# Handle nested metrics (e.g., with confidence intervals)
|
|
125
|
+
for sub_key, sub_value in value.items():
|
|
126
|
+
if isinstance(sub_value, numbers.Number):
|
|
127
|
+
log_dict[f"{key}/{sub_key}"] = float(cast(SupportsFloat, sub_value))
|
|
128
|
+
elif isinstance(value, numbers.Number):
|
|
129
|
+
log_dict[key] = float(cast(SupportsFloat, value))
|
|
130
|
+
elif isinstance(value, list | np.ndarray):
|
|
131
|
+
# Log array statistics
|
|
132
|
+
if len(value) > 0:
|
|
133
|
+
log_dict[f"{key}/mean"] = float(np.mean(value))
|
|
134
|
+
log_dict[f"{key}/std"] = float(np.std(value))
|
|
135
|
+
log_dict[f"{key}/min"] = float(np.min(value))
|
|
136
|
+
log_dict[f"{key}/max"] = float(np.max(value))
|
|
137
|
+
|
|
138
|
+
if step is not None:
|
|
139
|
+
log_dict["step"] = step
|
|
140
|
+
|
|
141
|
+
wandb.log(log_dict)
|
|
142
|
+
|
|
143
|
+
def log_fold_results(
|
|
144
|
+
self,
|
|
145
|
+
fold_idx: int,
|
|
146
|
+
train_size: int,
|
|
147
|
+
test_size: int,
|
|
148
|
+
metrics: dict[str, Any],
|
|
149
|
+
) -> None:
|
|
150
|
+
"""Log results from a single CV fold.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
fold_idx : int
|
|
155
|
+
Fold index
|
|
156
|
+
train_size : int
|
|
157
|
+
Training set size
|
|
158
|
+
test_size : int
|
|
159
|
+
Test set size
|
|
160
|
+
metrics : dict
|
|
161
|
+
Fold metrics
|
|
162
|
+
"""
|
|
163
|
+
if self.disabled or self.run is None:
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
# Add metrics with fold prefix
|
|
167
|
+
self.log_metrics(metrics, step=fold_idx, prefix="fold/")
|
|
168
|
+
|
|
169
|
+
# Log fold metadata
|
|
170
|
+
wandb.log(
|
|
171
|
+
{
|
|
172
|
+
"fold/train_size": train_size,
|
|
173
|
+
"fold/test_size": test_size,
|
|
174
|
+
"fold/train_test_ratio": train_size / test_size if test_size > 0 else 0,
|
|
175
|
+
},
|
|
176
|
+
step=fold_idx,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def log_statistical_tests(self, tests: dict[str, Any]) -> None:
|
|
180
|
+
"""Log statistical test results.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
tests : dict
|
|
185
|
+
Statistical test results
|
|
186
|
+
"""
|
|
187
|
+
if self.disabled or self.run is None:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
log_dict = {}
|
|
191
|
+
|
|
192
|
+
for test_name, result in tests.items():
|
|
193
|
+
if isinstance(result, dict):
|
|
194
|
+
for key, value in result.items():
|
|
195
|
+
if isinstance(value, numbers.Number):
|
|
196
|
+
log_dict[f"stats/{test_name}/{key}"] = float(cast(SupportsFloat, value))
|
|
197
|
+
elif key == "significant" and isinstance(value, bool):
|
|
198
|
+
log_dict[f"stats/{test_name}/{key}"] = int(value)
|
|
199
|
+
|
|
200
|
+
wandb.log(log_dict)
|
|
201
|
+
|
|
202
|
+
def log_figure(
|
|
203
|
+
self,
|
|
204
|
+
figure: Any,
|
|
205
|
+
name: str,
|
|
206
|
+
step: int | None = None,
|
|
207
|
+
) -> None:
|
|
208
|
+
"""Log a Plotly figure.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
figure : plotly.graph_objects.Figure
|
|
213
|
+
Figure to log
|
|
214
|
+
name : str
|
|
215
|
+
Figure name
|
|
216
|
+
step : int, optional
|
|
217
|
+
Step number
|
|
218
|
+
"""
|
|
219
|
+
if self.disabled or self.run is None:
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
# Convert Plotly figure to W&B
|
|
223
|
+
wandb.log({f"plots/{name}": figure}, step=step)
|
|
224
|
+
|
|
225
|
+
def log_evaluation_summary(
|
|
226
|
+
self,
|
|
227
|
+
result: Any, # EvaluationResult
|
|
228
|
+
_predictions: Any | None = None,
|
|
229
|
+
_returns: Any | None = None,
|
|
230
|
+
) -> None:
|
|
231
|
+
"""Log complete evaluation summary.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
result : EvaluationResult
|
|
236
|
+
Evaluation result object
|
|
237
|
+
predictions : array-like, optional
|
|
238
|
+
Predictions for additional logging
|
|
239
|
+
returns : array-like, optional
|
|
240
|
+
Returns for additional logging
|
|
241
|
+
"""
|
|
242
|
+
if self.disabled or self.run is None:
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
# Log summary metrics
|
|
246
|
+
summary = result.summary()
|
|
247
|
+
|
|
248
|
+
# Log aggregate metrics
|
|
249
|
+
self.log_metrics(summary["metrics"], prefix="summary/")
|
|
250
|
+
|
|
251
|
+
# Log statistical tests
|
|
252
|
+
if summary.get("statistical_tests"):
|
|
253
|
+
self.log_statistical_tests(summary["statistical_tests"])
|
|
254
|
+
|
|
255
|
+
# Log metadata
|
|
256
|
+
wandb.log(
|
|
257
|
+
{
|
|
258
|
+
"summary/tier": result.tier,
|
|
259
|
+
"summary/n_folds": summary["n_folds"],
|
|
260
|
+
"summary/splitter": result.splitter_name,
|
|
261
|
+
},
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Create summary table
|
|
265
|
+
if result.fold_results:
|
|
266
|
+
fold_data = []
|
|
267
|
+
for fold in result.fold_results:
|
|
268
|
+
fold_row = {"fold": fold.get("fold", 0)}
|
|
269
|
+
fold_row.update(
|
|
270
|
+
{k: v for k, v in fold.items() if isinstance(v, numbers.Number)},
|
|
271
|
+
)
|
|
272
|
+
fold_data.append(fold_row)
|
|
273
|
+
|
|
274
|
+
fold_table = wandb.Table(dataframe=pd.DataFrame(fold_data))
|
|
275
|
+
wandb.log({"tables/fold_results": fold_table})
|
|
276
|
+
|
|
277
|
+
def log_artifact(
|
|
278
|
+
self,
|
|
279
|
+
artifact_path: str,
|
|
280
|
+
name: str,
|
|
281
|
+
artifact_type: str = "evaluation",
|
|
282
|
+
metadata: dict[str, Any] | None = None,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""Log an artifact (model, dataset, etc.).
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
artifact_path : str
|
|
289
|
+
Path to artifact file
|
|
290
|
+
name : str
|
|
291
|
+
Artifact name
|
|
292
|
+
artifact_type : str
|
|
293
|
+
Type of artifact
|
|
294
|
+
metadata : dict, optional
|
|
295
|
+
Additional metadata
|
|
296
|
+
"""
|
|
297
|
+
if self.disabled or self.run is None:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
artifact = wandb.Artifact(
|
|
301
|
+
name=name,
|
|
302
|
+
type=artifact_type,
|
|
303
|
+
metadata=metadata or {},
|
|
304
|
+
)
|
|
305
|
+
artifact.add_file(artifact_path)
|
|
306
|
+
wandb.log_artifact(artifact)
|
|
307
|
+
|
|
308
|
+
def finish(self) -> None:
|
|
309
|
+
"""Finish the W&B run."""
|
|
310
|
+
if self.disabled or self.run is None:
|
|
311
|
+
return
|
|
312
|
+
|
|
313
|
+
wandb.finish()
|
|
314
|
+
|
|
315
|
+
def __enter__(self):
|
|
316
|
+
"""Context manager entry."""
|
|
317
|
+
return self
|
|
318
|
+
|
|
319
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
320
|
+
"""Context manager exit."""
|
|
321
|
+
self.finish()
|
|
322
|
+
|
|
323
|
+
@staticmethod
|
|
324
|
+
def _flatten_dict(
|
|
325
|
+
d: dict[str, Any],
|
|
326
|
+
parent_key: str = "",
|
|
327
|
+
sep: str = "/",
|
|
328
|
+
) -> dict[str, Any]:
|
|
329
|
+
"""Flatten nested dictionary."""
|
|
330
|
+
items: list[tuple[str, Any]] = []
|
|
331
|
+
|
|
332
|
+
for k, v in d.items():
|
|
333
|
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
|
334
|
+
|
|
335
|
+
if isinstance(v, dict):
|
|
336
|
+
items.extend(WandbLogger._flatten_dict(v, new_key, sep=sep).items())
|
|
337
|
+
else:
|
|
338
|
+
items.append((new_key, v))
|
|
339
|
+
|
|
340
|
+
return dict(items)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def log_experiment(
|
|
344
|
+
evaluator: Any,
|
|
345
|
+
X: Any,
|
|
346
|
+
y: Any,
|
|
347
|
+
model: Any,
|
|
348
|
+
project: str | None = None,
|
|
349
|
+
config: dict[str, Any] | None = None,
|
|
350
|
+
tags: list[str] | None = None,
|
|
351
|
+
**kwargs: Any,
|
|
352
|
+
) -> Any:
|
|
353
|
+
"""Convenience function to run and log an experiment.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
evaluator : ml4t-diagnostic.Evaluator
|
|
358
|
+
Configured evaluator
|
|
359
|
+
X : array-like
|
|
360
|
+
Features
|
|
361
|
+
y : array-like
|
|
362
|
+
Labels
|
|
363
|
+
model : estimator
|
|
364
|
+
Model to evaluate
|
|
365
|
+
project : str, optional
|
|
366
|
+
W&B project name
|
|
367
|
+
config : dict, optional
|
|
368
|
+
Additional config to log
|
|
369
|
+
tags : list[str], optional
|
|
370
|
+
Experiment tags
|
|
371
|
+
**kwargs : Any
|
|
372
|
+
Additional arguments passed to evaluate()
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
-------
|
|
376
|
+
EvaluationResult
|
|
377
|
+
Result with W&B logging
|
|
378
|
+
"""
|
|
379
|
+
if not HAS_WANDB:
|
|
380
|
+
warnings.warn(
|
|
381
|
+
"wandb not installed. Running without logging. Install with: pip install wandb",
|
|
382
|
+
stacklevel=2,
|
|
383
|
+
)
|
|
384
|
+
return evaluator.evaluate(X, y, model, **kwargs)
|
|
385
|
+
|
|
386
|
+
# Initialize logger
|
|
387
|
+
with WandbLogger(project=project, config=config, tags=tags) as logger:
|
|
388
|
+
# Log evaluator configuration
|
|
389
|
+
logger.log_config(
|
|
390
|
+
{
|
|
391
|
+
"evaluator": {
|
|
392
|
+
"tier": evaluator.tier,
|
|
393
|
+
"splitter": evaluator.splitter.__class__.__name__,
|
|
394
|
+
"metrics": evaluator.metrics,
|
|
395
|
+
"statistical_tests": evaluator.statistical_tests,
|
|
396
|
+
"confidence_level": evaluator.confidence_level,
|
|
397
|
+
"bootstrap_samples": evaluator.bootstrap_samples,
|
|
398
|
+
},
|
|
399
|
+
},
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Log model info if available
|
|
403
|
+
if hasattr(model, "get_params"):
|
|
404
|
+
logger.log_config({"model": model.get_params()})
|
|
405
|
+
|
|
406
|
+
# Run evaluation
|
|
407
|
+
result = evaluator.evaluate(X, y, model, **kwargs)
|
|
408
|
+
|
|
409
|
+
# Log results
|
|
410
|
+
logger.log_evaluation_summary(result)
|
|
411
|
+
|
|
412
|
+
return result
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Percentile computation utilities for threshold-based signal generation.
|
|
3
|
+
|
|
4
|
+
Provides fast percentile computation from fold-specific predictions using Polars,
|
|
5
|
+
designed to prevent data leakage by computing thresholds from training data only.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import polars as pl
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def compute_fold_percentiles(
|
|
15
|
+
predictions: pd.DataFrame | pl.DataFrame,
|
|
16
|
+
percentiles: Sequence[float],
|
|
17
|
+
fold_col: str = "fold_id",
|
|
18
|
+
iteration_col: str = "iteration",
|
|
19
|
+
prediction_col: str = "prediction",
|
|
20
|
+
verbose: bool = True,
|
|
21
|
+
) -> pd.DataFrame:
|
|
22
|
+
"""
|
|
23
|
+
Compute percentiles from predictions grouped by fold and iteration.
|
|
24
|
+
|
|
25
|
+
Uses efficient Polars group_by operation to compute percentiles 10-50x faster
|
|
26
|
+
than nested loops. Designed for threshold-based signal generation where
|
|
27
|
+
thresholds must be computed from TRAINING predictions only to prevent data leakage.
|
|
28
|
+
|
|
29
|
+
Performance: ~50-100ms for 89M predictions with 26 percentiles (vs 5-10s with loops)
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
predictions: DataFrame with predictions to compute percentiles from
|
|
33
|
+
Must contain: fold_col, iteration_col, prediction_col
|
|
34
|
+
percentiles: List of percentiles to compute (e.g., [0.1, 0.5, 1, ..., 99, 99.5, 99.9])
|
|
35
|
+
Values should be in range [0, 100]
|
|
36
|
+
fold_col: Name of fold identifier column (default: "fold_id")
|
|
37
|
+
iteration_col: Name of iteration/checkpoint column (default: "iteration")
|
|
38
|
+
prediction_col: Name of prediction values column (default: "prediction")
|
|
39
|
+
verbose: Print progress information (default: True)
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
DataFrame with columns: [fold_col, iteration_col, p{percentile}, ...]
|
|
43
|
+
- One row per (fold, iteration) combination
|
|
44
|
+
- Percentile columns named like "p0.1", "p99.9", etc.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> # Training predictions: 13 folds × 10 iterations × 687k samples
|
|
48
|
+
>>> import pandas as pd
|
|
49
|
+
>>> predictions = pd.DataFrame({
|
|
50
|
+
... 'fold_id': [0] * 1000 + [1] * 1000,
|
|
51
|
+
... 'iteration': [50] * 500 + [100] * 500 + [50] * 500 + [100] * 500,
|
|
52
|
+
... 'prediction': np.random.rand(2000)
|
|
53
|
+
... })
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Compute percentiles for LONG and SHORT strategies
|
|
56
|
+
>>> percentiles = [0.1, 0.5, 1, 5, 10, 90, 95, 99, 99.5, 99.9]
|
|
57
|
+
>>> thresholds = compute_fold_percentiles(predictions, percentiles)
|
|
58
|
+
>>>
|
|
59
|
+
>>> # Result: 2 rows (2 folds) × 2 iterations = 4 rows
|
|
60
|
+
>>> thresholds.shape
|
|
61
|
+
(4, 12) # 2 meta columns + 10 percentile columns
|
|
62
|
+
>>>
|
|
63
|
+
>>> # Use for signal generation
|
|
64
|
+
>>> fold_0_iter_100 = thresholds[
|
|
65
|
+
... (thresholds['fold_id'] == 0) & (thresholds['iteration'] == 100)
|
|
66
|
+
... ]
|
|
67
|
+
>>> long_threshold = fold_0_iter_100['p95'].values[0]
|
|
68
|
+
>>> short_threshold = fold_0_iter_100['p5'].values[0]
|
|
69
|
+
|
|
70
|
+
Methodology:
|
|
71
|
+
1. Convert predictions to Polars (if pandas)
|
|
72
|
+
2. Group by (fold_id, iteration)
|
|
73
|
+
3. Compute all percentiles in single aggregation
|
|
74
|
+
4. Return as pandas DataFrame
|
|
75
|
+
|
|
76
|
+
Data Leakage Prevention:
|
|
77
|
+
CRITICAL: This function should ONLY be called on TRAINING predictions.
|
|
78
|
+
- Training: compute_fold_percentiles(train_predictions) → save thresholds
|
|
79
|
+
- Validation: Apply saved thresholds to OOS predictions
|
|
80
|
+
- NEVER: compute_fold_percentiles(val_predictions) → data leakage!
|
|
81
|
+
|
|
82
|
+
Performance Notes:
|
|
83
|
+
- Polars group_by is 10-50x faster than nested loops
|
|
84
|
+
- Memory usage: O(n_predictions) for single pass
|
|
85
|
+
- Time complexity: O(n * log(n)) for sorting within groups
|
|
86
|
+
- Recommended for predictions > 1M rows
|
|
87
|
+
"""
|
|
88
|
+
if verbose:
|
|
89
|
+
print("\nComputing fold-specific percentiles (Fast Polars Method)...")
|
|
90
|
+
|
|
91
|
+
# Convert to Polars if pandas
|
|
92
|
+
preds_pl = pl.from_pandas(predictions) if isinstance(predictions, pd.DataFrame) else predictions
|
|
93
|
+
|
|
94
|
+
# Validate required columns
|
|
95
|
+
required_cols = {fold_col, iteration_col, prediction_col}
|
|
96
|
+
available_cols = set(preds_pl.columns)
|
|
97
|
+
missing = required_cols - available_cols
|
|
98
|
+
if missing:
|
|
99
|
+
raise ValueError(f"Missing required columns: {missing}. Available: {available_cols}")
|
|
100
|
+
|
|
101
|
+
# Convert percentiles to quantiles (0-1 range)
|
|
102
|
+
quantiles = [p / 100 for p in percentiles]
|
|
103
|
+
|
|
104
|
+
# Compute percentiles with single group_by operation
|
|
105
|
+
percentiles_df = (
|
|
106
|
+
preds_pl.group_by([fold_col, iteration_col])
|
|
107
|
+
.agg(
|
|
108
|
+
[
|
|
109
|
+
pl.col(prediction_col).quantile(q, interpolation="linear").alias(f"p{p}")
|
|
110
|
+
for q, p in zip(quantiles, percentiles, strict=False)
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
.sort([fold_col, iteration_col])
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Convert back to pandas for compatibility
|
|
117
|
+
result = percentiles_df.to_pandas()
|
|
118
|
+
|
|
119
|
+
if verbose:
|
|
120
|
+
n_folds = result[fold_col].nunique()
|
|
121
|
+
n_iterations = result[iteration_col].nunique()
|
|
122
|
+
print(f"✓ Computed {len(result)} percentile arrays")
|
|
123
|
+
print(
|
|
124
|
+
f"✓ Structure: {n_folds} folds × {n_iterations} iterations × {len(percentiles)} percentiles"
|
|
125
|
+
)
|
|
126
|
+
print(f"✓ Percentile columns: {sorted([c for c in result.columns if c.startswith('p')])}")
|
|
127
|
+
|
|
128
|
+
return result
|
ml4t/diagnostic/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# PEP 561 marker file - this package supports type checking
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Report generation module for ML4T Diagnostic results.
|
|
2
|
+
|
|
3
|
+
Provides flexible report generation in multiple formats:
|
|
4
|
+
- HTML: Rich, styled reports with tables and charts
|
|
5
|
+
- JSON: Machine-readable structured output
|
|
6
|
+
- Markdown: Human-readable documentation
|
|
7
|
+
|
|
8
|
+
Examples:
|
|
9
|
+
>>> from ml4t.diagnostic.reporting import ReportFactory, ReportFormat
|
|
10
|
+
>>> from ml4t.diagnostic.results import FeatureDiagnosticsResult
|
|
11
|
+
>>>
|
|
12
|
+
>>> # Generate HTML report
|
|
13
|
+
>>> html_report = ReportFactory.render(result, ReportFormat.HTML)
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Generate JSON report
|
|
16
|
+
>>> json_report = ReportFactory.render(result, ReportFormat.JSON, indent=4)
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Generate Markdown report
|
|
19
|
+
>>> md_report = ReportFactory.render(result, ReportFormat.MARKDOWN)
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Save to file
|
|
22
|
+
>>> generator = ReportFactory.create(ReportFormat.HTML)
|
|
23
|
+
>>> html = generator.render(result)
|
|
24
|
+
>>> generator.save(html, "report.html")
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from ml4t.diagnostic.reporting.base import ReportFactory, ReportFormat, ReportGenerator
|
|
28
|
+
|
|
29
|
+
# Import renderers to trigger registration
|
|
30
|
+
from ml4t.diagnostic.reporting.html_renderer import HTMLReportGenerator
|
|
31
|
+
from ml4t.diagnostic.reporting.json_renderer import JSONReportGenerator
|
|
32
|
+
from ml4t.diagnostic.reporting.markdown_renderer import MarkdownReportGenerator
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
# Factory and base
|
|
36
|
+
"ReportFactory",
|
|
37
|
+
"ReportFormat",
|
|
38
|
+
"ReportGenerator",
|
|
39
|
+
# Renderers
|
|
40
|
+
"HTMLReportGenerator",
|
|
41
|
+
"JSONReportGenerator",
|
|
42
|
+
"MarkdownReportGenerator",
|
|
43
|
+
]
|