spotforecast2 0.0.5__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -142,7 +142,7 @@ def prepare_steps_direct(
142
142
  steps: int, list, None, default None
143
143
  Predict n steps. The value of `steps` must be less than or equal to the
144
144
  value of steps defined when initializing the forecaster. Starts at 1.
145
-
145
+
146
146
  - If `int`: Only steps within the range of 1 to int are predicted.
147
147
  - If `list`: List of ints. Only the steps contained in the list
148
148
  are predicted.
@@ -1,43 +1,308 @@
1
+ """
2
+ End-to-end baseline forecasting using equivalent date method.
3
+
4
+ This module provides a complete forecasting pipeline using the ForecasterEquivalentDate
5
+ baseline model. It handles data preparation, outlier detection, imputation, model
6
+ training, and prediction in a single integrated function.
7
+
8
+ Model persistence follows scikit-learn conventions using joblib for efficient
9
+ serialization and deserialization of trained forecasters.
10
+
11
+ Examples:
12
+ Basic usage with default parameters:
13
+
14
+ >>> from spotforecast2.processing.n2n_predict import n2n_predict
15
+ >>> predictions = n2n_predict(forecast_horizon=24, verbose=True)
16
+
17
+ Using cached models:
18
+
19
+ >>> # Load existing models if available, or train new ones
20
+ >>> predictions = n2n_predict(
21
+ ... forecast_horizon=24,
22
+ ... force_train=False,
23
+ ... model_dir="./models",
24
+ ... verbose=True
25
+ ... )
26
+
27
+ Force retraining and update cache:
28
+
29
+ >>> predictions = n2n_predict(
30
+ ... forecast_horizon=24,
31
+ ... force_train=True,
32
+ ... model_dir="./models",
33
+ ... verbose=True
34
+ ... )
35
+ """
36
+
37
+ from pathlib import Path
38
+ from typing import Dict, List, Optional, Tuple, Union
39
+
1
40
  import pandas as pd
2
- from typing import List, Optional
3
41
  from spotforecast2.forecaster.recursive import ForecasterEquivalentDate
4
42
  from spotforecast2.data.fetch_data import fetch_data
5
43
  from spotforecast2.preprocessing.curate_data import basic_ts_checks
6
44
  from spotforecast2.preprocessing.curate_data import agg_and_resample_data
7
45
  from spotforecast2.preprocessing.outlier import mark_outliers
8
-
9
46
  from spotforecast2.preprocessing.split import split_rel_train_val_test
10
47
  from spotforecast2.forecaster.utils import predict_multivariate
11
48
  from spotforecast2.preprocessing.curate_data import get_start_end
12
49
 
50
+ try:
51
+ from joblib import dump, load
52
+ except ImportError:
53
+ raise ImportError("joblib is required. Install with: pip install joblib")
54
+
13
55
  try:
14
56
  from tqdm.auto import tqdm
15
57
  except ImportError: # pragma: no cover - fallback when tqdm is not installed
16
58
  tqdm = None
17
59
 
18
60
 
61
+ # ============================================================================
62
+ # Model Persistence Functions
63
+ # ============================================================================
64
+
65
+
66
+ def _ensure_model_dir(model_dir: Union[str, Path]) -> Path:
67
+ """Ensure model directory exists.
68
+
69
+ Args:
70
+ model_dir: Directory path for model storage.
71
+
72
+ Returns:
73
+ Path: Validated Path object.
74
+
75
+ Raises:
76
+ OSError: If directory cannot be created.
77
+ """
78
+ model_path = Path(model_dir)
79
+ model_path.mkdir(parents=True, exist_ok=True)
80
+ return model_path
81
+
82
+
83
+ def _get_model_filepath(model_dir: Path, target: str) -> Path:
84
+ """Get filepath for a single model.
85
+
86
+ Args:
87
+ model_dir: Directory containing models.
88
+ target: Target variable name.
89
+
90
+ Returns:
91
+ Path: Full filepath for the model.
92
+
93
+ Examples:
94
+ >>> path = _get_model_filepath(Path("./models"), "power")
95
+ >>> str(path)
96
+ './models/forecaster_power.joblib'
97
+ """
98
+ return model_dir / f"forecaster_{target}.joblib"
99
+
100
+
101
+ def _save_forecasters(
102
+ forecasters: Dict[str, object],
103
+ model_dir: Union[str, Path],
104
+ verbose: bool = False,
105
+ ) -> Dict[str, Path]:
106
+ """Save trained forecasters to disk using joblib.
107
+
108
+ Follows scikit-learn persistence conventions using joblib for efficient
109
+ serialization of sklearn-compatible estimators.
110
+
111
+ Args:
112
+ forecasters: Dictionary mapping target names to trained ForecasterEquivalentDate objects.
113
+ model_dir: Directory to save models. Created if it doesn't exist.
114
+ verbose: Print progress messages. Default: False.
115
+
116
+ Returns:
117
+ Dict[str, Path]: Dictionary mapping target names to saved model filepaths.
118
+
119
+ Raises:
120
+ OSError: If models cannot be written to disk.
121
+ TypeError: If forecasters contain non-serializable objects.
122
+
123
+ Examples:
124
+ >>> forecasters = {"power": forecaster_obj}
125
+ >>> paths = _save_forecasters(forecasters, "./models", verbose=True)
126
+ >>> print(paths["power"])
127
+ models/forecaster_power.joblib
128
+ """
129
+ model_path = _ensure_model_dir(model_dir)
130
+ saved_paths = {}
131
+
132
+ for target, forecaster in forecasters.items():
133
+ filepath = _get_model_filepath(model_path, target)
134
+ try:
135
+ dump(forecaster, filepath, compress=3)
136
+ saved_paths[target] = filepath
137
+ if verbose:
138
+ print(f" ✓ Saved forecaster for {target} to {filepath}")
139
+ except Exception as e:
140
+ raise OSError(f"Failed to save model for {target}: {e}")
141
+
142
+ return saved_paths
143
+
144
+
145
+ def _load_forecasters(
146
+ target_columns: List[str],
147
+ model_dir: Union[str, Path],
148
+ verbose: bool = False,
149
+ ) -> Tuple[Dict[str, object], List[str]]:
150
+ """Load trained forecasters from disk using joblib.
151
+
152
+ Attempts to load all forecasters for given targets. Missing models
153
+ are indicated in the return value for selective retraining.
154
+
155
+ Args:
156
+ target_columns: List of target variable names to load.
157
+ model_dir: Directory containing saved models.
158
+ verbose: Print progress messages. Default: False.
159
+
160
+ Returns:
161
+ Tuple[Dict[str, object], List[str]]:
162
+ - forecasters: Dictionary of successfully loaded ForecasterEquivalentDate objects.
163
+ - missing_targets: List of target names without saved models.
164
+
165
+ Examples:
166
+ >>> forecasters, missing = _load_forecasters(
167
+ ... ["power", "energy"],
168
+ ... "./models",
169
+ ... verbose=True
170
+ ... )
171
+ >>> print(missing)
172
+ ['energy']
173
+ """
174
+ model_path = Path(model_dir)
175
+ forecasters = {}
176
+ missing_targets = []
177
+
178
+ for target in target_columns:
179
+ filepath = _get_model_filepath(model_path, target)
180
+ if filepath.exists():
181
+ try:
182
+ forecasters[target] = load(filepath)
183
+ if verbose:
184
+ print(f" ✓ Loaded forecaster for {target} from {filepath}")
185
+ except Exception as e:
186
+ if verbose:
187
+ print(f" ✗ Failed to load {target}: {e}")
188
+ missing_targets.append(target)
189
+ else:
190
+ missing_targets.append(target)
191
+
192
+ return forecasters, missing_targets
193
+
194
+
195
+ def _model_directory_exists(model_dir: Union[str, Path]) -> bool:
196
+ """Check if model directory exists.
197
+
198
+ Args:
199
+ model_dir: Directory path to check.
200
+
201
+ Returns:
202
+ bool: True if directory exists, False otherwise.
203
+ """
204
+ return Path(model_dir).exists()
205
+
206
+
207
+ # ============================================================================
208
+ # Main Function
209
+ # ============================================================================
210
+
211
+
19
212
  def n2n_predict(
20
213
  columns: Optional[List[str]] = None,
21
214
  forecast_horizon: int = 24,
22
215
  contamination: float = 0.01,
23
216
  window_size: int = 72,
217
+ force_train: bool = False,
218
+ model_dir: Union[str, Path] = "./models_baseline",
24
219
  verbose: bool = True,
25
220
  show_progress: bool = True,
26
- ) -> pd.DataFrame:
27
- """
28
- End-to-end prediction function replicating the workflow from 01_base_predictor combined with fetch_data.
221
+ ) -> Tuple[pd.DataFrame, Dict]:
222
+ """End-to-end baseline forecasting using equivalent date method.
223
+
224
+ This function implements a complete forecasting pipeline that:
225
+ 1. Loads and validates target data
226
+ 2. Detects and removes outliers
227
+ 3. Imputes missing values
228
+ 4. Splits into train/validation/test sets
229
+ 5. Trains or loads equivalent date forecasters
230
+ 6. Generates multi-step ahead predictions
231
+
232
+ Models are persisted to disk following scikit-learn conventions using joblib.
233
+ Existing models are reused for prediction unless force_train=True.
29
234
 
30
235
  Args:
31
- columns: List of target columns to forecast. If None, uses a default set (defined internally or from data).
32
- Note: fetch_data supports None to return all columns.
33
- forecast_horizon: Number of steps to forecast.
34
- contamination: Contamination factor for outlier detection.
35
- window_size: Window size for weighting (not fully utilized in main flow but kept for consistency).
36
- verbose: Whether to print progress logs.
37
- show_progress: Show progress bar during training and prediction.
236
+ columns: List of target columns to forecast. If None, uses all available columns.
237
+ Default: None.
238
+ forecast_horizon: Number of time steps to forecast ahead. Default: 24.
239
+ contamination: Contamination parameter for outlier detection. Default: 0.01.
240
+ window_size: Rolling window size for gap detection. Default: 72.
241
+ force_train: Force retraining of all models, ignoring cached models.
242
+ Default: False.
243
+ model_dir: Directory for saving/loading trained models.
244
+ Default: "./models_baseline".
245
+ verbose: Print progress messages. Default: True.
246
+ show_progress: Show progress bar during training and prediction. Default: True.
38
247
 
39
248
  Returns:
40
- pd.DataFrame: The multi-output predictions.
249
+ Tuple containing:
250
+ - predictions: DataFrame with forecast values for each target variable.
251
+ - forecasters: Dictionary of trained ForecasterEquivalentDate objects keyed by target.
252
+
253
+ Raises:
254
+ ValueError: If data validation fails or required data cannot be retrieved.
255
+ ImportError: If required dependencies are not installed.
256
+ OSError: If models cannot be saved to disk.
257
+
258
+ Examples:
259
+ Basic usage with automatic model caching:
260
+
261
+ >>> predictions, forecasters = n2n_predict(
262
+ ... forecast_horizon=24,
263
+ ... verbose=True
264
+ ... )
265
+ >>> print(predictions.shape)
266
+ (24, 11)
267
+
268
+ Load cached models (if available):
269
+
270
+ >>> predictions, forecasters = n2n_predict(
271
+ ... forecast_horizon=24,
272
+ ... force_train=False,
273
+ ... model_dir="./saved_models",
274
+ ... verbose=True
275
+ ... )
276
+
277
+ Force retraining and update cache:
278
+
279
+ >>> predictions, forecasters = n2n_predict(
280
+ ... forecast_horizon=24,
281
+ ... force_train=True,
282
+ ... model_dir="./saved_models",
283
+ ... verbose=True
284
+ ... )
285
+
286
+ With specific target columns:
287
+
288
+ >>> predictions, forecasters = n2n_predict(
289
+ ... columns=["power", "energy"],
290
+ ... forecast_horizon=48,
291
+ ... force_train=False,
292
+ ... verbose=True
293
+ ... )
294
+
295
+ Notes:
296
+ - Trained models are saved to disk using joblib for fast reuse.
297
+ - When force_train=False, existing models are loaded and prediction
298
+ proceeds without retraining. This significantly speeds up prediction
299
+ for repeated calls with the same configuration.
300
+ - The model_dir directory is created automatically if it doesn't exist.
301
+
302
+ Performance Notes:
303
+ - First run: Full training (~2-5 minutes depending on data size)
304
+ - Subsequent runs (force_train=False): Model loading only (~1-2 seconds)
305
+ - Force retrain (force_train=True): Full training again (~2-5 minutes)
41
306
  """
42
307
  if columns is not None:
43
308
  TARGET = columns
@@ -98,20 +363,66 @@ def n2n_predict(
98
363
  end_validation = pd.concat([data_train, data_val]).index[-1]
99
364
 
100
365
  baseline_forecasters = {}
366
+ targets_to_train = list(data.columns)
367
+
368
+ # Attempt to load cached models if force_train=False
369
+ if not force_train and _model_directory_exists(model_dir):
370
+ if verbose:
371
+ print(" Attempting to load cached models...")
372
+ cached_forecasters, missing_targets = _load_forecasters(
373
+ target_columns=list(data.columns),
374
+ model_dir=model_dir,
375
+ verbose=verbose,
376
+ )
377
+ baseline_forecasters.update(cached_forecasters)
378
+ targets_to_train = missing_targets
379
+
380
+ if len(cached_forecasters) == len(data.columns):
381
+ if verbose:
382
+ print(f" ✓ All {len(data.columns)} forecasters loaded from cache")
383
+ elif len(cached_forecasters) > 0:
384
+ if verbose:
385
+ print(
386
+ f" ✓ Loaded {len(cached_forecasters)} forecasters, "
387
+ f"will train {len(targets_to_train)} new ones"
388
+ )
389
+
390
+ # Train missing or forced models
391
+ if len(targets_to_train) > 0:
392
+ if force_train and len(baseline_forecasters) > 0:
393
+ if verbose:
394
+ print(f" Force retraining all {len(data.columns)} forecasters...")
395
+ targets_to_train = list(data.columns)
396
+ baseline_forecasters.clear()
397
+
398
+ target_iter = targets_to_train
399
+ if show_progress and tqdm is not None:
400
+ target_iter = tqdm(
401
+ targets_to_train,
402
+ desc="Training forecasters",
403
+ unit="model",
404
+ )
101
405
 
102
- target_iter = data.columns
103
- if show_progress and tqdm is not None:
104
- target_iter = tqdm(data.columns, desc="Training forecasters", unit="model")
406
+ for target in target_iter:
407
+ forecaster = ForecasterEquivalentDate(
408
+ offset=pd.DateOffset(days=1), n_offsets=1
409
+ )
105
410
 
106
- for target in target_iter:
107
- forecaster = ForecasterEquivalentDate(offset=pd.DateOffset(days=1), n_offsets=1)
411
+ forecaster.fit(y=data.loc[:end_validation, target])
108
412
 
109
- forecaster.fit(y=data.loc[:end_validation, target])
413
+ baseline_forecasters[target] = forecaster
110
414
 
111
- baseline_forecasters[target] = forecaster
415
+ # Save newly trained models to disk
416
+ if verbose:
417
+ print(f" Saving {len(targets_to_train)} trained forecasters to disk...")
418
+ _save_forecasters(
419
+ forecasters={t: baseline_forecasters[t] for t in targets_to_train},
420
+ model_dir=model_dir,
421
+ verbose=verbose,
422
+ )
112
423
 
113
424
  if verbose:
114
- print("✓ Multi-output baseline system trained")
425
+ print(f" Total forecasters available: {len(baseline_forecasters)}")
115
426
 
116
427
  # --- Predict ---
117
428
  if verbose:
@@ -123,4 +434,4 @@ def n2n_predict(
123
434
  show_progress=show_progress,
124
435
  )
125
436
 
126
- return predictions
437
+ return predictions, baseline_forecasters
@@ -6,6 +6,9 @@ recursive forecasters with exogenous variables (weather, holidays, calendar feat
6
6
  It handles data preparation, feature engineering, model training, and prediction
7
7
  in a single integrated function.
8
8
 
9
+ Model persistence follows scikit-learn conventions using joblib for efficient
10
+ serialization and deserialization of trained forecasters.
11
+
9
12
  Examples:
10
13
  Basic usage with default parameters:
11
14
 
@@ -27,8 +30,28 @@ Examples:
27
30
  ... train_ratio=0.75,
28
31
  ... verbose=True
29
32
  ... )
33
+
34
+ Using cached models:
35
+
36
+ >>> # Load existing models if available, or train new ones
37
+ >>> predictions, metadata, forecasters = n2n_predict_with_covariates(
38
+ ... forecast_horizon=24,
39
+ ... force_train=False,
40
+ ... model_dir="./models",
41
+ ... verbose=True
42
+ ... )
43
+
44
+ Force retraining and update cache:
45
+
46
+ >>> predictions, metadata, forecasters = n2n_predict_with_covariates(
47
+ ... forecast_horizon=24,
48
+ ... force_train=True,
49
+ ... model_dir="./models",
50
+ ... verbose=True
51
+ ... )
30
52
  """
31
53
 
54
+ from pathlib import Path
32
55
  from typing import Dict, List, Optional, Tuple, Union
33
56
 
34
57
  import numpy as np
@@ -37,6 +60,11 @@ from astral import LocationInfo
37
60
  from lightgbm import LGBMRegressor
38
61
  from sklearn.preprocessing import PolynomialFeatures
39
62
 
63
+ try:
64
+ from joblib import dump, load
65
+ except ImportError:
66
+ raise ImportError("joblib is required. Install with: pip install joblib")
67
+
40
68
  try:
41
69
  from tqdm.auto import tqdm
42
70
  except ImportError: # pragma: no cover - fallback when tqdm is not installed
@@ -547,6 +575,152 @@ def _merge_data_and_covariates(
547
575
  return data_with_exog, exo_tmp, exo_pred
548
576
 
549
577
 
578
+ # ============================================================================
579
+ # Model Persistence Functions
580
+ # ============================================================================
581
+
582
+
583
+ def _ensure_model_dir(model_dir: Union[str, Path]) -> Path:
584
+ """Ensure model directory exists.
585
+
586
+ Args:
587
+ model_dir: Directory path for model storage.
588
+
589
+ Returns:
590
+ Path: Validated Path object.
591
+
592
+ Raises:
593
+ OSError: If directory cannot be created.
594
+ """
595
+ model_path = Path(model_dir)
596
+ model_path.mkdir(parents=True, exist_ok=True)
597
+ return model_path
598
+
599
+
600
+ def _get_model_filepath(model_dir: Path, target: str) -> Path:
601
+ """Get filepath for a single model.
602
+
603
+ Args:
604
+ model_dir: Directory containing models.
605
+ target: Target variable name.
606
+
607
+ Returns:
608
+ Path: Full filepath for the model.
609
+
610
+ Examples:
611
+ >>> path = _get_model_filepath(Path("./models"), "power")
612
+ >>> str(path)
613
+ './models/forecaster_power.joblib'
614
+ """
615
+ return model_dir / f"forecaster_{target}.joblib"
616
+
617
+
618
+ def _save_forecasters(
619
+ forecasters: Dict[str, object],
620
+ model_dir: Union[str, Path],
621
+ verbose: bool = False,
622
+ ) -> Dict[str, Path]:
623
+ """Save trained forecasters to disk using joblib.
624
+
625
+ Follows scikit-learn persistence conventions using joblib for efficient
626
+ serialization of sklearn-compatible estimators.
627
+
628
+ Args:
629
+ forecasters: Dictionary mapping target names to trained ForecasterRecursive objects.
630
+ model_dir: Directory to save models. Created if it doesn't exist.
631
+ verbose: Print progress messages. Default: False.
632
+
633
+ Returns:
634
+ Dict[str, Path]: Dictionary mapping target names to saved model filepaths.
635
+
636
+ Raises:
637
+ OSError: If models cannot be written to disk.
638
+ TypeError: If forecasters contain non-serializable objects.
639
+
640
+ Examples:
641
+ >>> forecasters = {"power": forecaster_obj}
642
+ >>> paths = _save_forecasters(forecasters, "./models", verbose=True)
643
+ >>> print(paths["power"])
644
+ models/forecaster_power.joblib
645
+ """
646
+ model_path = _ensure_model_dir(model_dir)
647
+ saved_paths = {}
648
+
649
+ for target, forecaster in forecasters.items():
650
+ filepath = _get_model_filepath(model_path, target)
651
+ try:
652
+ dump(forecaster, filepath, compress=3)
653
+ saved_paths[target] = filepath
654
+ if verbose:
655
+ print(f" ✓ Saved forecaster for {target} to {filepath}")
656
+ except Exception as e:
657
+ raise OSError(f"Failed to save model for {target}: {e}")
658
+
659
+ return saved_paths
660
+
661
+
662
+ def _load_forecasters(
663
+ target_columns: List[str],
664
+ model_dir: Union[str, Path],
665
+ verbose: bool = False,
666
+ ) -> Tuple[Dict[str, object], List[str]]:
667
+ """Load trained forecasters from disk using joblib.
668
+
669
+ Attempts to load all forecasters for given targets. Missing models
670
+ are indicated in the return value for selective retraining.
671
+
672
+ Args:
673
+ target_columns: List of target variable names to load.
674
+ model_dir: Directory containing saved models.
675
+ verbose: Print progress messages. Default: False.
676
+
677
+ Returns:
678
+ Tuple[Dict[str, object], List[str]]:
679
+ - forecasters: Dictionary of successfully loaded ForecasterRecursive objects.
680
+ - missing_targets: List of target names without saved models.
681
+
682
+ Examples:
683
+ >>> forecasters, missing = _load_forecasters(
684
+ ... ["power", "energy"],
685
+ ... "./models",
686
+ ... verbose=True
687
+ ... )
688
+ >>> print(missing)
689
+ ['energy']
690
+ """
691
+ model_path = Path(model_dir)
692
+ forecasters = {}
693
+ missing_targets = []
694
+
695
+ for target in target_columns:
696
+ filepath = _get_model_filepath(model_path, target)
697
+ if filepath.exists():
698
+ try:
699
+ forecasters[target] = load(filepath)
700
+ if verbose:
701
+ print(f" ✓ Loaded forecaster for {target} from {filepath}")
702
+ except Exception as e:
703
+ if verbose:
704
+ print(f" ✗ Failed to load {target}: {e}")
705
+ missing_targets.append(target)
706
+ else:
707
+ missing_targets.append(target)
708
+
709
+ return forecasters, missing_targets
710
+
711
+
712
+ def _model_directory_exists(model_dir: Union[str, Path]) -> bool:
713
+ """Check if model directory exists.
714
+
715
+ Args:
716
+ model_dir: Directory path to check.
717
+
718
+ Returns:
719
+ bool: True if directory exists, False otherwise.
720
+ """
721
+ return Path(model_dir).exists()
722
+
723
+
550
724
  # ============================================================================
551
725
  # Main Function
552
726
  # ============================================================================
@@ -567,8 +741,10 @@ def n2n_predict_with_covariates(
567
741
  include_weather_windows: bool = False,
568
742
  include_holiday_features: bool = False,
569
743
  include_poly_features: bool = False,
744
+ force_train: bool = False,
745
+ model_dir: Union[str, Path] = "./forecaster_models",
570
746
  verbose: bool = True,
571
- show_progress: bool = True,
747
+ show_progress: bool = False,
572
748
  ) -> Tuple[pd.DataFrame, Dict, Dict]:
573
749
  """End-to-end recursive forecasting with exogenous covariates.
574
750
 
@@ -580,9 +756,12 @@ def n2n_predict_with_covariates(
580
756
  5. Performs feature engineering (cyclical encoding, interactions)
581
757
  6. Merges target and exogenous data
582
758
  7. Splits into train/validation/test sets
583
- 8. Trains recursive forecasters with sample weighting
759
+ 8. Trains or loads recursive forecasters with sample weighting
584
760
  9. Generates multi-step ahead predictions
585
761
 
762
+ Models are persisted to disk following scikit-learn conventions using joblib.
763
+ Existing models are reused for prediction unless force_train=True.
764
+
586
765
  Args:
587
766
  forecast_horizon: Number of time steps to forecast ahead. Default: 24.
588
767
  contamination: Contamination parameter for outlier detection. Default: 0.01.
@@ -599,8 +778,12 @@ def n2n_predict_with_covariates(
599
778
  include_weather_windows: Include weather window features. Default: False.
600
779
  include_holiday_features: Include holiday features. Default: False.
601
780
  include_poly_features: Include polynomial interaction features. Default: False.
781
+ force_train: Force retraining of all models, ignoring cached models.
782
+ Default: False.
783
+ model_dir: Directory for saving/loading trained models.
784
+ Default: "./models_covariates".
602
785
  verbose: Print progress messages. Default: True.
603
- show_progress: Show progress bar during training. Default: True.
786
+ show_progress: Show progress bar during training. Default: False.
604
787
 
605
788
  Returns:
606
789
  Tuple containing:
@@ -611,9 +794,10 @@ def n2n_predict_with_covariates(
611
794
  Raises:
612
795
  ValueError: If data validation fails or required data cannot be retrieved.
613
796
  ImportError: If required dependencies are not installed.
797
+ OSError: If models cannot be saved to disk.
614
798
 
615
799
  Examples:
616
- Basic usage:
800
+ Basic usage with automatic model caching:
617
801
 
618
802
  >>> predictions, metadata, forecasters = n2n_predict_with_covariates(
619
803
  ... forecast_horizon=24,
@@ -622,6 +806,22 @@ def n2n_predict_with_covariates(
622
806
  >>> print(predictions.shape)
623
807
  (24, 11)
624
808
 
809
+ Load cached models (if available):
810
+
811
+ >>> predictions, metadata, forecasters = n2n_predict_with_covariates(
812
+ ... forecast_horizon=24,
813
+ ... force_train=False,
814
+ ... model_dir="./saved_models"
815
+ ... )
816
+
817
+ Force retraining and update cache:
818
+
819
+ >>> predictions, metadata, forecasters = n2n_predict_with_covariates(
820
+ ... forecast_horizon=24,
821
+ ... force_train=True,
822
+ ... model_dir="./saved_models"
823
+ ... )
824
+
625
825
  Custom location and features:
626
826
 
627
827
  >>> predictions, metadata, forecasters = n2n_predict_with_covariates(
@@ -630,6 +830,7 @@ def n2n_predict_with_covariates(
630
830
  ... longitude=13.4050,
631
831
  ... lags=48,
632
832
  ... include_poly_features=True,
833
+ ... force_train=False,
633
834
  ... verbose=True
634
835
  ... )
635
836
 
@@ -641,6 +842,16 @@ def n2n_predict_with_covariates(
641
842
  near missing data.
642
843
  - Train/validation splits are temporal (80/20 by default).
643
844
  - All features are cast to float32 for memory efficiency.
845
+ - Trained models are saved to disk using joblib for fast reuse.
846
+ - When force_train=False, existing models are loaded and prediction
847
+ proceeds without retraining. This significantly speeds up prediction
848
+ for repeated calls with the same configuration.
849
+ - The model_dir directory is created automatically if it doesn't exist.
850
+
851
+ Performance Notes:
852
+ - First run: Full training (~5-10 minutes depending on data size)
853
+ - Subsequent runs (force_train=False): Model loading only (~1-2 seconds)
854
+ - Force retrain (force_train=True): Full training again (~5-10 minutes)
644
855
  """
645
856
  if verbose:
646
857
  print("=" * 80)
@@ -702,6 +913,10 @@ def n2n_predict_with_covariates(
702
913
  """Return sample weights for given index."""
703
914
  return custom_weights(index, weights_series)
704
915
 
916
+ # Note: weight_func is a local function and cannot be pickled.
917
+ # Model persistence is disabled when using weight_func.
918
+ use_model_persistence = False
919
+
705
920
  # ========================================================================
706
921
  # 4. EXOGENOUS FEATURES ENGINEERING
707
922
  # ========================================================================
@@ -845,11 +1060,13 @@ def n2n_predict_with_covariates(
845
1060
  )
846
1061
 
847
1062
  # ========================================================================
848
- # 9. MODEL TRAINING
1063
+ # 9. MODEL TRAINING OR LOADING
849
1064
  # ========================================================================
850
1065
 
851
1066
  if verbose:
852
- print("\n[8/9] Training recursive forecasters with exogenous variables...")
1067
+ print(
1068
+ "\n[8/9] Loading or training recursive forecasters with exogenous variables..."
1069
+ )
853
1070
 
854
1071
  if estimator is None:
855
1072
  estimator = LGBMRegressor(random_state=1234, verbose=-1)
@@ -857,35 +1074,85 @@ def n2n_predict_with_covariates(
857
1074
  window_features = RollingFeatures(stats=["mean"], window_sizes=window_size)
858
1075
  end_validation = pd.concat([data_train, data_val]).index[-1]
859
1076
 
1077
+ # Attempt to load cached models if force_train=False and persistence is enabled
860
1078
  recursive_forecasters = {}
1079
+ targets_to_train = target_columns
861
1080
 
862
- target_iter = target_columns
863
- if show_progress and tqdm is not None:
864
- target_iter = tqdm(target_columns, desc="Training forecasters", unit="model")
865
-
866
- for target in target_iter:
1081
+ if use_model_persistence and not force_train and _model_directory_exists(model_dir):
867
1082
  if verbose:
868
- print(f" Training forecaster for {target}...")
869
-
870
- forecaster = ForecasterRecursive(
871
- estimator=estimator,
872
- lags=lags,
873
- window_features=window_features,
874
- weight_func=weight_func,
1083
+ print(" Attempting to load cached models...")
1084
+ cached_forecasters, missing_targets = _load_forecasters(
1085
+ target_columns=target_columns,
1086
+ model_dir=model_dir,
1087
+ verbose=verbose,
875
1088
  )
876
-
877
- forecaster.fit(
878
- y=data_with_exog[target].loc[:end_validation].squeeze(),
879
- exog=data_with_exog[exog_features].loc[:end_validation],
880
- )
881
-
882
- recursive_forecasters[target] = forecaster
883
-
884
- if verbose:
885
- print(f" Forecaster trained for {target}")
1089
+ recursive_forecasters.update(cached_forecasters)
1090
+ targets_to_train = missing_targets
1091
+
1092
+ if len(cached_forecasters) == len(target_columns):
1093
+ if verbose:
1094
+ print(f" ✓ All {len(target_columns)} forecasters loaded from cache")
1095
+ elif len(cached_forecasters) > 0:
1096
+ if verbose:
1097
+ print(
1098
+ f" Loaded {len(cached_forecasters)} forecasters, "
1099
+ f"will train {len(targets_to_train)} new ones"
1100
+ )
1101
+
1102
+ # Train missing or forced models
1103
+ if len(targets_to_train) > 0:
1104
+ if force_train and len(recursive_forecasters) > 0:
1105
+ if verbose:
1106
+ print(f" Force retraining all {len(target_columns)} forecasters...")
1107
+ targets_to_train = target_columns
1108
+ recursive_forecasters.clear()
1109
+
1110
+ target_iter = targets_to_train
1111
+ if show_progress and tqdm is not None:
1112
+ target_iter = tqdm(
1113
+ targets_to_train,
1114
+ desc="Training forecasters",
1115
+ unit="model",
1116
+ )
1117
+
1118
+ for target in target_iter:
1119
+ if verbose:
1120
+ print(f" Training forecaster for {target}...")
1121
+
1122
+ forecaster = ForecasterRecursive(
1123
+ estimator=estimator,
1124
+ lags=lags,
1125
+ window_features=window_features,
1126
+ weight_func=weight_func,
1127
+ )
1128
+
1129
+ forecaster.fit(
1130
+ y=data_with_exog[target].loc[:end_validation].squeeze(),
1131
+ exog=data_with_exog[exog_features].loc[:end_validation],
1132
+ )
1133
+
1134
+ recursive_forecasters[target] = forecaster
1135
+
1136
+ if verbose:
1137
+ print(f" ✓ Forecaster trained for {target}")
1138
+
1139
+ # Save newly trained models to disk (only if persistence is enabled)
1140
+ if use_model_persistence:
1141
+ if verbose:
1142
+ print(
1143
+ f" Saving {len(targets_to_train)} trained forecasters to disk..."
1144
+ )
1145
+ _save_forecasters(
1146
+ forecasters={t: recursive_forecasters[t] for t in targets_to_train},
1147
+ model_dir=model_dir,
1148
+ verbose=verbose,
1149
+ )
1150
+ else:
1151
+ if verbose:
1152
+ print(" ⚠ Model persistence disabled (weight_func cannot be pickled)")
886
1153
 
887
1154
  if verbose:
888
- print(f" ✓ Total forecasters trained: {len(recursive_forecasters)}")
1155
+ print(f" ✓ Total forecasters available: {len(recursive_forecasters)}")
889
1156
 
890
1157
  # ========================================================================
891
1158
  # 10. PREDICTION
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: spotforecast2
3
- Version: 0.0.5
3
+ Version: 0.1.1
4
4
  Summary: Forecasting with spot
5
5
  Author: bartzbeielstein
6
6
  Author-email: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com>
@@ -10,7 +10,7 @@ spotforecast2/forecaster/recursive/__init__.py,sha256=YNVxLReLEwSFDasmjXXMSKJqNL
10
10
  spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py,sha256=Mdr-3D1lUivXO07Rp4T8NIgQ2H_2y4IR4BqCwjBtZsw,48261
11
11
  spotforecast2/forecaster/recursive/_forecaster_recursive.py,sha256=oU2zCOI0UaGIn8doLJGphP7jcNL5FF6Y972UCwlxDJI,35739
12
12
  spotforecast2/forecaster/recursive/_warnings.py,sha256=BtZ3UoycywjEQ0ceXe4TL1WEdFcLAi1EnDMvZXHw_U8,325
13
- spotforecast2/forecaster/utils.py,sha256=0cHegSO3WmEXq5Q6NLQGcMgBNASZ6qvbhGC9wg5ZdBA,36600
13
+ spotforecast2/forecaster/utils.py,sha256=eOx_Ayf2WtW3JVUsOWvMzPHQ17ImKLIZZV-hejJArKk,36588
14
14
  spotforecast2/model_selection/__init__.py,sha256=uP60TkgDzs_x5V60rnKanc12S9-yXx2ZLsXsXdqAYEA,208
15
15
  spotforecast2/model_selection/bayesian_search.py,sha256=Vwb_LatDnt22LhIWyzqNhCdlDQ_UgVCyFcXmOxF3Pic,17407
16
16
  spotforecast2/model_selection/grid_search.py,sha256=a5rNEndTXlx1ghT7ws5qs7WM0XBFMqEiK3Q5k7P0EJg,10998
@@ -31,8 +31,8 @@ spotforecast2/preprocessing/imputation.py,sha256=lmH-HumI_QLLm9aMESe_oZq84Axn60w
31
31
  spotforecast2/preprocessing/outlier.py,sha256=jZxAR870QtYner7b4gXk6LLGJw0juLq1VU4CGklYd3c,4208
32
32
  spotforecast2/preprocessing/split.py,sha256=mzzt5ltUZdVzfWtBBTQjp8E2MyqVdWUFtz7nN11urbU,5011
33
33
  spotforecast2/processing/agg_predict.py,sha256=VKlruB0x-eJKokkHyJxR87rZ4m53si3ODbrd0ibPlow,2378
34
- spotforecast2/processing/n2n_predict.py,sha256=Jkf-fMw2RSKY8-0UDc8D0yiiZxiF9s5DyfeRpfx90ks,4060
35
- spotforecast2/processing/n2n_predict_with_covariates.py,sha256=Py9oMSUFv_9Tw5S9TfNF__MzEZNmGaN85lPbg6GBluw,31111
34
+ spotforecast2/processing/n2n_predict.py,sha256=dAj5yXD2JGXSqtl0VDkq0O_8FO_K9BCYG6osbJbWDFg,14494
35
+ spotforecast2/processing/n2n_predict_with_covariates.py,sha256=5a1lYIQE1d-t4ZvSQDoW87G705eiIZxtrCn4w7U2bVw,40420
36
36
  spotforecast2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  spotforecast2/utils/__init__.py,sha256=NrMt_xJLe4rbTFbsbgSQYeREohEOiYG5S-97e6Jj07I,1018
38
38
  spotforecast2/utils/convert_to_utc.py,sha256=hz8mJUHK9jDLUiN5LdNX5l3KZuOKlklyycB4zFdB9Ng,1405
@@ -42,6 +42,6 @@ spotforecast2/utils/generate_holiday.py,sha256=SHaPvPMt-abis95cChHf5ObyPwCTrzJ87
42
42
  spotforecast2/utils/validation.py,sha256=x9ypQzcneDhWJA_piiY4Q3_ogoGd1LTsZ7__MFeG9Fc,21618
43
43
  spotforecast2/weather/__init__.py,sha256=1Jco88pl0deNESgNATin83Nf5i9c58pxN7G-vNiOiu0,120
44
44
  spotforecast2/weather/weather_client.py,sha256=Ec_ywug6uoa71MfXM8RNbXEvtBtBzr-SUS5xq_HKtZE,9837
45
- spotforecast2-0.0.5.dist-info/WHEEL,sha256=5DEXXimM34_d4Gx1AuF9ysMr1_maoEtGKjaILM3s4w4,80
46
- spotforecast2-0.0.5.dist-info/METADATA,sha256=tqA8nKykujUdpZHhanNyF0KF57nS_ZNQJ061FEEeceg,3481
47
- spotforecast2-0.0.5.dist-info/RECORD,,
45
+ spotforecast2-0.1.1.dist-info/WHEEL,sha256=5DEXXimM34_d4Gx1AuF9ysMr1_maoEtGKjaILM3s4w4,80
46
+ spotforecast2-0.1.1.dist-info/METADATA,sha256=TMwW-WMXSoNRVw7oDLU3Ys_8JXhODgvXxXjSeokWaXs,3481
47
+ spotforecast2-0.1.1.dist-info/RECORD,,