ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,899 @@
|
|
|
1
|
+
"""Polars-specific optimizations for DataFrame operations.
|
|
2
|
+
|
|
3
|
+
This module provides optimized implementations of common operations
|
|
4
|
+
when working with Polars DataFrames, leveraging Polars' lazy evaluation
|
|
5
|
+
and columnar operations for improved performance.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from numpy.typing import NDArray
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PolarsBackend:
|
|
19
|
+
"""Optimized operations for Polars DataFrames.
|
|
20
|
+
|
|
21
|
+
This backend provides performance-optimized implementations
|
|
22
|
+
of common operations used throughout ml4t-diagnostic when working with
|
|
23
|
+
Polars DataFrames. Includes memory-efficient streaming methods
|
|
24
|
+
for handling large datasets (10M+ samples) without memory issues.
|
|
25
|
+
|
|
26
|
+
Key Features:
|
|
27
|
+
- Vectorized rolling correlations
|
|
28
|
+
- Memory-efficient streaming for large datasets
|
|
29
|
+
- Adaptive chunk sizing based on available memory
|
|
30
|
+
- Multi-horizon Information Coefficient calculations
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def fast_rolling_correlation(
|
|
35
|
+
x: pl.Series,
|
|
36
|
+
y: pl.Series,
|
|
37
|
+
window: int,
|
|
38
|
+
min_periods: int | None = None,
|
|
39
|
+
) -> pl.Series:
|
|
40
|
+
"""Compute rolling correlation efficiently.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
x : pl.Series
|
|
45
|
+
First series
|
|
46
|
+
y : pl.Series
|
|
47
|
+
Second series
|
|
48
|
+
window : int
|
|
49
|
+
Rolling window size
|
|
50
|
+
min_periods : int, optional
|
|
51
|
+
Minimum number of observations required
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
-------
|
|
55
|
+
pl.Series
|
|
56
|
+
Rolling correlation values
|
|
57
|
+
"""
|
|
58
|
+
if min_periods is None:
|
|
59
|
+
min_periods = window
|
|
60
|
+
|
|
61
|
+
# Use Polars' native rolling correlation
|
|
62
|
+
df = pl.DataFrame({"x": x, "y": y})
|
|
63
|
+
|
|
64
|
+
# Compute rolling stats needed for correlation
|
|
65
|
+
rolling_df = df.select(
|
|
66
|
+
[
|
|
67
|
+
pl.col("x").rolling_mean(window, min_samples=min_periods).alias("x_mean"),
|
|
68
|
+
pl.col("y").rolling_mean(window, min_samples=min_periods).alias("y_mean"),
|
|
69
|
+
(pl.col("x") * pl.col("y"))
|
|
70
|
+
.rolling_mean(window, min_samples=min_periods)
|
|
71
|
+
.alias("xy_mean"),
|
|
72
|
+
(pl.col("x") ** 2).rolling_mean(window, min_samples=min_periods).alias("x2_mean"),
|
|
73
|
+
(pl.col("y") ** 2).rolling_mean(window, min_samples=min_periods).alias("y2_mean"),
|
|
74
|
+
],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Calculate correlation from components
|
|
78
|
+
result = rolling_df.select(
|
|
79
|
+
[
|
|
80
|
+
(
|
|
81
|
+
(pl.col("xy_mean") - pl.col("x_mean") * pl.col("y_mean"))
|
|
82
|
+
/ (
|
|
83
|
+
(
|
|
84
|
+
(pl.col("x2_mean") - pl.col("x_mean") ** 2)
|
|
85
|
+
* (pl.col("y2_mean") - pl.col("y_mean") ** 2)
|
|
86
|
+
)
|
|
87
|
+
** 0.5
|
|
88
|
+
)
|
|
89
|
+
).alias("correlation"),
|
|
90
|
+
],
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return result["correlation"]
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def fast_rolling_spearman_correlation(
|
|
97
|
+
x: pl.Series,
|
|
98
|
+
y: pl.Series,
|
|
99
|
+
window: int,
|
|
100
|
+
min_periods: int | None = None,
|
|
101
|
+
) -> pl.Series:
|
|
102
|
+
"""Compute rolling Spearman correlation using native Polars operations.
|
|
103
|
+
|
|
104
|
+
This implementation calculates ranks within each rolling window to avoid
|
|
105
|
+
lookahead bias, ensuring that the rank at time T only uses data up to time T.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
x : pl.Series
|
|
110
|
+
First series
|
|
111
|
+
y : pl.Series
|
|
112
|
+
Second series
|
|
113
|
+
window : int
|
|
114
|
+
Rolling window size
|
|
115
|
+
min_periods : int, optional
|
|
116
|
+
Minimum number of observations required
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
-------
|
|
120
|
+
pl.Series
|
|
121
|
+
Rolling Spearman correlation values
|
|
122
|
+
"""
|
|
123
|
+
if min_periods is None:
|
|
124
|
+
min_periods = max(2, window // 2)
|
|
125
|
+
|
|
126
|
+
# Import scipy for rank calculation
|
|
127
|
+
import numpy as np
|
|
128
|
+
from scipy.stats import rankdata
|
|
129
|
+
|
|
130
|
+
# Convert to numpy for processing
|
|
131
|
+
x_values = x.to_numpy()
|
|
132
|
+
y_values = y.to_numpy()
|
|
133
|
+
n = len(x_values)
|
|
134
|
+
|
|
135
|
+
# Initialize result array
|
|
136
|
+
result = np.full(n, np.nan)
|
|
137
|
+
|
|
138
|
+
# Calculate rolling Spearman correlation
|
|
139
|
+
for i in range(n):
|
|
140
|
+
# Define window boundaries
|
|
141
|
+
start_idx = max(0, i - window + 1)
|
|
142
|
+
end_idx = i + 1
|
|
143
|
+
|
|
144
|
+
# Extract window data
|
|
145
|
+
x_window = x_values[start_idx:end_idx]
|
|
146
|
+
y_window = y_values[start_idx:end_idx]
|
|
147
|
+
|
|
148
|
+
# Check minimum periods
|
|
149
|
+
if len(x_window) < min_periods:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
# Handle NaN values
|
|
153
|
+
mask = ~(np.isnan(x_window) | np.isnan(y_window))
|
|
154
|
+
if np.sum(mask) < min_periods:
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
x_clean = x_window[mask]
|
|
158
|
+
y_clean = y_window[mask]
|
|
159
|
+
|
|
160
|
+
# Calculate ranks within window
|
|
161
|
+
if len(x_clean) > 1:
|
|
162
|
+
x_ranks = rankdata(x_clean, method="average")
|
|
163
|
+
y_ranks = rankdata(y_clean, method="average")
|
|
164
|
+
|
|
165
|
+
# Calculate correlation on ranks
|
|
166
|
+
x_std = np.std(x_ranks, ddof=1)
|
|
167
|
+
y_std = np.std(y_ranks, ddof=1)
|
|
168
|
+
|
|
169
|
+
if x_std > 0 and y_std > 0:
|
|
170
|
+
# Pearson correlation on ranks = Spearman correlation
|
|
171
|
+
corr = np.corrcoef(x_ranks, y_ranks)[0, 1]
|
|
172
|
+
result[i] = corr
|
|
173
|
+
|
|
174
|
+
return pl.Series(result)
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def fast_multi_horizon_ic(
|
|
178
|
+
predictions: pl.Series,
|
|
179
|
+
returns_matrix: pl.DataFrame,
|
|
180
|
+
window: int,
|
|
181
|
+
min_periods: int | None = None,
|
|
182
|
+
) -> pl.DataFrame:
|
|
183
|
+
"""Calculate rolling IC for multiple return horizons efficiently.
|
|
184
|
+
|
|
185
|
+
This is a specialized function for IC heatmap calculations that processes
|
|
186
|
+
multiple horizons in parallel. Ranks are calculated within each rolling
|
|
187
|
+
window to avoid lookahead bias, ensuring that the rank at time T only
|
|
188
|
+
uses data up to time T.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
predictions : pl.Series
|
|
193
|
+
Model predictions
|
|
194
|
+
returns_matrix : pl.DataFrame
|
|
195
|
+
Returns for different horizons (columns = horizons)
|
|
196
|
+
window : int
|
|
197
|
+
Rolling window size
|
|
198
|
+
min_periods : int, optional
|
|
199
|
+
Minimum periods required
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
-------
|
|
203
|
+
pl.DataFrame
|
|
204
|
+
DataFrame with rolling IC for each horizon
|
|
205
|
+
"""
|
|
206
|
+
if min_periods is None:
|
|
207
|
+
min_periods = max(2, window // 2)
|
|
208
|
+
|
|
209
|
+
import numpy as np
|
|
210
|
+
from scipy.stats import rankdata
|
|
211
|
+
|
|
212
|
+
pred_values = predictions.to_numpy()
|
|
213
|
+
n = len(pred_values)
|
|
214
|
+
|
|
215
|
+
# Initialize result matrix
|
|
216
|
+
result_data = {}
|
|
217
|
+
|
|
218
|
+
for col in returns_matrix.columns:
|
|
219
|
+
ret_values = returns_matrix[col].to_numpy()
|
|
220
|
+
ic_result = np.full(n, np.nan)
|
|
221
|
+
|
|
222
|
+
# Calculate IC for each window position
|
|
223
|
+
for i in range(n):
|
|
224
|
+
start_idx = max(0, i - window + 1)
|
|
225
|
+
end_idx = i + 1
|
|
226
|
+
|
|
227
|
+
pred_window = pred_values[start_idx:end_idx]
|
|
228
|
+
ret_window = ret_values[start_idx:end_idx]
|
|
229
|
+
|
|
230
|
+
if len(pred_window) < min_periods:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
# Remove NaN values
|
|
234
|
+
mask = ~(np.isnan(pred_window) | np.isnan(ret_window))
|
|
235
|
+
if np.sum(mask) < min_periods:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
pred_clean = pred_window[mask]
|
|
239
|
+
ret_clean = ret_window[mask]
|
|
240
|
+
|
|
241
|
+
if len(pred_clean) > 1:
|
|
242
|
+
# Calculate ranks within this window only
|
|
243
|
+
pred_ranks = rankdata(pred_clean, method="average")
|
|
244
|
+
ret_ranks = rankdata(ret_clean, method="average")
|
|
245
|
+
|
|
246
|
+
# Check for constant values (all ranks identical)
|
|
247
|
+
pred_std = np.std(pred_ranks, ddof=1)
|
|
248
|
+
ret_std = np.std(ret_ranks, ddof=1)
|
|
249
|
+
|
|
250
|
+
if pred_std > 0 and ret_std > 0:
|
|
251
|
+
# Compute Spearman correlation (Pearson on ranks)
|
|
252
|
+
corr = np.corrcoef(pred_ranks, ret_ranks)[0, 1]
|
|
253
|
+
ic_result[i] = corr
|
|
254
|
+
|
|
255
|
+
result_data[f"ic_{col}"] = ic_result
|
|
256
|
+
|
|
257
|
+
# Convert to Polars DataFrame
|
|
258
|
+
return pl.DataFrame(result_data)
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def _rolling_correlation_expr(
|
|
262
|
+
x_col: str,
|
|
263
|
+
y_col: str,
|
|
264
|
+
window: int,
|
|
265
|
+
min_periods: int,
|
|
266
|
+
) -> pl.Expr:
|
|
267
|
+
"""Create a Polars expression for rolling correlation between two columns."""
|
|
268
|
+
# Rolling means
|
|
269
|
+
x_mean = pl.col(x_col).rolling_mean(window, min_samples=min_periods)
|
|
270
|
+
y_mean = pl.col(y_col).rolling_mean(window, min_samples=min_periods)
|
|
271
|
+
|
|
272
|
+
# Rolling products and squares
|
|
273
|
+
xy_mean = (pl.col(x_col) * pl.col(y_col)).rolling_mean(
|
|
274
|
+
window,
|
|
275
|
+
min_samples=min_periods,
|
|
276
|
+
)
|
|
277
|
+
x2_mean = (pl.col(x_col) ** 2).rolling_mean(window, min_samples=min_periods)
|
|
278
|
+
y2_mean = (pl.col(y_col) ** 2).rolling_mean(window, min_samples=min_periods)
|
|
279
|
+
|
|
280
|
+
# Compute correlation using the formula: corr = cov(x,y) / (std(x) * std(y))
|
|
281
|
+
# where cov(x,y) = E[xy] - E[x]E[y] and var(x) = E[x²] - E[x]²
|
|
282
|
+
numerator = xy_mean - (x_mean * y_mean)
|
|
283
|
+
denominator = ((x2_mean - x_mean**2) * (y2_mean - y_mean**2)) ** 0.5
|
|
284
|
+
|
|
285
|
+
# Handle division by zero
|
|
286
|
+
correlation = pl.when(denominator > 1e-10).then(numerator / denominator).otherwise(0.0)
|
|
287
|
+
|
|
288
|
+
return correlation
|
|
289
|
+
|
|
290
|
+
@staticmethod
|
|
291
|
+
def fast_quantile_assignment(
|
|
292
|
+
data: pl.DataFrame,
|
|
293
|
+
column: str,
|
|
294
|
+
n_quantiles: int,
|
|
295
|
+
by_group: str | None = None,
|
|
296
|
+
) -> pl.DataFrame:
|
|
297
|
+
"""Assign quantile labels efficiently.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
data : pl.DataFrame
|
|
302
|
+
Input data
|
|
303
|
+
column : str
|
|
304
|
+
Column to compute quantiles on
|
|
305
|
+
n_quantiles : int
|
|
306
|
+
Number of quantiles
|
|
307
|
+
by_group : str, optional
|
|
308
|
+
Column to group by before quantile assignment
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
-------
|
|
312
|
+
pl.DataFrame
|
|
313
|
+
Data with quantile labels added
|
|
314
|
+
"""
|
|
315
|
+
if by_group is not None:
|
|
316
|
+
# Group-wise quantile assignment
|
|
317
|
+
result = data.with_columns(
|
|
318
|
+
pl.col(column)
|
|
319
|
+
.qcut(n_quantiles, labels=[str(i) for i in range(1, n_quantiles + 1)])
|
|
320
|
+
.over(by_group)
|
|
321
|
+
.alias(f"{column}_quantile"),
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
# Global quantile assignment
|
|
325
|
+
result = data.with_columns(
|
|
326
|
+
pl.col(column)
|
|
327
|
+
.qcut(n_quantiles, labels=[str(i) for i in range(1, n_quantiles + 1)])
|
|
328
|
+
.alias(f"{column}_quantile"),
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
return result
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def fast_time_aware_split(
|
|
335
|
+
data: pl.DataFrame,
|
|
336
|
+
time_column: str,
|
|
337
|
+
test_start: Any,
|
|
338
|
+
test_end: Any,
|
|
339
|
+
buffer_before: int | None = None,
|
|
340
|
+
buffer_after: int | None = None,
|
|
341
|
+
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
|
342
|
+
"""Split data into train/test/buffer sets efficiently.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
data : pl.DataFrame
|
|
347
|
+
Input data with time column
|
|
348
|
+
time_column : str
|
|
349
|
+
Name of time column
|
|
350
|
+
test_start : Any
|
|
351
|
+
Test period start time
|
|
352
|
+
test_end : Any
|
|
353
|
+
Test period end time
|
|
354
|
+
buffer_before : int, optional
|
|
355
|
+
Purge buffer before test
|
|
356
|
+
buffer_after : int, optional
|
|
357
|
+
Embargo buffer after test
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
-------
|
|
361
|
+
train_df : pl.DataFrame
|
|
362
|
+
Training data
|
|
363
|
+
test_df : pl.DataFrame
|
|
364
|
+
Test data
|
|
365
|
+
buffer_df : pl.DataFrame
|
|
366
|
+
Buffer zone data
|
|
367
|
+
"""
|
|
368
|
+
# Create efficient filters
|
|
369
|
+
test_mask = (pl.col(time_column) >= test_start) & (pl.col(time_column) < test_end)
|
|
370
|
+
|
|
371
|
+
# Apply test filter
|
|
372
|
+
test_df = data.filter(test_mask)
|
|
373
|
+
|
|
374
|
+
# Create buffer masks if needed
|
|
375
|
+
if buffer_before is not None:
|
|
376
|
+
buffer_start = test_start - buffer_before
|
|
377
|
+
before_buffer_mask = (pl.col(time_column) >= buffer_start) & (
|
|
378
|
+
pl.col(time_column) < test_start
|
|
379
|
+
)
|
|
380
|
+
else:
|
|
381
|
+
before_buffer_mask = pl.lit(False)
|
|
382
|
+
|
|
383
|
+
if buffer_after is not None:
|
|
384
|
+
buffer_end = test_end + buffer_after
|
|
385
|
+
after_buffer_mask = (pl.col(time_column) >= test_end) & (
|
|
386
|
+
pl.col(time_column) < buffer_end
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
after_buffer_mask = pl.lit(False)
|
|
390
|
+
|
|
391
|
+
# Combine buffer masks
|
|
392
|
+
buffer_mask = before_buffer_mask | after_buffer_mask
|
|
393
|
+
buffer_df = data.filter(buffer_mask)
|
|
394
|
+
|
|
395
|
+
# Train is everything not in test or buffer
|
|
396
|
+
train_mask = ~(test_mask | buffer_mask)
|
|
397
|
+
train_df = data.filter(train_mask)
|
|
398
|
+
|
|
399
|
+
return train_df, test_df, buffer_df
|
|
400
|
+
|
|
401
|
+
@staticmethod
|
|
402
|
+
def fast_group_statistics(
|
|
403
|
+
data: pl.DataFrame,
|
|
404
|
+
group_column: str,
|
|
405
|
+
value_column: str,
|
|
406
|
+
statistics: list[str],
|
|
407
|
+
) -> pl.DataFrame:
|
|
408
|
+
"""Compute group statistics efficiently.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
data : pl.DataFrame
|
|
413
|
+
Input data
|
|
414
|
+
group_column : str
|
|
415
|
+
Column to group by
|
|
416
|
+
value_column : str
|
|
417
|
+
Column to compute statistics on
|
|
418
|
+
statistics : list[str]
|
|
419
|
+
List of statistics to compute
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
-------
|
|
423
|
+
pl.DataFrame
|
|
424
|
+
Group statistics
|
|
425
|
+
"""
|
|
426
|
+
# Map statistic names to Polars expressions
|
|
427
|
+
stat_exprs = []
|
|
428
|
+
|
|
429
|
+
for stat in statistics:
|
|
430
|
+
if stat == "mean":
|
|
431
|
+
stat_exprs.append(
|
|
432
|
+
pl.col(value_column).mean().alias(f"{value_column}_mean"),
|
|
433
|
+
)
|
|
434
|
+
elif stat == "std":
|
|
435
|
+
stat_exprs.append(
|
|
436
|
+
pl.col(value_column).std().alias(f"{value_column}_std"),
|
|
437
|
+
)
|
|
438
|
+
elif stat == "min":
|
|
439
|
+
stat_exprs.append(
|
|
440
|
+
pl.col(value_column).min().alias(f"{value_column}_min"),
|
|
441
|
+
)
|
|
442
|
+
elif stat == "max":
|
|
443
|
+
stat_exprs.append(
|
|
444
|
+
pl.col(value_column).max().alias(f"{value_column}_max"),
|
|
445
|
+
)
|
|
446
|
+
elif stat == "count":
|
|
447
|
+
stat_exprs.append(
|
|
448
|
+
pl.col(value_column).count().alias(f"{value_column}_count"),
|
|
449
|
+
)
|
|
450
|
+
elif stat == "sum":
|
|
451
|
+
stat_exprs.append(
|
|
452
|
+
pl.col(value_column).sum().alias(f"{value_column}_sum"),
|
|
453
|
+
)
|
|
454
|
+
elif stat == "median":
|
|
455
|
+
stat_exprs.append(
|
|
456
|
+
pl.col(value_column).median().alias(f"{value_column}_median"),
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
raise ValueError(f"Unknown statistic: {stat}")
|
|
460
|
+
|
|
461
|
+
# Compute all statistics in one pass
|
|
462
|
+
result = data.group_by(group_column).agg(stat_exprs)
|
|
463
|
+
|
|
464
|
+
return result
|
|
465
|
+
|
|
466
|
+
@staticmethod
|
|
467
|
+
def fast_expanding_window(
|
|
468
|
+
data: pl.DataFrame,
|
|
469
|
+
columns: list[str],
|
|
470
|
+
operation: str = "mean",
|
|
471
|
+
min_periods: int = 1,
|
|
472
|
+
) -> pl.DataFrame:
|
|
473
|
+
"""Compute expanding window statistics efficiently.
|
|
474
|
+
|
|
475
|
+
Parameters
|
|
476
|
+
----------
|
|
477
|
+
data : pl.DataFrame
|
|
478
|
+
Input data
|
|
479
|
+
columns : list[str]
|
|
480
|
+
Columns to compute expanding statistics on
|
|
481
|
+
operation : str
|
|
482
|
+
Operation to apply (mean, std, sum, etc.)
|
|
483
|
+
min_periods : int
|
|
484
|
+
Minimum number of observations required
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
-------
|
|
488
|
+
pl.DataFrame
|
|
489
|
+
Data with expanding statistics added
|
|
490
|
+
"""
|
|
491
|
+
result = data
|
|
492
|
+
|
|
493
|
+
for col in columns:
|
|
494
|
+
if operation == "mean":
|
|
495
|
+
expr = pl.col(col).cum_sum() / pl.int_range(1, pl.len() + 1)
|
|
496
|
+
elif operation == "std":
|
|
497
|
+
# O(n) expanding standard deviation using Welford's online algorithm
|
|
498
|
+
# Var(X) = E[X²] - E[X]² → std = sqrt(Var * n/(n-1)) with Bessel correction
|
|
499
|
+
n_expr = pl.int_range(1, pl.len() + 1).cast(pl.Float64)
|
|
500
|
+
cum_sum = pl.col(col).cum_sum()
|
|
501
|
+
cum_sum_sq = (pl.col(col) ** 2).cum_sum()
|
|
502
|
+
|
|
503
|
+
# Expanding mean of squares and square of mean
|
|
504
|
+
mean_of_sq = cum_sum_sq / n_expr
|
|
505
|
+
mean_sq = (cum_sum / n_expr) ** 2
|
|
506
|
+
|
|
507
|
+
# Population variance (before Bessel correction)
|
|
508
|
+
variance = mean_of_sq - mean_sq
|
|
509
|
+
|
|
510
|
+
# Apply Bessel correction: sample_var = pop_var * n / (n-1)
|
|
511
|
+
# Handle n=1 case (variance undefined for single observation)
|
|
512
|
+
# and min_periods requirement
|
|
513
|
+
expr = (
|
|
514
|
+
pl.when(n_expr >= max(min_periods, 2))
|
|
515
|
+
.then((variance * n_expr / (n_expr - 1)).sqrt())
|
|
516
|
+
.otherwise(None)
|
|
517
|
+
)
|
|
518
|
+
elif operation == "sum":
|
|
519
|
+
expr = pl.col(col).cum_sum()
|
|
520
|
+
elif operation == "min":
|
|
521
|
+
expr = pl.col(col).cum_min()
|
|
522
|
+
elif operation == "max":
|
|
523
|
+
expr = pl.col(col).cum_max()
|
|
524
|
+
else:
|
|
525
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
526
|
+
|
|
527
|
+
result = result.with_columns(expr.alias(f"{col}_expanding_{operation}"))
|
|
528
|
+
|
|
529
|
+
return result
|
|
530
|
+
|
|
531
|
+
@staticmethod
|
|
532
|
+
def to_numpy_batch(
|
|
533
|
+
data: pl.DataFrame,
|
|
534
|
+
columns: list[str] | None = None,
|
|
535
|
+
batch_size: int = 10000,
|
|
536
|
+
) -> "NDArray[Any]":
|
|
537
|
+
"""Convert DataFrame to numpy array in batches for memory efficiency.
|
|
538
|
+
|
|
539
|
+
Parameters
|
|
540
|
+
----------
|
|
541
|
+
data : pl.DataFrame
|
|
542
|
+
Input DataFrame
|
|
543
|
+
columns : list[str], optional
|
|
544
|
+
Columns to convert (all if None)
|
|
545
|
+
batch_size : int
|
|
546
|
+
Batch size for conversion
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
-------
|
|
550
|
+
np.ndarray
|
|
551
|
+
Numpy array
|
|
552
|
+
"""
|
|
553
|
+
if columns is not None:
|
|
554
|
+
data = data.select(columns)
|
|
555
|
+
|
|
556
|
+
# For small data, convert directly
|
|
557
|
+
if len(data) <= batch_size:
|
|
558
|
+
return data.to_numpy()
|
|
559
|
+
|
|
560
|
+
# For large data, convert in batches
|
|
561
|
+
n_rows = len(data)
|
|
562
|
+
n_cols = data.shape[1]
|
|
563
|
+
result = np.empty((n_rows, n_cols), dtype=np.float64)
|
|
564
|
+
|
|
565
|
+
for i in range(0, n_rows, batch_size):
|
|
566
|
+
end_idx = min(i + batch_size, n_rows)
|
|
567
|
+
result[i:end_idx] = data[i:end_idx].to_numpy()
|
|
568
|
+
|
|
569
|
+
return result
|
|
570
|
+
|
|
571
|
+
@staticmethod
|
|
572
|
+
def fast_rolling_correlation_streaming(
|
|
573
|
+
x: pl.Series,
|
|
574
|
+
y: pl.Series,
|
|
575
|
+
window: int,
|
|
576
|
+
min_periods: int | None = None,
|
|
577
|
+
chunk_size: int = 50000,
|
|
578
|
+
) -> pl.Series:
|
|
579
|
+
"""Compute rolling correlation for large datasets using streaming.
|
|
580
|
+
|
|
581
|
+
This method processes data in chunks to manage memory usage for
|
|
582
|
+
very large datasets while maintaining accuracy through proper
|
|
583
|
+
overlap handling.
|
|
584
|
+
|
|
585
|
+
Parameters
|
|
586
|
+
----------
|
|
587
|
+
x : pl.Series
|
|
588
|
+
First series
|
|
589
|
+
y : pl.Series
|
|
590
|
+
Second series
|
|
591
|
+
window : int
|
|
592
|
+
Rolling window size
|
|
593
|
+
min_periods : int, optional
|
|
594
|
+
Minimum number of observations required
|
|
595
|
+
chunk_size : int, default 50000
|
|
596
|
+
Size of chunks to process. Larger chunks use more memory
|
|
597
|
+
but may be more efficient.
|
|
598
|
+
|
|
599
|
+
Returns
|
|
600
|
+
-------
|
|
601
|
+
pl.Series
|
|
602
|
+
Rolling correlation values
|
|
603
|
+
|
|
604
|
+
Notes
|
|
605
|
+
-----
|
|
606
|
+
This function is designed for datasets larger than 100k samples.
|
|
607
|
+
For smaller datasets, use fast_rolling_correlation() directly
|
|
608
|
+
as it will be more efficient.
|
|
609
|
+
|
|
610
|
+
Examples
|
|
611
|
+
--------
|
|
612
|
+
>>> x = pl.Series("x", range(200000))
|
|
613
|
+
>>> y = pl.Series("y", np.random.randn(200000))
|
|
614
|
+
>>> corr = PolarsBackend.fast_rolling_correlation_streaming(x, y, 100)
|
|
615
|
+
"""
|
|
616
|
+
if min_periods is None:
|
|
617
|
+
min_periods = window
|
|
618
|
+
|
|
619
|
+
n_samples = len(x)
|
|
620
|
+
|
|
621
|
+
# For small datasets, use the standard method
|
|
622
|
+
if n_samples <= chunk_size:
|
|
623
|
+
return PolarsBackend.fast_rolling_correlation(x, y, window, min_periods)
|
|
624
|
+
|
|
625
|
+
# For very large datasets, use streaming approach
|
|
626
|
+
results = []
|
|
627
|
+
overlap = window - 1 # Overlap needed to maintain continuity
|
|
628
|
+
|
|
629
|
+
for start in range(0, n_samples, chunk_size - overlap):
|
|
630
|
+
end = min(start + chunk_size, n_samples)
|
|
631
|
+
|
|
632
|
+
# Extract chunk with overlap
|
|
633
|
+
x_chunk = x[start:end]
|
|
634
|
+
y_chunk = y[start:end]
|
|
635
|
+
|
|
636
|
+
# Process chunk
|
|
637
|
+
chunk_result = PolarsBackend.fast_rolling_correlation(
|
|
638
|
+
x_chunk,
|
|
639
|
+
y_chunk,
|
|
640
|
+
window,
|
|
641
|
+
min_periods,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Handle overlap: remove duplicate results from previous chunks
|
|
645
|
+
if start > 0:
|
|
646
|
+
# Remove the overlapping window-1 results
|
|
647
|
+
chunk_result = chunk_result[overlap:]
|
|
648
|
+
|
|
649
|
+
results.append(chunk_result)
|
|
650
|
+
|
|
651
|
+
# Concatenate all results - results list contains Series, so concat returns Series
|
|
652
|
+
concatenated = pl.concat(results)
|
|
653
|
+
assert isinstance(concatenated, pl.Series), "Expected Series from concatenating Series"
|
|
654
|
+
return concatenated
|
|
655
|
+
|
|
656
|
+
@staticmethod
|
|
657
|
+
def fast_multi_horizon_ic_streaming(
|
|
658
|
+
predictions: pl.Series,
|
|
659
|
+
returns_matrix: pl.DataFrame,
|
|
660
|
+
window: int,
|
|
661
|
+
min_periods: int | None = None,
|
|
662
|
+
chunk_size: int = 50000,
|
|
663
|
+
) -> pl.DataFrame:
|
|
664
|
+
"""Calculate rolling IC for multiple horizons using streaming for large datasets.
|
|
665
|
+
|
|
666
|
+
This is a memory-efficient version of fast_multi_horizon_ic that processes
|
|
667
|
+
data in chunks to handle very large datasets without memory issues.
|
|
668
|
+
|
|
669
|
+
Parameters
|
|
670
|
+
----------
|
|
671
|
+
predictions : pl.Series
|
|
672
|
+
Model predictions
|
|
673
|
+
returns_matrix : pl.DataFrame
|
|
674
|
+
Returns for different horizons (columns = horizons)
|
|
675
|
+
window : int
|
|
676
|
+
Rolling window size
|
|
677
|
+
min_periods : int, optional
|
|
678
|
+
Minimum periods required
|
|
679
|
+
chunk_size : int, default 50000
|
|
680
|
+
Size of chunks to process
|
|
681
|
+
|
|
682
|
+
Returns
|
|
683
|
+
-------
|
|
684
|
+
pl.DataFrame
|
|
685
|
+
DataFrame with rolling IC for each horizon
|
|
686
|
+
|
|
687
|
+
Examples
|
|
688
|
+
--------
|
|
689
|
+
>>> predictions = pl.Series("pred", np.random.randn(200000))
|
|
690
|
+
>>> returns = pl.DataFrame({
|
|
691
|
+
... "1d": np.random.randn(200000),
|
|
692
|
+
... "5d": np.random.randn(200000),
|
|
693
|
+
... "20d": np.random.randn(200000)
|
|
694
|
+
... })
|
|
695
|
+
>>> ic_matrix = PolarsBackend.fast_multi_horizon_ic_streaming(
|
|
696
|
+
... predictions, returns, window=100
|
|
697
|
+
... )
|
|
698
|
+
"""
|
|
699
|
+
if min_periods is None:
|
|
700
|
+
min_periods = max(2, window // 2)
|
|
701
|
+
|
|
702
|
+
n_samples = len(predictions)
|
|
703
|
+
|
|
704
|
+
# For small datasets, use the standard method
|
|
705
|
+
if n_samples <= chunk_size:
|
|
706
|
+
return PolarsBackend.fast_multi_horizon_ic(
|
|
707
|
+
predictions,
|
|
708
|
+
returns_matrix,
|
|
709
|
+
window,
|
|
710
|
+
min_periods,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# For large datasets, use streaming approach
|
|
714
|
+
results = []
|
|
715
|
+
overlap = window - 1
|
|
716
|
+
|
|
717
|
+
for start in range(0, n_samples, chunk_size - overlap):
|
|
718
|
+
end = min(start + chunk_size, n_samples)
|
|
719
|
+
|
|
720
|
+
# Extract chunks
|
|
721
|
+
pred_chunk = predictions[start:end]
|
|
722
|
+
returns_chunk = returns_matrix[start:end]
|
|
723
|
+
|
|
724
|
+
# Process chunk
|
|
725
|
+
chunk_result = PolarsBackend.fast_multi_horizon_ic(
|
|
726
|
+
pred_chunk,
|
|
727
|
+
returns_chunk,
|
|
728
|
+
window,
|
|
729
|
+
min_periods,
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Handle overlap
|
|
733
|
+
if start > 0:
|
|
734
|
+
chunk_result = chunk_result[overlap:]
|
|
735
|
+
|
|
736
|
+
results.append(chunk_result)
|
|
737
|
+
|
|
738
|
+
# Concatenate all results
|
|
739
|
+
return pl.concat(results)
|
|
740
|
+
|
|
741
|
+
@staticmethod
|
|
742
|
+
def estimate_memory_usage(
|
|
743
|
+
n_samples: int,
|
|
744
|
+
n_features: int,
|
|
745
|
+
data_type: str = "float64",
|
|
746
|
+
) -> dict[str, float]:
|
|
747
|
+
"""Estimate memory usage for different operations.
|
|
748
|
+
|
|
749
|
+
Parameters
|
|
750
|
+
----------
|
|
751
|
+
n_samples : int
|
|
752
|
+
Number of samples
|
|
753
|
+
n_features : int
|
|
754
|
+
Number of features
|
|
755
|
+
data_type : str, default "float64"
|
|
756
|
+
Data type (float64, float32, int64, etc.)
|
|
757
|
+
|
|
758
|
+
Returns
|
|
759
|
+
-------
|
|
760
|
+
dict
|
|
761
|
+
Memory usage estimates in MB
|
|
762
|
+
"""
|
|
763
|
+
# Bytes per element
|
|
764
|
+
type_sizes = {"float64": 8, "float32": 4, "int64": 8, "int32": 4, "bool": 1}
|
|
765
|
+
|
|
766
|
+
bytes_per_element = type_sizes.get(data_type, 8)
|
|
767
|
+
|
|
768
|
+
# Basic DataFrame memory
|
|
769
|
+
base_memory_mb = (n_samples * n_features * bytes_per_element) / (1024 * 1024)
|
|
770
|
+
|
|
771
|
+
# Rolling operations typically need 2-3x memory for intermediate calculations
|
|
772
|
+
rolling_memory_mb = base_memory_mb * 2.5
|
|
773
|
+
|
|
774
|
+
# Multi-horizon IC needs additional memory for ranks and correlations
|
|
775
|
+
ic_memory_mb = base_memory_mb * 3.0
|
|
776
|
+
|
|
777
|
+
return {
|
|
778
|
+
"base_dataframe_mb": base_memory_mb,
|
|
779
|
+
"rolling_operations_mb": rolling_memory_mb,
|
|
780
|
+
"multi_horizon_ic_mb": ic_memory_mb,
|
|
781
|
+
"recommended_chunk_size": max(
|
|
782
|
+
10000,
|
|
783
|
+
min(100000, int(500 * 1024 * 1024 / (n_features * bytes_per_element))),
|
|
784
|
+
),
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
@staticmethod
|
|
788
|
+
def adaptive_chunk_size(
|
|
789
|
+
total_samples: int,
|
|
790
|
+
n_features: int = 1,
|
|
791
|
+
target_memory_mb: int = 500,
|
|
792
|
+
min_chunk_size: int = 10000,
|
|
793
|
+
max_chunk_size: int = 100000,
|
|
794
|
+
) -> int:
|
|
795
|
+
"""Calculate optimal chunk size based on available memory.
|
|
796
|
+
|
|
797
|
+
Parameters
|
|
798
|
+
----------
|
|
799
|
+
total_samples : int
|
|
800
|
+
Total number of samples in dataset
|
|
801
|
+
n_features : int, default 1
|
|
802
|
+
Number of features (affects memory per sample)
|
|
803
|
+
target_memory_mb : int, default 500
|
|
804
|
+
Target memory usage in MB
|
|
805
|
+
min_chunk_size : int, default 10000
|
|
806
|
+
Minimum chunk size
|
|
807
|
+
max_chunk_size : int, default 100000
|
|
808
|
+
Maximum chunk size
|
|
809
|
+
|
|
810
|
+
Returns
|
|
811
|
+
-------
|
|
812
|
+
int
|
|
813
|
+
Optimal chunk size
|
|
814
|
+
"""
|
|
815
|
+
# Estimate memory per sample (assuming float64)
|
|
816
|
+
memory_per_sample = n_features * 8 * 2.5 # 2.5x factor for processing overhead
|
|
817
|
+
|
|
818
|
+
# Calculate chunk size to fit in target memory
|
|
819
|
+
target_chunk_size = int((target_memory_mb * 1024 * 1024) / memory_per_sample)
|
|
820
|
+
|
|
821
|
+
# Apply bounds
|
|
822
|
+
chunk_size = max(min_chunk_size, min(max_chunk_size, target_chunk_size))
|
|
823
|
+
|
|
824
|
+
# Don't chunk if dataset is small
|
|
825
|
+
if total_samples <= max_chunk_size:
|
|
826
|
+
return total_samples
|
|
827
|
+
|
|
828
|
+
return chunk_size
|
|
829
|
+
|
|
830
|
+
@staticmethod
|
|
831
|
+
def memory_efficient_operation(
|
|
832
|
+
data: pl.DataFrame,
|
|
833
|
+
operation_func: Callable[..., pl.DataFrame],
|
|
834
|
+
chunk_size: int | None = None,
|
|
835
|
+
overlap: int = 0,
|
|
836
|
+
**kwargs,
|
|
837
|
+
) -> pl.DataFrame:
|
|
838
|
+
"""Apply an operation to large DataFrame in memory-efficient chunks.
|
|
839
|
+
|
|
840
|
+
This is a generic streaming framework that can be used for any
|
|
841
|
+
operation that can be applied to DataFrame chunks.
|
|
842
|
+
|
|
843
|
+
Parameters
|
|
844
|
+
----------
|
|
845
|
+
data : pl.DataFrame
|
|
846
|
+
Input DataFrame
|
|
847
|
+
operation_func : callable
|
|
848
|
+
Function to apply to each chunk
|
|
849
|
+
chunk_size : int, optional
|
|
850
|
+
Chunk size (auto-calculated if None)
|
|
851
|
+
overlap : int, default 0
|
|
852
|
+
Number of rows to overlap between chunks
|
|
853
|
+
**kwargs
|
|
854
|
+
Additional arguments passed to operation_func
|
|
855
|
+
|
|
856
|
+
Returns
|
|
857
|
+
-------
|
|
858
|
+
pl.DataFrame
|
|
859
|
+
Result of applying operation to entire DataFrame
|
|
860
|
+
|
|
861
|
+
Examples
|
|
862
|
+
--------
|
|
863
|
+
>>> def rolling_mean_op(chunk_df, window=10):
|
|
864
|
+
... return chunk_df.select([
|
|
865
|
+
... pl.col("value").rolling_mean(window).alias("rolling_mean")
|
|
866
|
+
... ])
|
|
867
|
+
>>>
|
|
868
|
+
>>> result = PolarsBackend.memory_efficient_operation(
|
|
869
|
+
... large_df, rolling_mean_op, overlap=9, window=10
|
|
870
|
+
... )
|
|
871
|
+
"""
|
|
872
|
+
n_samples = len(data)
|
|
873
|
+
|
|
874
|
+
# Auto-calculate chunk size if not provided
|
|
875
|
+
if chunk_size is None:
|
|
876
|
+
chunk_size = PolarsBackend.adaptive_chunk_size(n_samples, data.shape[1])
|
|
877
|
+
|
|
878
|
+
# For small data, process directly
|
|
879
|
+
if n_samples <= chunk_size:
|
|
880
|
+
return operation_func(data, **kwargs)
|
|
881
|
+
|
|
882
|
+
results = []
|
|
883
|
+
|
|
884
|
+
for start in range(0, n_samples, chunk_size - overlap):
|
|
885
|
+
end = min(start + chunk_size, n_samples)
|
|
886
|
+
|
|
887
|
+
# Extract chunk
|
|
888
|
+
chunk = data[start:end]
|
|
889
|
+
|
|
890
|
+
# Apply operation
|
|
891
|
+
chunk_result = operation_func(chunk, **kwargs)
|
|
892
|
+
|
|
893
|
+
# Handle overlap
|
|
894
|
+
if start > 0 and overlap > 0:
|
|
895
|
+
chunk_result = chunk_result[overlap:]
|
|
896
|
+
|
|
897
|
+
results.append(chunk_result)
|
|
898
|
+
|
|
899
|
+
return pl.concat(results)
|