spotforecast2 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2/.DS_Store +0 -0
- spotforecast2/forecaster/.DS_Store +0 -0
- spotforecast2/forecaster/utils.py +17 -1
- spotforecast2/processing/n2n_predict.py +19 -10
- spotforecast2/processing/n2n_predict_with_covariates.py +937 -0
- spotforecast2/weather/__init__.py +5 -0
- {spotforecast2-0.0.2.dist-info → spotforecast2-0.0.3.dist-info}/METADATA +1 -1
- {spotforecast2-0.0.2.dist-info → spotforecast2-0.0.3.dist-info}/RECORD +9 -6
- {spotforecast2-0.0.2.dist-info → spotforecast2-0.0.3.dist-info}/WHEEL +1 -1
spotforecast2/.DS_Store
ADDED
|
Binary file
|
|
Binary file
|
|
@@ -20,6 +20,11 @@ from spotforecast2.utils import (
|
|
|
20
20
|
)
|
|
21
21
|
from spotforecast2.exceptions import set_skforecast_warnings, UnknownLevelWarning
|
|
22
22
|
|
|
23
|
+
try:
|
|
24
|
+
from tqdm.auto import tqdm
|
|
25
|
+
except ImportError: # pragma: no cover - fallback when tqdm is not installed
|
|
26
|
+
tqdm = None
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
def check_preprocess_series(series):
|
|
25
30
|
pass
|
|
@@ -785,6 +790,7 @@ def predict_multivariate(
|
|
|
785
790
|
forecasters: dict[str, Any],
|
|
786
791
|
steps_ahead: int,
|
|
787
792
|
exog: pd.DataFrame | None = None,
|
|
793
|
+
show_progress: bool = False,
|
|
788
794
|
) -> pd.DataFrame:
|
|
789
795
|
"""
|
|
790
796
|
Generate multi-output predictions using multiple baseline forecasters.
|
|
@@ -796,6 +802,8 @@ def predict_multivariate(
|
|
|
796
802
|
steps_ahead (int): Number of steps to forecast.
|
|
797
803
|
exog (pd.DataFrame, optional): Exogenous variables for prediction.
|
|
798
804
|
If provided, will be passed to each forecaster's predict method.
|
|
805
|
+
show_progress (bool, optional): Show progress bar while predicting
|
|
806
|
+
per target forecaster. Default: False.
|
|
799
807
|
|
|
800
808
|
Returns:
|
|
801
809
|
pd.DataFrame: DataFrame with predictions for all targets.
|
|
@@ -824,7 +832,15 @@ def predict_multivariate(
|
|
|
824
832
|
|
|
825
833
|
predictions = {}
|
|
826
834
|
|
|
827
|
-
|
|
835
|
+
target_iter = forecasters.items()
|
|
836
|
+
if show_progress and tqdm is not None:
|
|
837
|
+
target_iter = tqdm(
|
|
838
|
+
forecasters.items(),
|
|
839
|
+
desc="Predicting targets",
|
|
840
|
+
unit="model",
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
for target, forecaster in target_iter:
|
|
828
844
|
# Generate predictions for this target
|
|
829
845
|
if exog is not None:
|
|
830
846
|
pred = forecaster.predict(steps=steps_ahead, exog=exog)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
-
from typing import List, Optional
|
|
2
|
+
from typing import List, Optional
|
|
3
3
|
from spotforecast2.forecaster.recursive import ForecasterEquivalentDate
|
|
4
4
|
from spotforecast2.data.fetch_data import fetch_data
|
|
5
5
|
from spotforecast2.preprocessing.curate_data import basic_ts_checks
|
|
@@ -8,9 +8,13 @@ from spotforecast2.preprocessing.outlier import mark_outliers
|
|
|
8
8
|
|
|
9
9
|
from spotforecast2.preprocessing.split import split_rel_train_val_test
|
|
10
10
|
from spotforecast2.forecaster.utils import predict_multivariate
|
|
11
|
-
from spotforecast2.model_selection import TimeSeriesFold, backtesting_forecaster
|
|
12
11
|
from spotforecast2.preprocessing.curate_data import get_start_end
|
|
13
12
|
|
|
13
|
+
try:
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
except ImportError: # pragma: no cover - fallback when tqdm is not installed
|
|
16
|
+
tqdm = None
|
|
17
|
+
|
|
14
18
|
|
|
15
19
|
def n2n_predict(
|
|
16
20
|
columns: Optional[List[str]] = None,
|
|
@@ -18,22 +22,22 @@ def n2n_predict(
|
|
|
18
22
|
contamination: float = 0.01,
|
|
19
23
|
window_size: int = 72,
|
|
20
24
|
verbose: bool = True,
|
|
21
|
-
|
|
25
|
+
show_progress: bool = True,
|
|
26
|
+
) -> pd.DataFrame:
|
|
22
27
|
"""
|
|
23
28
|
End-to-end prediction function replicating the workflow from 01_base_predictor combined with fetch_data.
|
|
24
29
|
|
|
25
30
|
Args:
|
|
26
31
|
columns: List of target columns to forecast. If None, uses a default set (defined internally or from data).
|
|
27
|
-
Note: fetch_data
|
|
32
|
+
Note: fetch_data supports None to return all columns.
|
|
28
33
|
forecast_horizon: Number of steps to forecast.
|
|
29
34
|
contamination: Contamination factor for outlier detection.
|
|
30
35
|
window_size: Window size for weighting (not fully utilized in main flow but kept for consistency).
|
|
31
36
|
verbose: Whether to print progress logs.
|
|
37
|
+
show_progress: Show progress bar during training and prediction.
|
|
32
38
|
|
|
33
39
|
Returns:
|
|
34
|
-
|
|
35
|
-
- predictions (pd.DataFrame): The multi-output predictions.
|
|
36
|
-
- metrics (Optional[Dict]): Dictionary containing backtesting metrics if performed.
|
|
40
|
+
pd.DataFrame: The multi-output predictions.
|
|
37
41
|
"""
|
|
38
42
|
if columns is not None:
|
|
39
43
|
TARGET = columns
|
|
@@ -95,7 +99,11 @@ def n2n_predict(
|
|
|
95
99
|
|
|
96
100
|
baseline_forecasters = {}
|
|
97
101
|
|
|
98
|
-
|
|
102
|
+
target_iter = data.columns
|
|
103
|
+
if show_progress and tqdm is not None:
|
|
104
|
+
target_iter = tqdm(data.columns, desc="Training forecasters", unit="model")
|
|
105
|
+
|
|
106
|
+
for target in target_iter:
|
|
99
107
|
forecaster = ForecasterEquivalentDate(offset=pd.DateOffset(days=1), n_offsets=1)
|
|
100
108
|
|
|
101
109
|
forecaster.fit(y=data.loc[:end_validation, target])
|
|
@@ -105,13 +113,14 @@ def n2n_predict(
|
|
|
105
113
|
if verbose:
|
|
106
114
|
print("✓ Multi-output baseline system trained")
|
|
107
115
|
|
|
108
|
-
|
|
109
116
|
# --- Predict ---
|
|
110
117
|
if verbose:
|
|
111
118
|
print("Generating predictions...")
|
|
112
119
|
|
|
113
120
|
predictions = predict_multivariate(
|
|
114
|
-
baseline_forecasters,
|
|
121
|
+
baseline_forecasters,
|
|
122
|
+
steps_ahead=forecast_horizon,
|
|
123
|
+
show_progress=show_progress,
|
|
115
124
|
)
|
|
116
125
|
|
|
117
126
|
return predictions
|
|
@@ -0,0 +1,937 @@
|
|
|
1
|
+
"""
|
|
2
|
+
End-to-end recursive forecasting with exogenous covariates.
|
|
3
|
+
|
|
4
|
+
This module provides a complete pipeline for time series forecasting using
|
|
5
|
+
recursive forecasters with exogenous variables (weather, holidays, calendar features).
|
|
6
|
+
It handles data preparation, feature engineering, model training, and prediction
|
|
7
|
+
in a single integrated function.
|
|
8
|
+
|
|
9
|
+
Examples:
|
|
10
|
+
Basic usage with default parameters:
|
|
11
|
+
|
|
12
|
+
>>> from spotforecast2.processing.n2n_predict_with_covariates import (
|
|
13
|
+
... n2n_predict_with_covariates
|
|
14
|
+
... )
|
|
15
|
+
>>> predictions = n2n_predict_with_covariates(
|
|
16
|
+
... forecast_horizon=24,
|
|
17
|
+
... verbose=True
|
|
18
|
+
... )
|
|
19
|
+
|
|
20
|
+
With custom parameters:
|
|
21
|
+
|
|
22
|
+
>>> predictions = n2n_predict_with_covariates(
|
|
23
|
+
... forecast_horizon=48,
|
|
24
|
+
... contamination=0.02,
|
|
25
|
+
... window_size=100,
|
|
26
|
+
... lags=48,
|
|
27
|
+
... train_ratio=0.75,
|
|
28
|
+
... verbose=True
|
|
29
|
+
... )
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
import pandas as pd
|
|
36
|
+
from astral import LocationInfo
|
|
37
|
+
from lightgbm import LGBMRegressor
|
|
38
|
+
from sklearn.preprocessing import PolynomialFeatures
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
from tqdm.auto import tqdm
|
|
42
|
+
except ImportError: # pragma: no cover - fallback when tqdm is not installed
|
|
43
|
+
tqdm = None
|
|
44
|
+
|
|
45
|
+
from spotforecast2.data.fetch_data import (
|
|
46
|
+
fetch_data,
|
|
47
|
+
fetch_holiday_data,
|
|
48
|
+
fetch_weather_data,
|
|
49
|
+
)
|
|
50
|
+
from spotforecast2.forecaster.recursive import ForecasterRecursive
|
|
51
|
+
from spotforecast2.forecaster.utils import predict_multivariate
|
|
52
|
+
from spotforecast2.preprocessing import RollingFeatures
|
|
53
|
+
from spotforecast2.preprocessing.curate_data import (
|
|
54
|
+
agg_and_resample_data,
|
|
55
|
+
basic_ts_checks,
|
|
56
|
+
curate_holidays,
|
|
57
|
+
curate_weather,
|
|
58
|
+
get_start_end,
|
|
59
|
+
)
|
|
60
|
+
from spotforecast2.preprocessing.imputation import custom_weights, get_missing_weights
|
|
61
|
+
from spotforecast2.preprocessing.outlier import mark_outliers
|
|
62
|
+
from spotforecast2.preprocessing.split import split_rel_train_val_test
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
from feature_engine.creation import CyclicalFeatures
|
|
66
|
+
from feature_engine.datetime import DatetimeFeatures
|
|
67
|
+
from feature_engine.timeseries.forecasting import WindowFeatures
|
|
68
|
+
except ImportError:
|
|
69
|
+
raise ImportError(
|
|
70
|
+
"feature_engine is required. Install with: pip install feature-engine"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
from astral.sun import sun
|
|
75
|
+
except ImportError:
|
|
76
|
+
raise ImportError("astral is required. Install with: pip install astral")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ============================================================================
|
|
80
|
+
# Helper Functions for Feature Engineering
|
|
81
|
+
# ============================================================================
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_weather_features(
|
|
85
|
+
data: pd.DataFrame,
|
|
86
|
+
start: Union[str, pd.Timestamp],
|
|
87
|
+
cov_end: Union[str, pd.Timestamp],
|
|
88
|
+
forecast_horizon: int,
|
|
89
|
+
latitude: float = 51.5136,
|
|
90
|
+
longitude: float = 7.4653,
|
|
91
|
+
timezone: str = "UTC",
|
|
92
|
+
freq: str = "h",
|
|
93
|
+
window_periods: Optional[List[str]] = None,
|
|
94
|
+
window_functions: Optional[List[str]] = None,
|
|
95
|
+
fallback_on_failure: bool = True,
|
|
96
|
+
cached: bool = True,
|
|
97
|
+
verbose: bool = False,
|
|
98
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
99
|
+
"""Fetch and process weather data with rolling window features.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
data: Time series DataFrame for validation.
|
|
103
|
+
start: Start date for weather data.
|
|
104
|
+
cov_end: End date for weather data.
|
|
105
|
+
forecast_horizon: Number of forecast steps.
|
|
106
|
+
latitude: Latitude of location. Default: 51.5136 (Dortmund).
|
|
107
|
+
longitude: Longitude of location. Default: 7.4653 (Dortmund).
|
|
108
|
+
timezone: Timezone for data. Default: "UTC".
|
|
109
|
+
freq: Frequency of time series. Default: "h".
|
|
110
|
+
window_periods: Window periods for rolling features. Default: ["1D", "7D"].
|
|
111
|
+
window_functions: Functions for rolling windows. Default: ["mean", "max", "min"].
|
|
112
|
+
fallback_on_failure: Use fallback if API fails. Default: True.
|
|
113
|
+
cached: Use cached data if available. Default: True.
|
|
114
|
+
verbose: Print progress. Default: False.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Tuple of (weather_features, weather_aligned).
|
|
118
|
+
"""
|
|
119
|
+
if window_periods is None:
|
|
120
|
+
window_periods = ["1D", "7D"]
|
|
121
|
+
if window_functions is None:
|
|
122
|
+
window_functions = ["mean", "max", "min"]
|
|
123
|
+
|
|
124
|
+
if isinstance(start, str):
|
|
125
|
+
start = pd.to_datetime(start, utc=True)
|
|
126
|
+
if isinstance(cov_end, str):
|
|
127
|
+
cov_end = pd.to_datetime(cov_end, utc=True)
|
|
128
|
+
|
|
129
|
+
if verbose:
|
|
130
|
+
print("Fetching weather data...")
|
|
131
|
+
|
|
132
|
+
weather_df = fetch_weather_data(
|
|
133
|
+
cov_start=start,
|
|
134
|
+
cov_end=cov_end,
|
|
135
|
+
latitude=latitude,
|
|
136
|
+
longitude=longitude,
|
|
137
|
+
timezone=timezone,
|
|
138
|
+
freq=freq,
|
|
139
|
+
fallback_on_failure=fallback_on_failure,
|
|
140
|
+
cached=cached,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
curate_weather(weather_df, data, forecast_horizon=forecast_horizon)
|
|
144
|
+
|
|
145
|
+
if verbose:
|
|
146
|
+
print("Processing weather features...")
|
|
147
|
+
|
|
148
|
+
extended_index = pd.date_range(start=start, end=cov_end, freq=freq, tz=timezone)
|
|
149
|
+
weather_aligned = weather_df.reindex(extended_index, method="ffill")
|
|
150
|
+
|
|
151
|
+
weather_columns = weather_aligned.select_dtypes(
|
|
152
|
+
include=[np.number]
|
|
153
|
+
).columns.tolist()
|
|
154
|
+
|
|
155
|
+
if len(weather_columns) == 0:
|
|
156
|
+
raise ValueError("No numeric weather columns found")
|
|
157
|
+
|
|
158
|
+
weather_aligned_filled = weather_aligned[weather_columns].copy()
|
|
159
|
+
if weather_aligned_filled.isnull().any().any():
|
|
160
|
+
weather_aligned_filled = weather_aligned_filled.bfill()
|
|
161
|
+
if weather_aligned_filled.isnull().any().any():
|
|
162
|
+
raise ValueError("Missing values in weather data could not be filled")
|
|
163
|
+
|
|
164
|
+
wf_transformer = WindowFeatures(
|
|
165
|
+
variables=weather_columns,
|
|
166
|
+
window=window_periods,
|
|
167
|
+
functions=window_functions,
|
|
168
|
+
freq=freq,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
weather_features = wf_transformer.fit_transform(weather_aligned_filled)
|
|
172
|
+
|
|
173
|
+
if weather_features.isnull().any().any():
|
|
174
|
+
weather_features = weather_features.bfill()
|
|
175
|
+
if weather_features.isnull().any().any():
|
|
176
|
+
raise ValueError("Missing values in weather features could not be filled")
|
|
177
|
+
|
|
178
|
+
if verbose:
|
|
179
|
+
print(f"Weather features shape: {weather_features.shape}")
|
|
180
|
+
|
|
181
|
+
return weather_features, weather_aligned
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _get_calendar_features(
|
|
185
|
+
start: Union[str, pd.Timestamp],
|
|
186
|
+
cov_end: Union[str, pd.Timestamp],
|
|
187
|
+
freq: str = "h",
|
|
188
|
+
timezone: str = "UTC",
|
|
189
|
+
features_to_extract: Optional[List[str]] = None,
|
|
190
|
+
) -> pd.DataFrame:
|
|
191
|
+
"""Create calendar-based features for a time range.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
start: Start date.
|
|
195
|
+
cov_end: End date.
|
|
196
|
+
freq: Frequency. Default: "h".
|
|
197
|
+
timezone: Timezone. Default: "UTC".
|
|
198
|
+
features_to_extract: Features to extract. Default: ["month", "week", "day_of_week", "hour"].
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
DataFrame with calendar features.
|
|
202
|
+
"""
|
|
203
|
+
if features_to_extract is None:
|
|
204
|
+
features_to_extract = ["month", "week", "day_of_week", "hour"]
|
|
205
|
+
|
|
206
|
+
if isinstance(start, str):
|
|
207
|
+
start = pd.to_datetime(start, utc=True)
|
|
208
|
+
if isinstance(cov_end, str):
|
|
209
|
+
cov_end = pd.to_datetime(cov_end, utc=True)
|
|
210
|
+
|
|
211
|
+
calendar_transformer = DatetimeFeatures(
|
|
212
|
+
variables="index",
|
|
213
|
+
features_to_extract=features_to_extract,
|
|
214
|
+
drop_original=True,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
extended_index = pd.date_range(start=start, end=cov_end, freq=freq, tz=timezone)
|
|
218
|
+
extended_data = pd.DataFrame(index=extended_index)
|
|
219
|
+
extended_data["dummy"] = 0
|
|
220
|
+
|
|
221
|
+
return calendar_transformer.fit_transform(extended_data)[features_to_extract]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _get_day_night_features(
|
|
225
|
+
start: Union[str, pd.Timestamp],
|
|
226
|
+
cov_end: Union[str, pd.Timestamp],
|
|
227
|
+
location: LocationInfo,
|
|
228
|
+
freq: str = "h",
|
|
229
|
+
timezone: str = "UTC",
|
|
230
|
+
) -> pd.DataFrame:
|
|
231
|
+
"""Create day/night features using sunrise and sunset times.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
start: Start date.
|
|
235
|
+
cov_end: End date.
|
|
236
|
+
location: Astral LocationInfo object.
|
|
237
|
+
freq: Frequency. Default: "h".
|
|
238
|
+
timezone: Timezone. Default: "UTC".
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
DataFrame with sunrise/sunset and daylight features.
|
|
242
|
+
"""
|
|
243
|
+
if isinstance(start, str):
|
|
244
|
+
start = pd.to_datetime(start, utc=True)
|
|
245
|
+
if isinstance(cov_end, str):
|
|
246
|
+
cov_end = pd.to_datetime(cov_end, utc=True)
|
|
247
|
+
|
|
248
|
+
extended_index = pd.date_range(start=start, end=cov_end, freq=freq, tz=timezone)
|
|
249
|
+
|
|
250
|
+
sunrise_hour = [
|
|
251
|
+
sun(location.observer, date=date, tzinfo=location.timezone)["sunrise"]
|
|
252
|
+
for date in extended_index
|
|
253
|
+
]
|
|
254
|
+
sunset_hour = [
|
|
255
|
+
sun(location.observer, date=date, tzinfo=location.timezone)["sunset"]
|
|
256
|
+
for date in extended_index
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
sunrise_hour = pd.Series(sunrise_hour, index=extended_index).dt.round("h").dt.hour
|
|
260
|
+
sunset_hour = pd.Series(sunset_hour, index=extended_index).dt.round("h").dt.hour
|
|
261
|
+
|
|
262
|
+
sun_light_features = pd.DataFrame(
|
|
263
|
+
{
|
|
264
|
+
"sunrise_hour": sunrise_hour,
|
|
265
|
+
"sunset_hour": sunset_hour,
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
sun_light_features["daylight_hours"] = (
|
|
269
|
+
sun_light_features["sunset_hour"] - sun_light_features["sunrise_hour"]
|
|
270
|
+
)
|
|
271
|
+
sun_light_features["is_daylight"] = np.where(
|
|
272
|
+
(extended_index.hour >= sun_light_features["sunrise_hour"])
|
|
273
|
+
& (extended_index.hour < sun_light_features["sunset_hour"]),
|
|
274
|
+
1,
|
|
275
|
+
0,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return sun_light_features
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _get_holiday_features(
|
|
282
|
+
data: pd.DataFrame,
|
|
283
|
+
start: Union[str, pd.Timestamp],
|
|
284
|
+
cov_end: Union[str, pd.Timestamp],
|
|
285
|
+
forecast_horizon: int,
|
|
286
|
+
tz: str = "UTC",
|
|
287
|
+
freq: str = "h",
|
|
288
|
+
country_code: str = "DE",
|
|
289
|
+
state: str = "NW",
|
|
290
|
+
) -> pd.DataFrame:
|
|
291
|
+
"""Fetch and align holiday features to the extended time index.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
data: Target time series for validation.
|
|
295
|
+
start: Start timestamp.
|
|
296
|
+
cov_end: End timestamp.
|
|
297
|
+
forecast_horizon: Number of forecast steps.
|
|
298
|
+
tz: Timezone. Default: "UTC".
|
|
299
|
+
freq: Frequency. Default: "h".
|
|
300
|
+
country_code: Country code. Default: "DE".
|
|
301
|
+
state: State code. Default: "NW".
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
DataFrame with holiday features.
|
|
305
|
+
"""
|
|
306
|
+
if isinstance(start, str):
|
|
307
|
+
start = pd.to_datetime(start, utc=True)
|
|
308
|
+
if isinstance(cov_end, str):
|
|
309
|
+
cov_end = pd.to_datetime(cov_end, utc=True)
|
|
310
|
+
|
|
311
|
+
holiday_df = fetch_holiday_data(
|
|
312
|
+
start=start,
|
|
313
|
+
end=cov_end,
|
|
314
|
+
tz=tz,
|
|
315
|
+
freq=freq,
|
|
316
|
+
country_code=country_code,
|
|
317
|
+
state=state,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
curate_holidays(holiday_df, data, forecast_horizon=forecast_horizon)
|
|
321
|
+
|
|
322
|
+
extended_index = pd.date_range(start=start, end=cov_end, freq=freq, tz=tz)
|
|
323
|
+
holiday_features = holiday_df.reindex(extended_index, fill_value=0).astype(int)
|
|
324
|
+
|
|
325
|
+
return holiday_features
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _apply_cyclical_encoding(
|
|
329
|
+
data: pd.DataFrame,
|
|
330
|
+
features_to_encode: Optional[List[str]] = None,
|
|
331
|
+
max_values: Optional[Dict[str, int]] = None,
|
|
332
|
+
drop_original: bool = False,
|
|
333
|
+
) -> pd.DataFrame:
|
|
334
|
+
"""Apply cyclical encoding to selected features.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
data: DataFrame with features.
|
|
338
|
+
features_to_encode: Features to encode. Default: calendar and sun features.
|
|
339
|
+
max_values: Max values for features. Default: standard calendar/hour ranges.
|
|
340
|
+
drop_original: Drop original columns. Default: False.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
DataFrame with cyclical encoded features.
|
|
344
|
+
"""
|
|
345
|
+
if features_to_encode is None:
|
|
346
|
+
features_to_encode = [
|
|
347
|
+
"month",
|
|
348
|
+
"week",
|
|
349
|
+
"day_of_week",
|
|
350
|
+
"hour",
|
|
351
|
+
"sunrise_hour",
|
|
352
|
+
"sunset_hour",
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
if max_values is None:
|
|
356
|
+
max_values = {
|
|
357
|
+
"month": 12,
|
|
358
|
+
"week": 52,
|
|
359
|
+
"day_of_week": 6,
|
|
360
|
+
"hour": 24,
|
|
361
|
+
"sunrise_hour": 24,
|
|
362
|
+
"sunset_hour": 24,
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
# Filter features_to_encode to only those that exist in data
|
|
366
|
+
available_features = [f for f in features_to_encode if f in data.columns]
|
|
367
|
+
available_max_values = {
|
|
368
|
+
k: v for k, v in max_values.items() if k in available_features
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
cyclical_encoder = CyclicalFeatures(
|
|
372
|
+
variables=available_features,
|
|
373
|
+
max_values=available_max_values,
|
|
374
|
+
drop_original=drop_original,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
return cyclical_encoder.fit_transform(data)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _create_interaction_features(
|
|
381
|
+
exogenous_features: pd.DataFrame,
|
|
382
|
+
weather_aligned: pd.DataFrame,
|
|
383
|
+
base_cols: Optional[List[str]] = None,
|
|
384
|
+
weather_window_pattern: str = "_window_",
|
|
385
|
+
include_weather_funcs: Optional[List[str]] = None,
|
|
386
|
+
holiday_col: str = "holiday",
|
|
387
|
+
degree: int = 1,
|
|
388
|
+
) -> pd.DataFrame:
|
|
389
|
+
"""Create interaction features from exogenous features.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
exogenous_features: DataFrame with base features.
|
|
393
|
+
weather_aligned: DataFrame with raw weather columns.
|
|
394
|
+
base_cols: Base columns for interactions. Default: day_of_week and hour cyclical features.
|
|
395
|
+
weather_window_pattern: Pattern for weather window features. Default: "_window_".
|
|
396
|
+
include_weather_funcs: Functions to include. Default: ["_mean", "_min", "_max"].
|
|
397
|
+
holiday_col: Holiday column name. Default: "holiday".
|
|
398
|
+
degree: Polynomial degree. Default: 1.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
DataFrame with interaction features appended.
|
|
402
|
+
"""
|
|
403
|
+
if base_cols is None:
|
|
404
|
+
base_cols = [
|
|
405
|
+
"day_of_week_sin",
|
|
406
|
+
"day_of_week_cos",
|
|
407
|
+
"hour_sin",
|
|
408
|
+
"hour_cos",
|
|
409
|
+
]
|
|
410
|
+
|
|
411
|
+
if include_weather_funcs is None:
|
|
412
|
+
include_weather_funcs = ["_mean", "_min", "_max"]
|
|
413
|
+
|
|
414
|
+
transformer_poly = PolynomialFeatures(
|
|
415
|
+
degree=degree, interaction_only=True, include_bias=False
|
|
416
|
+
).set_output(transform="pandas")
|
|
417
|
+
|
|
418
|
+
weather_window_cols = [
|
|
419
|
+
col
|
|
420
|
+
for col in exogenous_features.columns
|
|
421
|
+
if weather_window_pattern in col
|
|
422
|
+
and any(func in col for func in include_weather_funcs)
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
raw_weather_cols = [
|
|
426
|
+
col
|
|
427
|
+
for col in exogenous_features.columns
|
|
428
|
+
if col in weather_aligned.columns and col not in weather_window_cols
|
|
429
|
+
]
|
|
430
|
+
|
|
431
|
+
poly_cols = list(base_cols)
|
|
432
|
+
poly_cols.extend(weather_window_cols)
|
|
433
|
+
poly_cols.extend(raw_weather_cols)
|
|
434
|
+
if holiday_col in exogenous_features.columns:
|
|
435
|
+
poly_cols.append(holiday_col)
|
|
436
|
+
|
|
437
|
+
poly_features = transformer_poly.fit_transform(exogenous_features[poly_cols])
|
|
438
|
+
poly_features = poly_features.drop(columns=poly_cols)
|
|
439
|
+
poly_features.columns = [f"poly_{col}" for col in poly_features.columns]
|
|
440
|
+
poly_features.columns = poly_features.columns.str.replace(" ", "__")
|
|
441
|
+
|
|
442
|
+
return pd.concat([exogenous_features, poly_features], axis=1)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _select_exogenous_features(
|
|
446
|
+
exogenous_features: pd.DataFrame,
|
|
447
|
+
weather_aligned: pd.DataFrame,
|
|
448
|
+
cyclical_regex: str = "_sin$|_cos$",
|
|
449
|
+
include_weather_windows: bool = False,
|
|
450
|
+
include_holiday_features: bool = False,
|
|
451
|
+
include_poly_features: bool = False,
|
|
452
|
+
) -> List[str]:
|
|
453
|
+
"""Select exogenous feature columns for model training.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
exogenous_features: DataFrame with all features.
|
|
457
|
+
weather_aligned: DataFrame with raw weather columns.
|
|
458
|
+
cyclical_regex: Regex for cyclical features. Default: "_sin$|_cos$".
|
|
459
|
+
include_weather_windows: Include weather window features. Default: False.
|
|
460
|
+
include_holiday_features: Include holiday features. Default: False.
|
|
461
|
+
include_poly_features: Include polynomial features. Default: False.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
List of selected feature column names.
|
|
465
|
+
"""
|
|
466
|
+
exog_features: List[str] = []
|
|
467
|
+
|
|
468
|
+
exog_features.extend(
|
|
469
|
+
exogenous_features.filter(regex=cyclical_regex).columns.tolist()
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
if include_weather_windows:
|
|
473
|
+
weather_window_features = [
|
|
474
|
+
col
|
|
475
|
+
for col in exogenous_features.columns
|
|
476
|
+
if "_window_" in col and ("_mean" in col or "_min" in col or "_max" in col)
|
|
477
|
+
]
|
|
478
|
+
exog_features.extend(weather_window_features)
|
|
479
|
+
|
|
480
|
+
raw_weather_features = [
|
|
481
|
+
col for col in exogenous_features.columns if col in weather_aligned.columns
|
|
482
|
+
]
|
|
483
|
+
exog_features.extend(raw_weather_features)
|
|
484
|
+
|
|
485
|
+
if include_holiday_features:
|
|
486
|
+
holiday_related = [
|
|
487
|
+
col for col in exogenous_features.columns if col.startswith("holiday")
|
|
488
|
+
]
|
|
489
|
+
exog_features.extend(holiday_related)
|
|
490
|
+
|
|
491
|
+
if include_poly_features:
|
|
492
|
+
poly_features_list = [
|
|
493
|
+
col for col in exogenous_features.columns if col.startswith("poly_")
|
|
494
|
+
]
|
|
495
|
+
exog_features.extend(poly_features_list)
|
|
496
|
+
|
|
497
|
+
return list(dict.fromkeys(exog_features))
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _merge_data_and_covariates(
|
|
501
|
+
data: pd.DataFrame,
|
|
502
|
+
exogenous_features: pd.DataFrame,
|
|
503
|
+
target_columns: List[str],
|
|
504
|
+
exog_features: List[str],
|
|
505
|
+
start: Union[str, pd.Timestamp],
|
|
506
|
+
end: Union[str, pd.Timestamp],
|
|
507
|
+
cov_end: Union[str, pd.Timestamp],
|
|
508
|
+
forecast_horizon: int,
|
|
509
|
+
cast_dtype: Optional[str] = "float32",
|
|
510
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
|
511
|
+
"""Merge target data with exogenous features and build prediction covariates.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
data: DataFrame with target variables.
|
|
515
|
+
exogenous_features: DataFrame with exogenous features.
|
|
516
|
+
target_columns: Target column names.
|
|
517
|
+
exog_features: Exogenous feature column names.
|
|
518
|
+
start: Start date.
|
|
519
|
+
end: End date.
|
|
520
|
+
cov_end: Covariate end date.
|
|
521
|
+
forecast_horizon: Number of forecast steps.
|
|
522
|
+
cast_dtype: Data type for merged data. Default: "float32".
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
Tuple of (data_with_exog, exo_tmp, exo_pred).
|
|
526
|
+
"""
|
|
527
|
+
if isinstance(start, str):
|
|
528
|
+
start = pd.to_datetime(start, utc=True)
|
|
529
|
+
if isinstance(end, str):
|
|
530
|
+
end = pd.to_datetime(end, utc=True)
|
|
531
|
+
if isinstance(cov_end, str):
|
|
532
|
+
cov_end = pd.to_datetime(cov_end, utc=True)
|
|
533
|
+
|
|
534
|
+
exo_tmp = exogenous_features.loc[start:end].copy()
|
|
535
|
+
exo_pred = exogenous_features.loc[end + pd.Timedelta(hours=1) : cov_end].copy()
|
|
536
|
+
|
|
537
|
+
data_with_exog = data[target_columns].merge(
|
|
538
|
+
exo_tmp[exog_features],
|
|
539
|
+
left_index=True,
|
|
540
|
+
right_index=True,
|
|
541
|
+
how="inner",
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
if cast_dtype is not None:
|
|
545
|
+
data_with_exog = data_with_exog.astype(cast_dtype)
|
|
546
|
+
|
|
547
|
+
return data_with_exog, exo_tmp, exo_pred
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
# ============================================================================
|
|
551
|
+
# Main Function
|
|
552
|
+
# ============================================================================
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def n2n_predict_with_covariates(
|
|
556
|
+
forecast_horizon: int = 24,
|
|
557
|
+
contamination: float = 0.01,
|
|
558
|
+
window_size: int = 72,
|
|
559
|
+
lags: int = 24,
|
|
560
|
+
train_ratio: float = 0.8,
|
|
561
|
+
latitude: float = 51.5136,
|
|
562
|
+
longitude: float = 7.4653,
|
|
563
|
+
timezone: str = "UTC",
|
|
564
|
+
country_code: str = "DE",
|
|
565
|
+
state: str = "NW",
|
|
566
|
+
estimator: Optional[object] = None,
|
|
567
|
+
include_weather_windows: bool = False,
|
|
568
|
+
include_holiday_features: bool = False,
|
|
569
|
+
include_poly_features: bool = False,
|
|
570
|
+
verbose: bool = True,
|
|
571
|
+
show_progress: bool = True,
|
|
572
|
+
) -> Tuple[pd.DataFrame, Dict, Dict]:
|
|
573
|
+
"""End-to-end recursive forecasting with exogenous covariates.
|
|
574
|
+
|
|
575
|
+
This function implements a complete forecasting pipeline that:
|
|
576
|
+
1. Loads and validates target data
|
|
577
|
+
2. Detects and removes outliers
|
|
578
|
+
3. Imputes missing values with weighted gaps
|
|
579
|
+
4. Creates exogenous features (weather, holidays, calendar, day/night)
|
|
580
|
+
5. Performs feature engineering (cyclical encoding, interactions)
|
|
581
|
+
6. Merges target and exogenous data
|
|
582
|
+
7. Splits into train/validation/test sets
|
|
583
|
+
8. Trains recursive forecasters with sample weighting
|
|
584
|
+
9. Generates multi-step ahead predictions
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
forecast_horizon: Number of time steps to forecast ahead. Default: 24.
|
|
588
|
+
contamination: Contamination parameter for outlier detection. Default: 0.01.
|
|
589
|
+
window_size: Rolling window size for gap detection. Default: 72.
|
|
590
|
+
lags: Number of lags for recursive forecaster. Default: 24.
|
|
591
|
+
train_ratio: Fraction of data for training. Default: 0.8.
|
|
592
|
+
latitude: Location latitude. Default: 51.5136 (Dortmund).
|
|
593
|
+
longitude: Location longitude. Default: 7.4653 (Dortmund).
|
|
594
|
+
timezone: Timezone for data. Default: "UTC".
|
|
595
|
+
country_code: Country code for holidays. Default: "DE".
|
|
596
|
+
state: State code for holidays. Default: "NW".
|
|
597
|
+
estimator: Base estimator for recursive forecaster.
|
|
598
|
+
If None, uses LGBMRegressor. Default: None.
|
|
599
|
+
include_weather_windows: Include weather window features. Default: False.
|
|
600
|
+
include_holiday_features: Include holiday features. Default: False.
|
|
601
|
+
include_poly_features: Include polynomial interaction features. Default: False.
|
|
602
|
+
verbose: Print progress messages. Default: True.
|
|
603
|
+
show_progress: Show progress bar during training. Default: True.
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
Tuple containing:
|
|
607
|
+
- predictions: DataFrame with forecast values for each target variable.
|
|
608
|
+
- metadata: Dictionary with forecast metadata (index, shapes, etc.).
|
|
609
|
+
- forecasters: Dictionary of trained ForecasterRecursive objects keyed by target.
|
|
610
|
+
|
|
611
|
+
Raises:
|
|
612
|
+
ValueError: If data validation fails or required data cannot be retrieved.
|
|
613
|
+
ImportError: If required dependencies are not installed.
|
|
614
|
+
|
|
615
|
+
Examples:
|
|
616
|
+
Basic usage:
|
|
617
|
+
|
|
618
|
+
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
619
|
+
... forecast_horizon=24,
|
|
620
|
+
... verbose=True
|
|
621
|
+
... )
|
|
622
|
+
>>> print(predictions.shape)
|
|
623
|
+
(24, 11)
|
|
624
|
+
|
|
625
|
+
Custom location and features:
|
|
626
|
+
|
|
627
|
+
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
628
|
+
... forecast_horizon=48,
|
|
629
|
+
... latitude=52.5200, # Berlin
|
|
630
|
+
... longitude=13.4050,
|
|
631
|
+
... lags=48,
|
|
632
|
+
... include_poly_features=True,
|
|
633
|
+
... verbose=True
|
|
634
|
+
... )
|
|
635
|
+
|
|
636
|
+
Notes:
|
|
637
|
+
- The function uses cached weather data when available.
|
|
638
|
+
- Missing values are handled via forward/backward fill with downweighting
|
|
639
|
+
observations near gaps.
|
|
640
|
+
- Sample weights are passed to the forecaster to penalize observations
|
|
641
|
+
near missing data.
|
|
642
|
+
- Train/validation splits are temporal (80/20 by default).
|
|
643
|
+
- All features are cast to float32 for memory efficiency.
|
|
644
|
+
"""
|
|
645
|
+
if verbose:
|
|
646
|
+
print("=" * 80)
|
|
647
|
+
print("N2N Recursive Forecasting with Exogenous Covariates")
|
|
648
|
+
print("=" * 80)
|
|
649
|
+
|
|
650
|
+
# ========================================================================
|
|
651
|
+
# 1. DATA PREPARATION
|
|
652
|
+
# ========================================================================
|
|
653
|
+
|
|
654
|
+
if verbose:
|
|
655
|
+
print("\n[1/9] Loading and preparing target data...")
|
|
656
|
+
|
|
657
|
+
data = fetch_data()
|
|
658
|
+
target_columns = data.columns.tolist()
|
|
659
|
+
|
|
660
|
+
if verbose:
|
|
661
|
+
print(f" Target variables: {target_columns}")
|
|
662
|
+
|
|
663
|
+
start, end, cov_start, cov_end = get_start_end(
|
|
664
|
+
data=data,
|
|
665
|
+
forecast_horizon=forecast_horizon,
|
|
666
|
+
verbose=verbose,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
basic_ts_checks(data, verbose=verbose)
|
|
670
|
+
data = agg_and_resample_data(data, verbose=verbose)
|
|
671
|
+
|
|
672
|
+
# ========================================================================
|
|
673
|
+
# 2. OUTLIER DETECTION AND REMOVAL
|
|
674
|
+
# ========================================================================
|
|
675
|
+
|
|
676
|
+
if verbose:
|
|
677
|
+
print("\n[2/9] Detecting and marking outliers...")
|
|
678
|
+
|
|
679
|
+
data, outliers = mark_outliers(
|
|
680
|
+
data,
|
|
681
|
+
contamination=contamination,
|
|
682
|
+
random_state=1234,
|
|
683
|
+
verbose=verbose,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# ========================================================================
|
|
687
|
+
# 3. MISSING VALUE IMPUTATION WITH WEIGHTING
|
|
688
|
+
# ========================================================================
|
|
689
|
+
|
|
690
|
+
if verbose:
|
|
691
|
+
print("\n[3/9] Processing missing values and creating sample weights...")
|
|
692
|
+
|
|
693
|
+
imputed_data, missing_mask = get_missing_weights(
|
|
694
|
+
data, window_size=window_size, verbose=verbose
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# Create weight function for forecaster
|
|
698
|
+
# Invert missing_mask: True (missing) -> 0 (weight), False (valid) -> 1 (weight)
|
|
699
|
+
weights_series = (~missing_mask).astype(float)
|
|
700
|
+
|
|
701
|
+
def weight_func(index):
|
|
702
|
+
"""Return sample weights for given index."""
|
|
703
|
+
return custom_weights(index, weights_series)
|
|
704
|
+
|
|
705
|
+
# ========================================================================
|
|
706
|
+
# 4. EXOGENOUS FEATURES ENGINEERING
|
|
707
|
+
# ========================================================================
|
|
708
|
+
|
|
709
|
+
if verbose:
|
|
710
|
+
print("\n[4/9] Creating exogenous features...")
|
|
711
|
+
|
|
712
|
+
# Location for day/night features
|
|
713
|
+
location = LocationInfo(
|
|
714
|
+
latitude=latitude,
|
|
715
|
+
longitude=longitude,
|
|
716
|
+
timezone=timezone,
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Holidays
|
|
720
|
+
holiday_features = _get_holiday_features(
|
|
721
|
+
data=imputed_data,
|
|
722
|
+
start=start,
|
|
723
|
+
cov_end=cov_end,
|
|
724
|
+
forecast_horizon=forecast_horizon,
|
|
725
|
+
tz=timezone,
|
|
726
|
+
freq="h",
|
|
727
|
+
country_code=country_code,
|
|
728
|
+
state=state,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# Weather
|
|
732
|
+
weather_features, weather_aligned = _get_weather_features(
|
|
733
|
+
data=imputed_data,
|
|
734
|
+
start=start,
|
|
735
|
+
cov_end=cov_end,
|
|
736
|
+
forecast_horizon=forecast_horizon,
|
|
737
|
+
latitude=latitude,
|
|
738
|
+
longitude=longitude,
|
|
739
|
+
timezone=timezone,
|
|
740
|
+
freq="h",
|
|
741
|
+
verbose=verbose,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Calendar
|
|
745
|
+
calendar_features = _get_calendar_features(
|
|
746
|
+
start=start,
|
|
747
|
+
cov_end=cov_end,
|
|
748
|
+
freq="h",
|
|
749
|
+
timezone=timezone,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
# Day/night
|
|
753
|
+
sun_light_features = _get_day_night_features(
|
|
754
|
+
start=start,
|
|
755
|
+
cov_end=cov_end,
|
|
756
|
+
location=location,
|
|
757
|
+
freq="h",
|
|
758
|
+
timezone=timezone,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# ========================================================================
|
|
762
|
+
# 5. COMBINE EXOGENOUS FEATURES
|
|
763
|
+
# ========================================================================
|
|
764
|
+
|
|
765
|
+
if verbose:
|
|
766
|
+
print("\n[5/9] Combining and encoding exogenous features...")
|
|
767
|
+
|
|
768
|
+
exogenous_features = pd.concat(
|
|
769
|
+
[
|
|
770
|
+
calendar_features,
|
|
771
|
+
sun_light_features,
|
|
772
|
+
weather_features,
|
|
773
|
+
holiday_features,
|
|
774
|
+
],
|
|
775
|
+
axis=1,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
assert (
|
|
779
|
+
sum(exogenous_features.isnull().sum()) == 0
|
|
780
|
+
), "Missing values in exogenous features"
|
|
781
|
+
|
|
782
|
+
# Apply cyclical encoding
|
|
783
|
+
exogenous_features = _apply_cyclical_encoding(
|
|
784
|
+
data=exogenous_features,
|
|
785
|
+
drop_original=False,
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
# Create interactions
|
|
789
|
+
exogenous_features = _create_interaction_features(
|
|
790
|
+
exogenous_features=exogenous_features,
|
|
791
|
+
weather_aligned=weather_aligned,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# ========================================================================
|
|
795
|
+
# 6. SELECT EXOGENOUS FEATURES
|
|
796
|
+
# ========================================================================
|
|
797
|
+
|
|
798
|
+
exog_features = _select_exogenous_features(
|
|
799
|
+
exogenous_features=exogenous_features,
|
|
800
|
+
weather_aligned=weather_aligned,
|
|
801
|
+
include_weather_windows=include_weather_windows,
|
|
802
|
+
include_holiday_features=include_holiday_features,
|
|
803
|
+
include_poly_features=include_poly_features,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
if verbose:
|
|
807
|
+
print(f" Selected {len(exog_features)} exogenous features")
|
|
808
|
+
|
|
809
|
+
# ========================================================================
|
|
810
|
+
# 7. MERGE DATA AND COVARIATES
|
|
811
|
+
# ========================================================================
|
|
812
|
+
|
|
813
|
+
if verbose:
|
|
814
|
+
print("\n[6/9] Merging target and exogenous data...")
|
|
815
|
+
|
|
816
|
+
data_with_exog, exo_tmp, exo_pred = _merge_data_and_covariates(
|
|
817
|
+
data=imputed_data,
|
|
818
|
+
exogenous_features=exogenous_features,
|
|
819
|
+
target_columns=target_columns,
|
|
820
|
+
exog_features=exog_features,
|
|
821
|
+
start=start,
|
|
822
|
+
end=end,
|
|
823
|
+
cov_end=cov_end,
|
|
824
|
+
forecast_horizon=forecast_horizon,
|
|
825
|
+
cast_dtype="float32",
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
if verbose:
|
|
829
|
+
print(f" Merged data shape: {data_with_exog.shape}")
|
|
830
|
+
print(f" Exogenous prediction shape: {exo_pred.shape}")
|
|
831
|
+
|
|
832
|
+
# ========================================================================
|
|
833
|
+
# 8. TRAIN/VALIDATION/TEST SPLIT
|
|
834
|
+
# ========================================================================
|
|
835
|
+
|
|
836
|
+
if verbose:
|
|
837
|
+
print("\n[7/9] Splitting data into train/validation/test...")
|
|
838
|
+
|
|
839
|
+
perc_val = 1.0 - train_ratio
|
|
840
|
+
data_train, data_val, data_test = split_rel_train_val_test(
|
|
841
|
+
data_with_exog,
|
|
842
|
+
perc_train=train_ratio,
|
|
843
|
+
perc_val=perc_val,
|
|
844
|
+
verbose=verbose,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
# ========================================================================
|
|
848
|
+
# 9. MODEL TRAINING
|
|
849
|
+
# ========================================================================
|
|
850
|
+
|
|
851
|
+
if verbose:
|
|
852
|
+
print("\n[8/9] Training recursive forecasters with exogenous variables...")
|
|
853
|
+
|
|
854
|
+
if estimator is None:
|
|
855
|
+
estimator = LGBMRegressor(random_state=1234, verbose=-1)
|
|
856
|
+
|
|
857
|
+
window_features = RollingFeatures(stats=["mean"], window_sizes=window_size)
|
|
858
|
+
end_validation = pd.concat([data_train, data_val]).index[-1]
|
|
859
|
+
|
|
860
|
+
recursive_forecasters = {}
|
|
861
|
+
|
|
862
|
+
target_iter = target_columns
|
|
863
|
+
if show_progress and tqdm is not None:
|
|
864
|
+
target_iter = tqdm(target_columns, desc="Training forecasters", unit="model")
|
|
865
|
+
|
|
866
|
+
for target in target_iter:
|
|
867
|
+
if verbose:
|
|
868
|
+
print(f" Training forecaster for {target}...")
|
|
869
|
+
|
|
870
|
+
forecaster = ForecasterRecursive(
|
|
871
|
+
estimator=estimator,
|
|
872
|
+
lags=lags,
|
|
873
|
+
window_features=window_features,
|
|
874
|
+
weight_func=weight_func,
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
forecaster.fit(
|
|
878
|
+
y=data_with_exog[target].loc[:end_validation].squeeze(),
|
|
879
|
+
exog=data_with_exog[exog_features].loc[:end_validation],
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
recursive_forecasters[target] = forecaster
|
|
883
|
+
|
|
884
|
+
if verbose:
|
|
885
|
+
print(f" ✓ Forecaster trained for {target}")
|
|
886
|
+
|
|
887
|
+
if verbose:
|
|
888
|
+
print(f" ✓ Total forecasters trained: {len(recursive_forecasters)}")
|
|
889
|
+
|
|
890
|
+
# ========================================================================
|
|
891
|
+
# 10. PREDICTION
|
|
892
|
+
# ========================================================================
|
|
893
|
+
|
|
894
|
+
if verbose:
|
|
895
|
+
print("\n[9/9] Generating predictions...")
|
|
896
|
+
|
|
897
|
+
exo_pred_subset = exo_pred[exog_features]
|
|
898
|
+
|
|
899
|
+
predictions = predict_multivariate(
|
|
900
|
+
recursive_forecasters,
|
|
901
|
+
steps_ahead=forecast_horizon,
|
|
902
|
+
exog=exo_pred_subset,
|
|
903
|
+
show_progress=show_progress,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
if verbose:
|
|
907
|
+
print(f" Predictions shape: {predictions.shape}")
|
|
908
|
+
print("\n" + "=" * 80)
|
|
909
|
+
print("Forecasting completed successfully!")
|
|
910
|
+
print("=" * 80)
|
|
911
|
+
|
|
912
|
+
# ========================================================================
|
|
913
|
+
# COMPILE METADATA
|
|
914
|
+
# ========================================================================
|
|
915
|
+
|
|
916
|
+
metadata = {
|
|
917
|
+
"forecast_horizon": forecast_horizon,
|
|
918
|
+
"target_columns": target_columns,
|
|
919
|
+
"exog_features": exog_features,
|
|
920
|
+
"n_exog_features": len(exog_features),
|
|
921
|
+
"train_size": len(data_train),
|
|
922
|
+
"val_size": len(data_val),
|
|
923
|
+
"test_size": len(data_test),
|
|
924
|
+
"data_shape_original": data.shape,
|
|
925
|
+
"data_shape_merged": data_with_exog.shape,
|
|
926
|
+
"training_end": end_validation,
|
|
927
|
+
"prediction_start": exo_pred.index[0],
|
|
928
|
+
"prediction_end": exo_pred.index[-1],
|
|
929
|
+
"lags": lags,
|
|
930
|
+
"window_size": window_size,
|
|
931
|
+
"contamination": contamination,
|
|
932
|
+
"n_outliers": (
|
|
933
|
+
outliers.sum() if isinstance(outliers, pd.Series) else len(outliers)
|
|
934
|
+
),
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
return predictions, metadata, recursive_forecasters
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
spotforecast2/.DS_Store,sha256=4yBH5_e0YHcGSgDSeKs4V5_sHINqyWiP33kMXar-lz8,6148
|
|
1
2
|
spotforecast2/__init__.py,sha256=X9sBx15iz8yqr9iDJcrGJM5nhvnpaczXto4XV_GtfhE,59
|
|
2
3
|
spotforecast2/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
4
|
spotforecast2/data/data.py,sha256=HEgr-FULaqHvuMeKTviOgYyo3GbxpGRTo3ZnmIU9w2Y,4422
|
|
4
5
|
spotforecast2/data/fetch_data.py,sha256=LcHowE6tnjKPNMTCGr8h29ioGHT4xmj6l6iZmZkJdLU,6842
|
|
5
6
|
spotforecast2/exceptions.py,sha256=gi8rmJWLKEpi3kNB9jWdHcH6XYkmsfyHfXTNg_fAy0w,20497
|
|
7
|
+
spotforecast2/forecaster/.DS_Store,sha256=GXyLvW6LC7GpVyo-vy-zehyHDbffWnsn_ZBT5AX0CQI,6148
|
|
6
8
|
spotforecast2/forecaster/__init__.py,sha256=BbCOS2ouKcPC9VzcdprllVyqlZIyAWXCOvUAiInxDi4,140
|
|
7
9
|
spotforecast2/forecaster/base.py,sha256=rXhcjY4AMpyQhkpbtLIA8OOrGEb8fU57SQiyeR9c9DQ,16748
|
|
8
10
|
spotforecast2/forecaster/metrics.py,sha256=MiZs9MAvT5JjPEGEks1uWR0nFuzYucCWuu4bMV_4HPQ,19316
|
|
@@ -10,7 +12,7 @@ spotforecast2/forecaster/recursive/__init__.py,sha256=YNVxLReLEwSFDasmjXXMSKJqNL
|
|
|
10
12
|
spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py,sha256=Mdr-3D1lUivXO07Rp4T8NIgQ2H_2y4IR4BqCwjBtZsw,48261
|
|
11
13
|
spotforecast2/forecaster/recursive/_forecaster_recursive.py,sha256=oU2zCOI0UaGIn8doLJGphP7jcNL5FF6Y972UCwlxDJI,35739
|
|
12
14
|
spotforecast2/forecaster/recursive/_warnings.py,sha256=BtZ3UoycywjEQ0ceXe4TL1WEdFcLAi1EnDMvZXHw_U8,325
|
|
13
|
-
spotforecast2/forecaster/utils.py,sha256=
|
|
15
|
+
spotforecast2/forecaster/utils.py,sha256=hWyDjDNhoYfu9vheLk2-bTNDX-fbX1oJhIbpzc8nu_I,36530
|
|
14
16
|
spotforecast2/model_selection/__init__.py,sha256=uP60TkgDzs_x5V60rnKanc12S9-yXx2ZLsXsXdqAYEA,208
|
|
15
17
|
spotforecast2/model_selection/bayesian_search.py,sha256=Vwb_LatDnt22LhIWyzqNhCdlDQ_UgVCyFcXmOxF3Pic,17407
|
|
16
18
|
spotforecast2/model_selection/grid_search.py,sha256=a5rNEndTXlx1ghT7ws5qs7WM0XBFMqEiK3Q5k7P0EJg,10998
|
|
@@ -31,7 +33,8 @@ spotforecast2/preprocessing/imputation.py,sha256=lmH-HumI_QLLm9aMESe_oZq84Axn60w
|
|
|
31
33
|
spotforecast2/preprocessing/outlier.py,sha256=jZxAR870QtYner7b4gXk6LLGJw0juLq1VU4CGklYd3c,4208
|
|
32
34
|
spotforecast2/preprocessing/split.py,sha256=mzzt5ltUZdVzfWtBBTQjp8E2MyqVdWUFtz7nN11urbU,5011
|
|
33
35
|
spotforecast2/processing/agg_predict.py,sha256=VKlruB0x-eJKokkHyJxR87rZ4m53si3ODbrd0ibPlow,2378
|
|
34
|
-
spotforecast2/processing/n2n_predict.py,sha256=
|
|
36
|
+
spotforecast2/processing/n2n_predict.py,sha256=Jkf-fMw2RSKY8-0UDc8D0yiiZxiF9s5DyfeRpfx90ks,4060
|
|
37
|
+
spotforecast2/processing/n2n_predict_with_covariates.py,sha256=Py9oMSUFv_9Tw5S9TfNF__MzEZNmGaN85lPbg6GBluw,31111
|
|
35
38
|
spotforecast2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
39
|
spotforecast2/utils/__init__.py,sha256=NrMt_xJLe4rbTFbsbgSQYeREohEOiYG5S-97e6Jj07I,1018
|
|
37
40
|
spotforecast2/utils/convert_to_utc.py,sha256=hz8mJUHK9jDLUiN5LdNX5l3KZuOKlklyycB4zFdB9Ng,1405
|
|
@@ -39,8 +42,8 @@ spotforecast2/utils/data_transform.py,sha256=PhLeZoimM0TLfp34Fp56dQrxlCYNWGVU8h8
|
|
|
39
42
|
spotforecast2/utils/forecaster_config.py,sha256=0jchk_9tjxzttN8btWlRBfAjT2bz27JO4CDrpPsC58E,12875
|
|
40
43
|
spotforecast2/utils/generate_holiday.py,sha256=SHaPvPMt-abis95cChHf5ObyPwCTrzJ87bxffeqZLRc,2707
|
|
41
44
|
spotforecast2/utils/validation.py,sha256=vcfpS6HF7YzVjKUZl-AGrIW71vCXrATJlfg2ZLjUse0,21633
|
|
42
|
-
spotforecast2/weather/__init__.py,sha256=
|
|
45
|
+
spotforecast2/weather/__init__.py,sha256=1Jco88pl0deNESgNATin83Nf5i9c58pxN7G-vNiOiu0,120
|
|
43
46
|
spotforecast2/weather/weather_client.py,sha256=Ec_ywug6uoa71MfXM8RNbXEvtBtBzr-SUS5xq_HKtZE,9837
|
|
44
|
-
spotforecast2-0.0.
|
|
45
|
-
spotforecast2-0.0.
|
|
46
|
-
spotforecast2-0.0.
|
|
47
|
+
spotforecast2-0.0.3.dist-info/WHEEL,sha256=ZyFSCYkV2BrxH6-HRVRg3R9Fo7MALzer9KiPYqNxSbo,79
|
|
48
|
+
spotforecast2-0.0.3.dist-info/METADATA,sha256=4jGNYzlX26Q1h9l0lfe2AU1AVhA-IyfXbRXktBIPKdg,1475
|
|
49
|
+
spotforecast2-0.0.3.dist-info/RECORD,,
|