spotforecast2 0.0.5__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/PKG-INFO +1 -1
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/pyproject.toml +1 -1
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/utils.py +1 -1
- spotforecast2-0.1.1/src/spotforecast2/processing/n2n_predict.py +437 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/processing/n2n_predict_with_covariates.py +296 -29
- spotforecast2-0.0.5/src/spotforecast2/processing/n2n_predict.py +0 -126
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/README.md +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/data/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/data/data.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/data/fetch_data.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/exceptions.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/base.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/metrics.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/recursive/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/recursive/_forecaster_recursive.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/recursive/_warnings.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/bayesian_search.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/grid_search.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/random_search.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/split_base.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/split_one_step.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/split_ts_cv.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/utils_common.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/utils_metrics.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/validation.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/_binner.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/_common.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/_differentiator.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/_rolling.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/curate_data.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/imputation.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/outlier.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/split.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/processing/agg_predict.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/py.typed +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/utils/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/utils/convert_to_utc.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/utils/data_transform.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/utils/forecaster_config.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/utils/generate_holiday.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/utils/validation.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/weather/__init__.py +0 -0
- {spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/weather/weather_client.py +0 -0
|
@@ -142,7 +142,7 @@ def prepare_steps_direct(
|
|
|
142
142
|
steps: int, list, None, default None
|
|
143
143
|
Predict n steps. The value of `steps` must be less than or equal to the
|
|
144
144
|
value of steps defined when initializing the forecaster. Starts at 1.
|
|
145
|
-
|
|
145
|
+
|
|
146
146
|
- If `int`: Only steps within the range of 1 to int are predicted.
|
|
147
147
|
- If `list`: List of ints. Only the steps contained in the list
|
|
148
148
|
are predicted.
|
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
"""
|
|
2
|
+
End-to-end baseline forecasting using equivalent date method.
|
|
3
|
+
|
|
4
|
+
This module provides a complete forecasting pipeline using the ForecasterEquivalentDate
|
|
5
|
+
baseline model. It handles data preparation, outlier detection, imputation, model
|
|
6
|
+
training, and prediction in a single integrated function.
|
|
7
|
+
|
|
8
|
+
Model persistence follows scikit-learn conventions using joblib for efficient
|
|
9
|
+
serialization and deserialization of trained forecasters.
|
|
10
|
+
|
|
11
|
+
Examples:
|
|
12
|
+
Basic usage with default parameters:
|
|
13
|
+
|
|
14
|
+
>>> from spotforecast2.processing.n2n_predict import n2n_predict
|
|
15
|
+
>>> predictions = n2n_predict(forecast_horizon=24, verbose=True)
|
|
16
|
+
|
|
17
|
+
Using cached models:
|
|
18
|
+
|
|
19
|
+
>>> # Load existing models if available, or train new ones
|
|
20
|
+
>>> predictions = n2n_predict(
|
|
21
|
+
... forecast_horizon=24,
|
|
22
|
+
... force_train=False,
|
|
23
|
+
... model_dir="./models",
|
|
24
|
+
... verbose=True
|
|
25
|
+
... )
|
|
26
|
+
|
|
27
|
+
Force retraining and update cache:
|
|
28
|
+
|
|
29
|
+
>>> predictions = n2n_predict(
|
|
30
|
+
... forecast_horizon=24,
|
|
31
|
+
... force_train=True,
|
|
32
|
+
... model_dir="./models",
|
|
33
|
+
... verbose=True
|
|
34
|
+
... )
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from pathlib import Path
|
|
38
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
39
|
+
|
|
40
|
+
import pandas as pd
|
|
41
|
+
from spotforecast2.forecaster.recursive import ForecasterEquivalentDate
|
|
42
|
+
from spotforecast2.data.fetch_data import fetch_data
|
|
43
|
+
from spotforecast2.preprocessing.curate_data import basic_ts_checks
|
|
44
|
+
from spotforecast2.preprocessing.curate_data import agg_and_resample_data
|
|
45
|
+
from spotforecast2.preprocessing.outlier import mark_outliers
|
|
46
|
+
from spotforecast2.preprocessing.split import split_rel_train_val_test
|
|
47
|
+
from spotforecast2.forecaster.utils import predict_multivariate
|
|
48
|
+
from spotforecast2.preprocessing.curate_data import get_start_end
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
from joblib import dump, load
|
|
52
|
+
except ImportError:
|
|
53
|
+
raise ImportError("joblib is required. Install with: pip install joblib")
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
from tqdm.auto import tqdm
|
|
57
|
+
except ImportError: # pragma: no cover - fallback when tqdm is not installed
|
|
58
|
+
tqdm = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# ============================================================================
|
|
62
|
+
# Model Persistence Functions
|
|
63
|
+
# ============================================================================
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _ensure_model_dir(model_dir: Union[str, Path]) -> Path:
|
|
67
|
+
"""Ensure model directory exists.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model_dir: Directory path for model storage.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Path: Validated Path object.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
OSError: If directory cannot be created.
|
|
77
|
+
"""
|
|
78
|
+
model_path = Path(model_dir)
|
|
79
|
+
model_path.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
return model_path
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_model_filepath(model_dir: Path, target: str) -> Path:
|
|
84
|
+
"""Get filepath for a single model.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model_dir: Directory containing models.
|
|
88
|
+
target: Target variable name.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Path: Full filepath for the model.
|
|
92
|
+
|
|
93
|
+
Examples:
|
|
94
|
+
>>> path = _get_model_filepath(Path("./models"), "power")
|
|
95
|
+
>>> str(path)
|
|
96
|
+
'./models/forecaster_power.joblib'
|
|
97
|
+
"""
|
|
98
|
+
return model_dir / f"forecaster_{target}.joblib"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _save_forecasters(
|
|
102
|
+
forecasters: Dict[str, object],
|
|
103
|
+
model_dir: Union[str, Path],
|
|
104
|
+
verbose: bool = False,
|
|
105
|
+
) -> Dict[str, Path]:
|
|
106
|
+
"""Save trained forecasters to disk using joblib.
|
|
107
|
+
|
|
108
|
+
Follows scikit-learn persistence conventions using joblib for efficient
|
|
109
|
+
serialization of sklearn-compatible estimators.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
forecasters: Dictionary mapping target names to trained ForecasterEquivalentDate objects.
|
|
113
|
+
model_dir: Directory to save models. Created if it doesn't exist.
|
|
114
|
+
verbose: Print progress messages. Default: False.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Dict[str, Path]: Dictionary mapping target names to saved model filepaths.
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
OSError: If models cannot be written to disk.
|
|
121
|
+
TypeError: If forecasters contain non-serializable objects.
|
|
122
|
+
|
|
123
|
+
Examples:
|
|
124
|
+
>>> forecasters = {"power": forecaster_obj}
|
|
125
|
+
>>> paths = _save_forecasters(forecasters, "./models", verbose=True)
|
|
126
|
+
>>> print(paths["power"])
|
|
127
|
+
models/forecaster_power.joblib
|
|
128
|
+
"""
|
|
129
|
+
model_path = _ensure_model_dir(model_dir)
|
|
130
|
+
saved_paths = {}
|
|
131
|
+
|
|
132
|
+
for target, forecaster in forecasters.items():
|
|
133
|
+
filepath = _get_model_filepath(model_path, target)
|
|
134
|
+
try:
|
|
135
|
+
dump(forecaster, filepath, compress=3)
|
|
136
|
+
saved_paths[target] = filepath
|
|
137
|
+
if verbose:
|
|
138
|
+
print(f" ✓ Saved forecaster for {target} to {filepath}")
|
|
139
|
+
except Exception as e:
|
|
140
|
+
raise OSError(f"Failed to save model for {target}: {e}")
|
|
141
|
+
|
|
142
|
+
return saved_paths
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _load_forecasters(
|
|
146
|
+
target_columns: List[str],
|
|
147
|
+
model_dir: Union[str, Path],
|
|
148
|
+
verbose: bool = False,
|
|
149
|
+
) -> Tuple[Dict[str, object], List[str]]:
|
|
150
|
+
"""Load trained forecasters from disk using joblib.
|
|
151
|
+
|
|
152
|
+
Attempts to load all forecasters for given targets. Missing models
|
|
153
|
+
are indicated in the return value for selective retraining.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
target_columns: List of target variable names to load.
|
|
157
|
+
model_dir: Directory containing saved models.
|
|
158
|
+
verbose: Print progress messages. Default: False.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Tuple[Dict[str, object], List[str]]:
|
|
162
|
+
- forecasters: Dictionary of successfully loaded ForecasterEquivalentDate objects.
|
|
163
|
+
- missing_targets: List of target names without saved models.
|
|
164
|
+
|
|
165
|
+
Examples:
|
|
166
|
+
>>> forecasters, missing = _load_forecasters(
|
|
167
|
+
... ["power", "energy"],
|
|
168
|
+
... "./models",
|
|
169
|
+
... verbose=True
|
|
170
|
+
... )
|
|
171
|
+
>>> print(missing)
|
|
172
|
+
['energy']
|
|
173
|
+
"""
|
|
174
|
+
model_path = Path(model_dir)
|
|
175
|
+
forecasters = {}
|
|
176
|
+
missing_targets = []
|
|
177
|
+
|
|
178
|
+
for target in target_columns:
|
|
179
|
+
filepath = _get_model_filepath(model_path, target)
|
|
180
|
+
if filepath.exists():
|
|
181
|
+
try:
|
|
182
|
+
forecasters[target] = load(filepath)
|
|
183
|
+
if verbose:
|
|
184
|
+
print(f" ✓ Loaded forecaster for {target} from {filepath}")
|
|
185
|
+
except Exception as e:
|
|
186
|
+
if verbose:
|
|
187
|
+
print(f" ✗ Failed to load {target}: {e}")
|
|
188
|
+
missing_targets.append(target)
|
|
189
|
+
else:
|
|
190
|
+
missing_targets.append(target)
|
|
191
|
+
|
|
192
|
+
return forecasters, missing_targets
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _model_directory_exists(model_dir: Union[str, Path]) -> bool:
|
|
196
|
+
"""Check if model directory exists.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
model_dir: Directory path to check.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
bool: True if directory exists, False otherwise.
|
|
203
|
+
"""
|
|
204
|
+
return Path(model_dir).exists()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# ============================================================================
|
|
208
|
+
# Main Function
|
|
209
|
+
# ============================================================================
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def n2n_predict(
|
|
213
|
+
columns: Optional[List[str]] = None,
|
|
214
|
+
forecast_horizon: int = 24,
|
|
215
|
+
contamination: float = 0.01,
|
|
216
|
+
window_size: int = 72,
|
|
217
|
+
force_train: bool = False,
|
|
218
|
+
model_dir: Union[str, Path] = "./models_baseline",
|
|
219
|
+
verbose: bool = True,
|
|
220
|
+
show_progress: bool = True,
|
|
221
|
+
) -> Tuple[pd.DataFrame, Dict]:
|
|
222
|
+
"""End-to-end baseline forecasting using equivalent date method.
|
|
223
|
+
|
|
224
|
+
This function implements a complete forecasting pipeline that:
|
|
225
|
+
1. Loads and validates target data
|
|
226
|
+
2. Detects and removes outliers
|
|
227
|
+
3. Imputes missing values
|
|
228
|
+
4. Splits into train/validation/test sets
|
|
229
|
+
5. Trains or loads equivalent date forecasters
|
|
230
|
+
6. Generates multi-step ahead predictions
|
|
231
|
+
|
|
232
|
+
Models are persisted to disk following scikit-learn conventions using joblib.
|
|
233
|
+
Existing models are reused for prediction unless force_train=True.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
columns: List of target columns to forecast. If None, uses all available columns.
|
|
237
|
+
Default: None.
|
|
238
|
+
forecast_horizon: Number of time steps to forecast ahead. Default: 24.
|
|
239
|
+
contamination: Contamination parameter for outlier detection. Default: 0.01.
|
|
240
|
+
window_size: Rolling window size for gap detection. Default: 72.
|
|
241
|
+
force_train: Force retraining of all models, ignoring cached models.
|
|
242
|
+
Default: False.
|
|
243
|
+
model_dir: Directory for saving/loading trained models.
|
|
244
|
+
Default: "./models_baseline".
|
|
245
|
+
verbose: Print progress messages. Default: True.
|
|
246
|
+
show_progress: Show progress bar during training and prediction. Default: True.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple containing:
|
|
250
|
+
- predictions: DataFrame with forecast values for each target variable.
|
|
251
|
+
- forecasters: Dictionary of trained ForecasterEquivalentDate objects keyed by target.
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
ValueError: If data validation fails or required data cannot be retrieved.
|
|
255
|
+
ImportError: If required dependencies are not installed.
|
|
256
|
+
OSError: If models cannot be saved to disk.
|
|
257
|
+
|
|
258
|
+
Examples:
|
|
259
|
+
Basic usage with automatic model caching:
|
|
260
|
+
|
|
261
|
+
>>> predictions, forecasters = n2n_predict(
|
|
262
|
+
... forecast_horizon=24,
|
|
263
|
+
... verbose=True
|
|
264
|
+
... )
|
|
265
|
+
>>> print(predictions.shape)
|
|
266
|
+
(24, 11)
|
|
267
|
+
|
|
268
|
+
Load cached models (if available):
|
|
269
|
+
|
|
270
|
+
>>> predictions, forecasters = n2n_predict(
|
|
271
|
+
... forecast_horizon=24,
|
|
272
|
+
... force_train=False,
|
|
273
|
+
... model_dir="./saved_models",
|
|
274
|
+
... verbose=True
|
|
275
|
+
... )
|
|
276
|
+
|
|
277
|
+
Force retraining and update cache:
|
|
278
|
+
|
|
279
|
+
>>> predictions, forecasters = n2n_predict(
|
|
280
|
+
... forecast_horizon=24,
|
|
281
|
+
... force_train=True,
|
|
282
|
+
... model_dir="./saved_models",
|
|
283
|
+
... verbose=True
|
|
284
|
+
... )
|
|
285
|
+
|
|
286
|
+
With specific target columns:
|
|
287
|
+
|
|
288
|
+
>>> predictions, forecasters = n2n_predict(
|
|
289
|
+
... columns=["power", "energy"],
|
|
290
|
+
... forecast_horizon=48,
|
|
291
|
+
... force_train=False,
|
|
292
|
+
... verbose=True
|
|
293
|
+
... )
|
|
294
|
+
|
|
295
|
+
Notes:
|
|
296
|
+
- Trained models are saved to disk using joblib for fast reuse.
|
|
297
|
+
- When force_train=False, existing models are loaded and prediction
|
|
298
|
+
proceeds without retraining. This significantly speeds up prediction
|
|
299
|
+
for repeated calls with the same configuration.
|
|
300
|
+
- The model_dir directory is created automatically if it doesn't exist.
|
|
301
|
+
|
|
302
|
+
Performance Notes:
|
|
303
|
+
- First run: Full training (~2-5 minutes depending on data size)
|
|
304
|
+
- Subsequent runs (force_train=False): Model loading only (~1-2 seconds)
|
|
305
|
+
- Force retrain (force_train=True): Full training again (~2-5 minutes)
|
|
306
|
+
"""
|
|
307
|
+
if columns is not None:
|
|
308
|
+
TARGET = columns
|
|
309
|
+
else:
|
|
310
|
+
TARGET = None
|
|
311
|
+
|
|
312
|
+
if verbose:
|
|
313
|
+
print("--- Starting n2n_predict ---")
|
|
314
|
+
print("Fetching data...")
|
|
315
|
+
|
|
316
|
+
# Fetch data
|
|
317
|
+
data = fetch_data(columns=TARGET)
|
|
318
|
+
|
|
319
|
+
START, END, COV_START, COV_END = get_start_end(
|
|
320
|
+
data=data,
|
|
321
|
+
forecast_horizon=forecast_horizon,
|
|
322
|
+
verbose=verbose,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
basic_ts_checks(data, verbose=verbose)
|
|
326
|
+
|
|
327
|
+
data = agg_and_resample_data(data, verbose=verbose)
|
|
328
|
+
|
|
329
|
+
# --- Outlier Handling ---
|
|
330
|
+
if verbose:
|
|
331
|
+
print("Handling outliers...")
|
|
332
|
+
|
|
333
|
+
# data_old = data.copy() # kept in notebook, maybe useful for debugging but not used logic-wise here
|
|
334
|
+
data, outliers = mark_outliers(
|
|
335
|
+
data, contamination=contamination, random_state=1234, verbose=verbose
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# --- Missing Data (Imputation) ---
|
|
339
|
+
if verbose:
|
|
340
|
+
print("Imputing missing data...")
|
|
341
|
+
|
|
342
|
+
missing_indices = data.index[data.isnull().any(axis=1)]
|
|
343
|
+
if verbose:
|
|
344
|
+
n_missing = len(missing_indices)
|
|
345
|
+
pct_missing = (n_missing / len(data)) * 100
|
|
346
|
+
print(f"Number of rows with missing values: {n_missing}")
|
|
347
|
+
print(f"Percentage of rows with missing values: {pct_missing:.2f}%")
|
|
348
|
+
|
|
349
|
+
data = data.ffill()
|
|
350
|
+
data = data.bfill()
|
|
351
|
+
|
|
352
|
+
# --- Train, Val, Test Split ---
|
|
353
|
+
if verbose:
|
|
354
|
+
print("Splitting data...")
|
|
355
|
+
data_train, data_val, data_test = split_rel_train_val_test(
|
|
356
|
+
data, perc_train=0.8, perc_val=0.2, verbose=verbose
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# --- Model Fit ---
|
|
360
|
+
if verbose:
|
|
361
|
+
print("Fitting models...")
|
|
362
|
+
|
|
363
|
+
end_validation = pd.concat([data_train, data_val]).index[-1]
|
|
364
|
+
|
|
365
|
+
baseline_forecasters = {}
|
|
366
|
+
targets_to_train = list(data.columns)
|
|
367
|
+
|
|
368
|
+
# Attempt to load cached models if force_train=False
|
|
369
|
+
if not force_train and _model_directory_exists(model_dir):
|
|
370
|
+
if verbose:
|
|
371
|
+
print(" Attempting to load cached models...")
|
|
372
|
+
cached_forecasters, missing_targets = _load_forecasters(
|
|
373
|
+
target_columns=list(data.columns),
|
|
374
|
+
model_dir=model_dir,
|
|
375
|
+
verbose=verbose,
|
|
376
|
+
)
|
|
377
|
+
baseline_forecasters.update(cached_forecasters)
|
|
378
|
+
targets_to_train = missing_targets
|
|
379
|
+
|
|
380
|
+
if len(cached_forecasters) == len(data.columns):
|
|
381
|
+
if verbose:
|
|
382
|
+
print(f" ✓ All {len(data.columns)} forecasters loaded from cache")
|
|
383
|
+
elif len(cached_forecasters) > 0:
|
|
384
|
+
if verbose:
|
|
385
|
+
print(
|
|
386
|
+
f" ✓ Loaded {len(cached_forecasters)} forecasters, "
|
|
387
|
+
f"will train {len(targets_to_train)} new ones"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Train missing or forced models
|
|
391
|
+
if len(targets_to_train) > 0:
|
|
392
|
+
if force_train and len(baseline_forecasters) > 0:
|
|
393
|
+
if verbose:
|
|
394
|
+
print(f" Force retraining all {len(data.columns)} forecasters...")
|
|
395
|
+
targets_to_train = list(data.columns)
|
|
396
|
+
baseline_forecasters.clear()
|
|
397
|
+
|
|
398
|
+
target_iter = targets_to_train
|
|
399
|
+
if show_progress and tqdm is not None:
|
|
400
|
+
target_iter = tqdm(
|
|
401
|
+
targets_to_train,
|
|
402
|
+
desc="Training forecasters",
|
|
403
|
+
unit="model",
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
for target in target_iter:
|
|
407
|
+
forecaster = ForecasterEquivalentDate(
|
|
408
|
+
offset=pd.DateOffset(days=1), n_offsets=1
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
forecaster.fit(y=data.loc[:end_validation, target])
|
|
412
|
+
|
|
413
|
+
baseline_forecasters[target] = forecaster
|
|
414
|
+
|
|
415
|
+
# Save newly trained models to disk
|
|
416
|
+
if verbose:
|
|
417
|
+
print(f" Saving {len(targets_to_train)} trained forecasters to disk...")
|
|
418
|
+
_save_forecasters(
|
|
419
|
+
forecasters={t: baseline_forecasters[t] for t in targets_to_train},
|
|
420
|
+
model_dir=model_dir,
|
|
421
|
+
verbose=verbose,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
if verbose:
|
|
425
|
+
print(f" ✓ Total forecasters available: {len(baseline_forecasters)}")
|
|
426
|
+
|
|
427
|
+
# --- Predict ---
|
|
428
|
+
if verbose:
|
|
429
|
+
print("Generating predictions...")
|
|
430
|
+
|
|
431
|
+
predictions = predict_multivariate(
|
|
432
|
+
baseline_forecasters,
|
|
433
|
+
steps_ahead=forecast_horizon,
|
|
434
|
+
show_progress=show_progress,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return predictions, baseline_forecasters
|
|
@@ -6,6 +6,9 @@ recursive forecasters with exogenous variables (weather, holidays, calendar feat
|
|
|
6
6
|
It handles data preparation, feature engineering, model training, and prediction
|
|
7
7
|
in a single integrated function.
|
|
8
8
|
|
|
9
|
+
Model persistence follows scikit-learn conventions using joblib for efficient
|
|
10
|
+
serialization and deserialization of trained forecasters.
|
|
11
|
+
|
|
9
12
|
Examples:
|
|
10
13
|
Basic usage with default parameters:
|
|
11
14
|
|
|
@@ -27,8 +30,28 @@ Examples:
|
|
|
27
30
|
... train_ratio=0.75,
|
|
28
31
|
... verbose=True
|
|
29
32
|
... )
|
|
33
|
+
|
|
34
|
+
Using cached models:
|
|
35
|
+
|
|
36
|
+
>>> # Load existing models if available, or train new ones
|
|
37
|
+
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
38
|
+
... forecast_horizon=24,
|
|
39
|
+
... force_train=False,
|
|
40
|
+
... model_dir="./models",
|
|
41
|
+
... verbose=True
|
|
42
|
+
... )
|
|
43
|
+
|
|
44
|
+
Force retraining and update cache:
|
|
45
|
+
|
|
46
|
+
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
47
|
+
... forecast_horizon=24,
|
|
48
|
+
... force_train=True,
|
|
49
|
+
... model_dir="./models",
|
|
50
|
+
... verbose=True
|
|
51
|
+
... )
|
|
30
52
|
"""
|
|
31
53
|
|
|
54
|
+
from pathlib import Path
|
|
32
55
|
from typing import Dict, List, Optional, Tuple, Union
|
|
33
56
|
|
|
34
57
|
import numpy as np
|
|
@@ -37,6 +60,11 @@ from astral import LocationInfo
|
|
|
37
60
|
from lightgbm import LGBMRegressor
|
|
38
61
|
from sklearn.preprocessing import PolynomialFeatures
|
|
39
62
|
|
|
63
|
+
try:
|
|
64
|
+
from joblib import dump, load
|
|
65
|
+
except ImportError:
|
|
66
|
+
raise ImportError("joblib is required. Install with: pip install joblib")
|
|
67
|
+
|
|
40
68
|
try:
|
|
41
69
|
from tqdm.auto import tqdm
|
|
42
70
|
except ImportError: # pragma: no cover - fallback when tqdm is not installed
|
|
@@ -547,6 +575,152 @@ def _merge_data_and_covariates(
|
|
|
547
575
|
return data_with_exog, exo_tmp, exo_pred
|
|
548
576
|
|
|
549
577
|
|
|
578
|
+
# ============================================================================
|
|
579
|
+
# Model Persistence Functions
|
|
580
|
+
# ============================================================================
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def _ensure_model_dir(model_dir: Union[str, Path]) -> Path:
|
|
584
|
+
"""Ensure model directory exists.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
model_dir: Directory path for model storage.
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
Path: Validated Path object.
|
|
591
|
+
|
|
592
|
+
Raises:
|
|
593
|
+
OSError: If directory cannot be created.
|
|
594
|
+
"""
|
|
595
|
+
model_path = Path(model_dir)
|
|
596
|
+
model_path.mkdir(parents=True, exist_ok=True)
|
|
597
|
+
return model_path
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def _get_model_filepath(model_dir: Path, target: str) -> Path:
|
|
601
|
+
"""Get filepath for a single model.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
model_dir: Directory containing models.
|
|
605
|
+
target: Target variable name.
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
Path: Full filepath for the model.
|
|
609
|
+
|
|
610
|
+
Examples:
|
|
611
|
+
>>> path = _get_model_filepath(Path("./models"), "power")
|
|
612
|
+
>>> str(path)
|
|
613
|
+
'./models/forecaster_power.joblib'
|
|
614
|
+
"""
|
|
615
|
+
return model_dir / f"forecaster_{target}.joblib"
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def _save_forecasters(
|
|
619
|
+
forecasters: Dict[str, object],
|
|
620
|
+
model_dir: Union[str, Path],
|
|
621
|
+
verbose: bool = False,
|
|
622
|
+
) -> Dict[str, Path]:
|
|
623
|
+
"""Save trained forecasters to disk using joblib.
|
|
624
|
+
|
|
625
|
+
Follows scikit-learn persistence conventions using joblib for efficient
|
|
626
|
+
serialization of sklearn-compatible estimators.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
forecasters: Dictionary mapping target names to trained ForecasterRecursive objects.
|
|
630
|
+
model_dir: Directory to save models. Created if it doesn't exist.
|
|
631
|
+
verbose: Print progress messages. Default: False.
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
Dict[str, Path]: Dictionary mapping target names to saved model filepaths.
|
|
635
|
+
|
|
636
|
+
Raises:
|
|
637
|
+
OSError: If models cannot be written to disk.
|
|
638
|
+
TypeError: If forecasters contain non-serializable objects.
|
|
639
|
+
|
|
640
|
+
Examples:
|
|
641
|
+
>>> forecasters = {"power": forecaster_obj}
|
|
642
|
+
>>> paths = _save_forecasters(forecasters, "./models", verbose=True)
|
|
643
|
+
>>> print(paths["power"])
|
|
644
|
+
models/forecaster_power.joblib
|
|
645
|
+
"""
|
|
646
|
+
model_path = _ensure_model_dir(model_dir)
|
|
647
|
+
saved_paths = {}
|
|
648
|
+
|
|
649
|
+
for target, forecaster in forecasters.items():
|
|
650
|
+
filepath = _get_model_filepath(model_path, target)
|
|
651
|
+
try:
|
|
652
|
+
dump(forecaster, filepath, compress=3)
|
|
653
|
+
saved_paths[target] = filepath
|
|
654
|
+
if verbose:
|
|
655
|
+
print(f" ✓ Saved forecaster for {target} to {filepath}")
|
|
656
|
+
except Exception as e:
|
|
657
|
+
raise OSError(f"Failed to save model for {target}: {e}")
|
|
658
|
+
|
|
659
|
+
return saved_paths
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def _load_forecasters(
|
|
663
|
+
target_columns: List[str],
|
|
664
|
+
model_dir: Union[str, Path],
|
|
665
|
+
verbose: bool = False,
|
|
666
|
+
) -> Tuple[Dict[str, object], List[str]]:
|
|
667
|
+
"""Load trained forecasters from disk using joblib.
|
|
668
|
+
|
|
669
|
+
Attempts to load all forecasters for given targets. Missing models
|
|
670
|
+
are indicated in the return value for selective retraining.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
target_columns: List of target variable names to load.
|
|
674
|
+
model_dir: Directory containing saved models.
|
|
675
|
+
verbose: Print progress messages. Default: False.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
Tuple[Dict[str, object], List[str]]:
|
|
679
|
+
- forecasters: Dictionary of successfully loaded ForecasterRecursive objects.
|
|
680
|
+
- missing_targets: List of target names without saved models.
|
|
681
|
+
|
|
682
|
+
Examples:
|
|
683
|
+
>>> forecasters, missing = _load_forecasters(
|
|
684
|
+
... ["power", "energy"],
|
|
685
|
+
... "./models",
|
|
686
|
+
... verbose=True
|
|
687
|
+
... )
|
|
688
|
+
>>> print(missing)
|
|
689
|
+
['energy']
|
|
690
|
+
"""
|
|
691
|
+
model_path = Path(model_dir)
|
|
692
|
+
forecasters = {}
|
|
693
|
+
missing_targets = []
|
|
694
|
+
|
|
695
|
+
for target in target_columns:
|
|
696
|
+
filepath = _get_model_filepath(model_path, target)
|
|
697
|
+
if filepath.exists():
|
|
698
|
+
try:
|
|
699
|
+
forecasters[target] = load(filepath)
|
|
700
|
+
if verbose:
|
|
701
|
+
print(f" ✓ Loaded forecaster for {target} from {filepath}")
|
|
702
|
+
except Exception as e:
|
|
703
|
+
if verbose:
|
|
704
|
+
print(f" ✗ Failed to load {target}: {e}")
|
|
705
|
+
missing_targets.append(target)
|
|
706
|
+
else:
|
|
707
|
+
missing_targets.append(target)
|
|
708
|
+
|
|
709
|
+
return forecasters, missing_targets
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _model_directory_exists(model_dir: Union[str, Path]) -> bool:
|
|
713
|
+
"""Check if model directory exists.
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
model_dir: Directory path to check.
|
|
717
|
+
|
|
718
|
+
Returns:
|
|
719
|
+
bool: True if directory exists, False otherwise.
|
|
720
|
+
"""
|
|
721
|
+
return Path(model_dir).exists()
|
|
722
|
+
|
|
723
|
+
|
|
550
724
|
# ============================================================================
|
|
551
725
|
# Main Function
|
|
552
726
|
# ============================================================================
|
|
@@ -567,8 +741,10 @@ def n2n_predict_with_covariates(
|
|
|
567
741
|
include_weather_windows: bool = False,
|
|
568
742
|
include_holiday_features: bool = False,
|
|
569
743
|
include_poly_features: bool = False,
|
|
744
|
+
force_train: bool = False,
|
|
745
|
+
model_dir: Union[str, Path] = "./forecaster_models",
|
|
570
746
|
verbose: bool = True,
|
|
571
|
-
show_progress: bool =
|
|
747
|
+
show_progress: bool = False,
|
|
572
748
|
) -> Tuple[pd.DataFrame, Dict, Dict]:
|
|
573
749
|
"""End-to-end recursive forecasting with exogenous covariates.
|
|
574
750
|
|
|
@@ -580,9 +756,12 @@ def n2n_predict_with_covariates(
|
|
|
580
756
|
5. Performs feature engineering (cyclical encoding, interactions)
|
|
581
757
|
6. Merges target and exogenous data
|
|
582
758
|
7. Splits into train/validation/test sets
|
|
583
|
-
8. Trains recursive forecasters with sample weighting
|
|
759
|
+
8. Trains or loads recursive forecasters with sample weighting
|
|
584
760
|
9. Generates multi-step ahead predictions
|
|
585
761
|
|
|
762
|
+
Models are persisted to disk following scikit-learn conventions using joblib.
|
|
763
|
+
Existing models are reused for prediction unless force_train=True.
|
|
764
|
+
|
|
586
765
|
Args:
|
|
587
766
|
forecast_horizon: Number of time steps to forecast ahead. Default: 24.
|
|
588
767
|
contamination: Contamination parameter for outlier detection. Default: 0.01.
|
|
@@ -599,8 +778,12 @@ def n2n_predict_with_covariates(
|
|
|
599
778
|
include_weather_windows: Include weather window features. Default: False.
|
|
600
779
|
include_holiday_features: Include holiday features. Default: False.
|
|
601
780
|
include_poly_features: Include polynomial interaction features. Default: False.
|
|
781
|
+
force_train: Force retraining of all models, ignoring cached models.
|
|
782
|
+
Default: False.
|
|
783
|
+
model_dir: Directory for saving/loading trained models.
|
|
784
|
+
Default: "./models_covariates".
|
|
602
785
|
verbose: Print progress messages. Default: True.
|
|
603
|
-
show_progress: Show progress bar during training. Default:
|
|
786
|
+
show_progress: Show progress bar during training. Default: False.
|
|
604
787
|
|
|
605
788
|
Returns:
|
|
606
789
|
Tuple containing:
|
|
@@ -611,9 +794,10 @@ def n2n_predict_with_covariates(
|
|
|
611
794
|
Raises:
|
|
612
795
|
ValueError: If data validation fails or required data cannot be retrieved.
|
|
613
796
|
ImportError: If required dependencies are not installed.
|
|
797
|
+
OSError: If models cannot be saved to disk.
|
|
614
798
|
|
|
615
799
|
Examples:
|
|
616
|
-
Basic usage:
|
|
800
|
+
Basic usage with automatic model caching:
|
|
617
801
|
|
|
618
802
|
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
619
803
|
... forecast_horizon=24,
|
|
@@ -622,6 +806,22 @@ def n2n_predict_with_covariates(
|
|
|
622
806
|
>>> print(predictions.shape)
|
|
623
807
|
(24, 11)
|
|
624
808
|
|
|
809
|
+
Load cached models (if available):
|
|
810
|
+
|
|
811
|
+
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
812
|
+
... forecast_horizon=24,
|
|
813
|
+
... force_train=False,
|
|
814
|
+
... model_dir="./saved_models"
|
|
815
|
+
... )
|
|
816
|
+
|
|
817
|
+
Force retraining and update cache:
|
|
818
|
+
|
|
819
|
+
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
820
|
+
... forecast_horizon=24,
|
|
821
|
+
... force_train=True,
|
|
822
|
+
... model_dir="./saved_models"
|
|
823
|
+
... )
|
|
824
|
+
|
|
625
825
|
Custom location and features:
|
|
626
826
|
|
|
627
827
|
>>> predictions, metadata, forecasters = n2n_predict_with_covariates(
|
|
@@ -630,6 +830,7 @@ def n2n_predict_with_covariates(
|
|
|
630
830
|
... longitude=13.4050,
|
|
631
831
|
... lags=48,
|
|
632
832
|
... include_poly_features=True,
|
|
833
|
+
... force_train=False,
|
|
633
834
|
... verbose=True
|
|
634
835
|
... )
|
|
635
836
|
|
|
@@ -641,6 +842,16 @@ def n2n_predict_with_covariates(
|
|
|
641
842
|
near missing data.
|
|
642
843
|
- Train/validation splits are temporal (80/20 by default).
|
|
643
844
|
- All features are cast to float32 for memory efficiency.
|
|
845
|
+
- Trained models are saved to disk using joblib for fast reuse.
|
|
846
|
+
- When force_train=False, existing models are loaded and prediction
|
|
847
|
+
proceeds without retraining. This significantly speeds up prediction
|
|
848
|
+
for repeated calls with the same configuration.
|
|
849
|
+
- The model_dir directory is created automatically if it doesn't exist.
|
|
850
|
+
|
|
851
|
+
Performance Notes:
|
|
852
|
+
- First run: Full training (~5-10 minutes depending on data size)
|
|
853
|
+
- Subsequent runs (force_train=False): Model loading only (~1-2 seconds)
|
|
854
|
+
- Force retrain (force_train=True): Full training again (~5-10 minutes)
|
|
644
855
|
"""
|
|
645
856
|
if verbose:
|
|
646
857
|
print("=" * 80)
|
|
@@ -702,6 +913,10 @@ def n2n_predict_with_covariates(
|
|
|
702
913
|
"""Return sample weights for given index."""
|
|
703
914
|
return custom_weights(index, weights_series)
|
|
704
915
|
|
|
916
|
+
# Note: weight_func is a local function and cannot be pickled.
|
|
917
|
+
# Model persistence is disabled when using weight_func.
|
|
918
|
+
use_model_persistence = False
|
|
919
|
+
|
|
705
920
|
# ========================================================================
|
|
706
921
|
# 4. EXOGENOUS FEATURES ENGINEERING
|
|
707
922
|
# ========================================================================
|
|
@@ -845,11 +1060,13 @@ def n2n_predict_with_covariates(
|
|
|
845
1060
|
)
|
|
846
1061
|
|
|
847
1062
|
# ========================================================================
|
|
848
|
-
# 9. MODEL TRAINING
|
|
1063
|
+
# 9. MODEL TRAINING OR LOADING
|
|
849
1064
|
# ========================================================================
|
|
850
1065
|
|
|
851
1066
|
if verbose:
|
|
852
|
-
print(
|
|
1067
|
+
print(
|
|
1068
|
+
"\n[8/9] Loading or training recursive forecasters with exogenous variables..."
|
|
1069
|
+
)
|
|
853
1070
|
|
|
854
1071
|
if estimator is None:
|
|
855
1072
|
estimator = LGBMRegressor(random_state=1234, verbose=-1)
|
|
@@ -857,35 +1074,85 @@ def n2n_predict_with_covariates(
|
|
|
857
1074
|
window_features = RollingFeatures(stats=["mean"], window_sizes=window_size)
|
|
858
1075
|
end_validation = pd.concat([data_train, data_val]).index[-1]
|
|
859
1076
|
|
|
1077
|
+
# Attempt to load cached models if force_train=False and persistence is enabled
|
|
860
1078
|
recursive_forecasters = {}
|
|
1079
|
+
targets_to_train = target_columns
|
|
861
1080
|
|
|
862
|
-
|
|
863
|
-
if show_progress and tqdm is not None:
|
|
864
|
-
target_iter = tqdm(target_columns, desc="Training forecasters", unit="model")
|
|
865
|
-
|
|
866
|
-
for target in target_iter:
|
|
1081
|
+
if use_model_persistence and not force_train and _model_directory_exists(model_dir):
|
|
867
1082
|
if verbose:
|
|
868
|
-
print(
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
window_features=window_features,
|
|
874
|
-
weight_func=weight_func,
|
|
1083
|
+
print(" Attempting to load cached models...")
|
|
1084
|
+
cached_forecasters, missing_targets = _load_forecasters(
|
|
1085
|
+
target_columns=target_columns,
|
|
1086
|
+
model_dir=model_dir,
|
|
1087
|
+
verbose=verbose,
|
|
875
1088
|
)
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
1089
|
+
recursive_forecasters.update(cached_forecasters)
|
|
1090
|
+
targets_to_train = missing_targets
|
|
1091
|
+
|
|
1092
|
+
if len(cached_forecasters) == len(target_columns):
|
|
1093
|
+
if verbose:
|
|
1094
|
+
print(f" ✓ All {len(target_columns)} forecasters loaded from cache")
|
|
1095
|
+
elif len(cached_forecasters) > 0:
|
|
1096
|
+
if verbose:
|
|
1097
|
+
print(
|
|
1098
|
+
f" ✓ Loaded {len(cached_forecasters)} forecasters, "
|
|
1099
|
+
f"will train {len(targets_to_train)} new ones"
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
# Train missing or forced models
|
|
1103
|
+
if len(targets_to_train) > 0:
|
|
1104
|
+
if force_train and len(recursive_forecasters) > 0:
|
|
1105
|
+
if verbose:
|
|
1106
|
+
print(f" Force retraining all {len(target_columns)} forecasters...")
|
|
1107
|
+
targets_to_train = target_columns
|
|
1108
|
+
recursive_forecasters.clear()
|
|
1109
|
+
|
|
1110
|
+
target_iter = targets_to_train
|
|
1111
|
+
if show_progress and tqdm is not None:
|
|
1112
|
+
target_iter = tqdm(
|
|
1113
|
+
targets_to_train,
|
|
1114
|
+
desc="Training forecasters",
|
|
1115
|
+
unit="model",
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
for target in target_iter:
|
|
1119
|
+
if verbose:
|
|
1120
|
+
print(f" Training forecaster for {target}...")
|
|
1121
|
+
|
|
1122
|
+
forecaster = ForecasterRecursive(
|
|
1123
|
+
estimator=estimator,
|
|
1124
|
+
lags=lags,
|
|
1125
|
+
window_features=window_features,
|
|
1126
|
+
weight_func=weight_func,
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
forecaster.fit(
|
|
1130
|
+
y=data_with_exog[target].loc[:end_validation].squeeze(),
|
|
1131
|
+
exog=data_with_exog[exog_features].loc[:end_validation],
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
recursive_forecasters[target] = forecaster
|
|
1135
|
+
|
|
1136
|
+
if verbose:
|
|
1137
|
+
print(f" ✓ Forecaster trained for {target}")
|
|
1138
|
+
|
|
1139
|
+
# Save newly trained models to disk (only if persistence is enabled)
|
|
1140
|
+
if use_model_persistence:
|
|
1141
|
+
if verbose:
|
|
1142
|
+
print(
|
|
1143
|
+
f" Saving {len(targets_to_train)} trained forecasters to disk..."
|
|
1144
|
+
)
|
|
1145
|
+
_save_forecasters(
|
|
1146
|
+
forecasters={t: recursive_forecasters[t] for t in targets_to_train},
|
|
1147
|
+
model_dir=model_dir,
|
|
1148
|
+
verbose=verbose,
|
|
1149
|
+
)
|
|
1150
|
+
else:
|
|
1151
|
+
if verbose:
|
|
1152
|
+
print(" ⚠ Model persistence disabled (weight_func cannot be pickled)")
|
|
886
1153
|
|
|
887
1154
|
if verbose:
|
|
888
|
-
print(f" ✓ Total forecasters
|
|
1155
|
+
print(f" ✓ Total forecasters available: {len(recursive_forecasters)}")
|
|
889
1156
|
|
|
890
1157
|
# ========================================================================
|
|
891
1158
|
# 10. PREDICTION
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
import pandas as pd
|
|
2
|
-
from typing import List, Optional
|
|
3
|
-
from spotforecast2.forecaster.recursive import ForecasterEquivalentDate
|
|
4
|
-
from spotforecast2.data.fetch_data import fetch_data
|
|
5
|
-
from spotforecast2.preprocessing.curate_data import basic_ts_checks
|
|
6
|
-
from spotforecast2.preprocessing.curate_data import agg_and_resample_data
|
|
7
|
-
from spotforecast2.preprocessing.outlier import mark_outliers
|
|
8
|
-
|
|
9
|
-
from spotforecast2.preprocessing.split import split_rel_train_val_test
|
|
10
|
-
from spotforecast2.forecaster.utils import predict_multivariate
|
|
11
|
-
from spotforecast2.preprocessing.curate_data import get_start_end
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
from tqdm.auto import tqdm
|
|
15
|
-
except ImportError: # pragma: no cover - fallback when tqdm is not installed
|
|
16
|
-
tqdm = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def n2n_predict(
|
|
20
|
-
columns: Optional[List[str]] = None,
|
|
21
|
-
forecast_horizon: int = 24,
|
|
22
|
-
contamination: float = 0.01,
|
|
23
|
-
window_size: int = 72,
|
|
24
|
-
verbose: bool = True,
|
|
25
|
-
show_progress: bool = True,
|
|
26
|
-
) -> pd.DataFrame:
|
|
27
|
-
"""
|
|
28
|
-
End-to-end prediction function replicating the workflow from 01_base_predictor combined with fetch_data.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
columns: List of target columns to forecast. If None, uses a default set (defined internally or from data).
|
|
32
|
-
Note: fetch_data supports None to return all columns.
|
|
33
|
-
forecast_horizon: Number of steps to forecast.
|
|
34
|
-
contamination: Contamination factor for outlier detection.
|
|
35
|
-
window_size: Window size for weighting (not fully utilized in main flow but kept for consistency).
|
|
36
|
-
verbose: Whether to print progress logs.
|
|
37
|
-
show_progress: Show progress bar during training and prediction.
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
pd.DataFrame: The multi-output predictions.
|
|
41
|
-
"""
|
|
42
|
-
if columns is not None:
|
|
43
|
-
TARGET = columns
|
|
44
|
-
else:
|
|
45
|
-
TARGET = None
|
|
46
|
-
|
|
47
|
-
if verbose:
|
|
48
|
-
print("--- Starting n2n_predict ---")
|
|
49
|
-
print("Fetching data...")
|
|
50
|
-
|
|
51
|
-
# Fetch data
|
|
52
|
-
data = fetch_data(columns=TARGET)
|
|
53
|
-
|
|
54
|
-
START, END, COV_START, COV_END = get_start_end(
|
|
55
|
-
data=data,
|
|
56
|
-
forecast_horizon=forecast_horizon,
|
|
57
|
-
verbose=verbose,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
basic_ts_checks(data, verbose=verbose)
|
|
61
|
-
|
|
62
|
-
data = agg_and_resample_data(data, verbose=verbose)
|
|
63
|
-
|
|
64
|
-
# --- Outlier Handling ---
|
|
65
|
-
if verbose:
|
|
66
|
-
print("Handling outliers...")
|
|
67
|
-
|
|
68
|
-
# data_old = data.copy() # kept in notebook, maybe useful for debugging but not used logic-wise here
|
|
69
|
-
data, outliers = mark_outliers(
|
|
70
|
-
data, contamination=contamination, random_state=1234, verbose=verbose
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
# --- Missing Data (Imputation) ---
|
|
74
|
-
if verbose:
|
|
75
|
-
print("Imputing missing data...")
|
|
76
|
-
|
|
77
|
-
missing_indices = data.index[data.isnull().any(axis=1)]
|
|
78
|
-
if verbose:
|
|
79
|
-
n_missing = len(missing_indices)
|
|
80
|
-
pct_missing = (n_missing / len(data)) * 100
|
|
81
|
-
print(f"Number of rows with missing values: {n_missing}")
|
|
82
|
-
print(f"Percentage of rows with missing values: {pct_missing:.2f}%")
|
|
83
|
-
|
|
84
|
-
data = data.ffill()
|
|
85
|
-
data = data.bfill()
|
|
86
|
-
|
|
87
|
-
# --- Train, Val, Test Split ---
|
|
88
|
-
if verbose:
|
|
89
|
-
print("Splitting data...")
|
|
90
|
-
data_train, data_val, data_test = split_rel_train_val_test(
|
|
91
|
-
data, perc_train=0.8, perc_val=0.2, verbose=verbose
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
# --- Model Fit ---
|
|
95
|
-
if verbose:
|
|
96
|
-
print("Fitting models...")
|
|
97
|
-
|
|
98
|
-
end_validation = pd.concat([data_train, data_val]).index[-1]
|
|
99
|
-
|
|
100
|
-
baseline_forecasters = {}
|
|
101
|
-
|
|
102
|
-
target_iter = data.columns
|
|
103
|
-
if show_progress and tqdm is not None:
|
|
104
|
-
target_iter = tqdm(data.columns, desc="Training forecasters", unit="model")
|
|
105
|
-
|
|
106
|
-
for target in target_iter:
|
|
107
|
-
forecaster = ForecasterEquivalentDate(offset=pd.DateOffset(days=1), n_offsets=1)
|
|
108
|
-
|
|
109
|
-
forecaster.fit(y=data.loc[:end_validation, target])
|
|
110
|
-
|
|
111
|
-
baseline_forecasters[target] = forecaster
|
|
112
|
-
|
|
113
|
-
if verbose:
|
|
114
|
-
print("✓ Multi-output baseline system trained")
|
|
115
|
-
|
|
116
|
-
# --- Predict ---
|
|
117
|
-
if verbose:
|
|
118
|
-
print("Generating predictions...")
|
|
119
|
-
|
|
120
|
-
predictions = predict_multivariate(
|
|
121
|
-
baseline_forecasters,
|
|
122
|
-
steps_ahead=forecast_horizon,
|
|
123
|
-
show_progress=show_progress,
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
return predictions
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/recursive/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/forecaster/recursive/_warnings.py
RENAMED
|
File without changes
|
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/bayesian_search.py
RENAMED
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/grid_search.py
RENAMED
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/random_search.py
RENAMED
|
File without changes
|
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/split_one_step.py
RENAMED
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/split_ts_cv.py
RENAMED
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/utils_common.py
RENAMED
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/model_selection/utils_metrics.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{spotforecast2-0.0.5 → spotforecast2-0.1.1}/src/spotforecast2/preprocessing/_differentiator.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|