spotforecast2 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ from .fetch_data import (
2
+ get_data_home,
3
+ get_cache_home,
4
+ fetch_data,
5
+ fetch_holiday_data,
6
+ fetch_weather_data,
7
+ )
8
+
9
+ __all__ = [
10
+ "get_data_home",
11
+ "get_cache_home",
12
+ "fetch_data",
13
+ "fetch_holiday_data",
14
+ "fetch_weather_data",
15
+ ]
@@ -43,6 +43,61 @@ def get_data_home(data_home: Optional[Union[str, Path]] = None) -> Path:
43
43
  return data_home
44
44
 
45
45
 
46
+ def get_cache_home(cache_home: Optional[Union[str, Path]] = None) -> Path:
47
+ """Return the location where persistent models are to be cached.
48
+
49
+ By default the cache directory is set to a folder named 'spotforecast2_cache' in the
50
+ user home folder. Alternatively, it can be set by the 'SPOTFORECAST2_CACHE' environment
51
+ variable or programmatically by giving an explicit folder path. The '~' symbol is
52
+ expanded to the user home folder. If the folder does not already exist, it is
53
+ automatically created.
54
+
55
+ This directory is used to store pickled trained models for quick reuse across
56
+ forecasting runs, following scikit-learn model persistence conventions.
57
+
58
+ Args:
59
+ cache_home (str or pathlib.Path, optional):
60
+ The path to spotforecast cache directory. If `None`, the default path
61
+ is `~/spotforecast2_cache`.
62
+
63
+ Returns:
64
+ pathlib.Path:
65
+ The path to the spotforecast cache directory.
66
+
67
+ Raises:
68
+ OSError: If the directory cannot be created due to permission issues.
69
+
70
+ Examples:
71
+ >>> from spotforecast2.data.fetch_data import get_cache_home
72
+ >>> cache_dir = get_cache_home()
73
+ >>> cache_dir.name
74
+ 'spotforecast2_cache'
75
+
76
+ >>> # Custom cache location
77
+ >>> import tempfile
78
+ >>> from pathlib import Path
79
+ >>> custom_cache = get_cache_home(Path('/tmp/my_cache'))
80
+ >>> custom_cache.exists()
81
+ True
82
+
83
+ >>> # Using environment variable
84
+ >>> import os
85
+ >>> os.environ['SPOTFORECAST2_CACHE'] = '/var/cache/spotforecast2'
86
+ >>> cache_dir = get_cache_home()
87
+ >>> cache_dir.as_posix()
88
+ '/var/cache/spotforecast2'
89
+ """
90
+ if cache_home is None:
91
+ cache_home = environ.get(
92
+ "SPOTFORECAST2_CACHE", Path.home() / "spotforecast2_cache"
93
+ )
94
+ # Ensure cache_home is a Path() object pointing to an absolute path
95
+ cache_home = Path(cache_home).expanduser().absolute()
96
+ # Create cache directory if it does not exist
97
+ cache_home.mkdir(parents=True, exist_ok=True)
98
+ return cache_home
99
+
100
+
46
101
  def fetch_data(
47
102
  filename: Optional[str] = None,
48
103
  dataframe: Optional[pd.DataFrame] = None,
@@ -56,7 +111,7 @@ def fetch_data(
56
111
 
57
112
  Args:
58
113
  filename (str, optional):
59
- Filename of the CSV file containing the dataset. Must be located in the
114
+ Filename of the CSV file containing the dataset. Must be located in the
60
115
  data home directory. If both filename and dataframe are None, defaults to "data_in.csv".
61
116
  dataframe (pd.DataFrame, optional):
62
117
  A pandas DataFrame to process. If provided, it will be processed with
@@ -87,13 +142,13 @@ def fetch_data(
87
142
  >>> data = fetch_data(columns=["col1", "col2"])
88
143
  >>> data.head()
89
144
  Header1 Header2 Header3
90
-
145
+
91
146
  Load from specific CSV:
92
147
  >>> data = fetch_data(filename="custom_data.csv")
93
-
148
+
94
149
  Process a DataFrame:
95
150
  >>> import pandas as pd
96
- >>> df = pd.DataFrame({"value": [1, 2, 3]},
151
+ >>> df = pd.DataFrame({"value": [1, 2, 3]},
97
152
  ... index=pd.date_range("2024-01-01", periods=3, freq="h"))
98
153
  >>> data = fetch_data(dataframe=df, timezone="Europe/Berlin")
99
154
  >>> data.index.tz
@@ -101,9 +156,11 @@ def fetch_data(
101
156
  """
102
157
  if columns is not None and len(columns) == 0:
103
158
  raise ValueError("columns must be specified and cannot be empty.")
104
-
159
+
105
160
  if filename is not None and dataframe is not None:
106
- raise ValueError("Cannot specify both filename and dataframe. Please provide only one.")
161
+ raise ValueError(
162
+ "Cannot specify both filename and dataframe. Please provide only one."
163
+ )
107
164
 
108
165
  # Process DataFrame if provided
109
166
  if dataframe is not None:
@@ -6,7 +6,7 @@ from .curate_data import (
6
6
  agg_and_resample_data,
7
7
  )
8
8
  from .outlier import mark_outliers, manual_outlier_removal
9
- from .imputation import custom_weights, get_missing_weights
9
+ from .imputation import custom_weights, get_missing_weights, WeightFunction
10
10
  from .split import split_abs_train_val_test, split_rel_train_val_test
11
11
  from ._differentiator import TimeSeriesDifferentiator
12
12
  from ._binner import QuantileBinner
@@ -22,6 +22,7 @@ __all__ = [
22
22
  "manual_outlier_removal",
23
23
  "custom_weights",
24
24
  "get_missing_weights",
25
+ "WeightFunction",
25
26
  "split_abs_train_val_test",
26
27
  "split_rel_train_val_test",
27
28
  "TimeSeriesDifferentiator",
@@ -1,4 +1,57 @@
1
1
  import pandas as pd
2
+ from typing import Union
3
+ import numpy as np
4
+
5
+
6
+ class WeightFunction:
7
+ """Callable class for sample weights that can be pickled.
8
+
9
+ This class wraps the weights_series and provides a callable interface
10
+ compatible with ForecasterRecursive's weight_func parameter. Unlike
11
+ local functions with closures, instances of this class can be pickled
12
+ using standard pickle/joblib.
13
+
14
+ Args:
15
+ weights_series: Series containing weight values for each index.
16
+
17
+ Examples:
18
+ >>> import pandas as pd
19
+ >>> import pickle
20
+ >>> weights = pd.Series([1.0, 0.9, 0.8], index=[0, 1, 2])
21
+ >>> weight_func = WeightFunction(weights)
22
+ >>> weight_func(pd.Index([0, 1]))
23
+ array([1. , 0.9])
24
+ >>> # Can be pickled
25
+ >>> pickled = pickle.dumps(weight_func)
26
+ >>> unpickled = pickle.loads(pickled)
27
+ >>> unpickled(pd.Index([0, 1]))
28
+ array([1. , 0.9])
29
+ """
30
+
31
+ def __init__(self, weights_series: pd.Series):
32
+ """Initialize with a weights series.
33
+
34
+ Args:
35
+ weights_series: Series containing weight values for each index.
36
+ """
37
+ self.weights_series = weights_series
38
+
39
+ def __call__(
40
+ self, index: Union[pd.Index, np.ndarray, list]
41
+ ) -> Union[float, np.ndarray]:
42
+ """Return sample weights for given index.
43
+
44
+ Args:
45
+ index: Index or indices to get weights for.
46
+
47
+ Returns:
48
+ Weight value(s) corresponding to the index.
49
+ """
50
+ return custom_weights(index, self.weights_series)
51
+
52
+ def __repr__(self):
53
+ """String representation."""
54
+ return f"WeightFunction(weights_series with {len(self.weights_series)} entries)"
2
55
 
3
56
 
4
57
  def custom_weights(index, weights_series: pd.Series) -> float:
@@ -215,8 +215,8 @@ def n2n_predict(
215
215
  forecast_horizon: int = 24,
216
216
  contamination: float = 0.01,
217
217
  window_size: int = 72,
218
- force_train: bool = False,
219
- model_dir: Union[str, Path] = "./models_baseline",
218
+ force_train: bool = True,
219
+ model_dir: Optional[Union[str, Path]] = None,
220
220
  verbose: bool = True,
221
221
  show_progress: bool = True,
222
222
  ) -> Tuple[pd.DataFrame, Dict]:
@@ -231,7 +231,7 @@ def n2n_predict(
231
231
  6. Generates multi-step ahead predictions
232
232
 
233
233
  Models are persisted to disk following scikit-learn conventions using joblib.
234
- Existing models are reused for prediction unless force_train=True.
234
+ By default, models are retrained (force_train=True). Set force_train=False to reuse existing cached models.
235
235
 
236
236
  Args:
237
237
  data: Optional DataFrame with target time series data. If None, fetches data automatically.
@@ -242,9 +242,8 @@ def n2n_predict(
242
242
  contamination: Contamination parameter for outlier detection. Default: 0.01.
243
243
  window_size: Rolling window size for gap detection. Default: 72.
244
244
  force_train: Force retraining of all models, ignoring cached models.
245
- Default: False.
246
- model_dir: Directory for saving/loading trained models.
247
- Default: "./models_baseline".
245
+ Default: True.
246
+ model_dir: Directory for saving/loading trained models. If None, uses cache directory from get_cache_home(). Default: None (uses ~/spotforecast2_cache/forecasters).
248
247
  verbose: Print progress messages. Default: True.
249
248
  show_progress: Show progress bar during training and prediction. Default: True.
250
249
 
@@ -301,6 +300,8 @@ def n2n_predict(
301
300
  proceeds without retraining. This significantly speeds up prediction
302
301
  for repeated calls with the same configuration.
303
302
  - The model_dir directory is created automatically if it doesn't exist.
303
+ - Default model_dir uses get_cache_home() which respects the
304
+ SPOTFORECAST2_CACHE environment variable.
304
305
 
305
306
  Performance Notes:
306
307
  - First run: Full training (~2-5 minutes depending on data size)
@@ -315,6 +316,12 @@ def n2n_predict(
315
316
  if verbose:
316
317
  print("--- Starting n2n_predict ---")
317
318
 
319
+ # Set default model_dir if not provided
320
+ if model_dir is None:
321
+ from spotforecast2.data.fetch_data import get_cache_home
322
+
323
+ model_dir = get_cache_home() / "forecasters"
324
+
318
325
  # Handle data input - fetch_data handles both CSV and DataFrame
319
326
  if data is not None:
320
327
  if verbose:
@@ -85,7 +85,7 @@ from spotforecast2.preprocessing.curate_data import (
85
85
  curate_weather,
86
86
  get_start_end,
87
87
  )
88
- from spotforecast2.preprocessing.imputation import custom_weights, get_missing_weights
88
+ from spotforecast2.preprocessing.imputation import get_missing_weights
89
89
  from spotforecast2.preprocessing.outlier import mark_outliers
90
90
  from spotforecast2.preprocessing.split import split_rel_train_val_test
91
91
 
@@ -742,8 +742,8 @@ def n2n_predict_with_covariates(
742
742
  include_weather_windows: bool = False,
743
743
  include_holiday_features: bool = False,
744
744
  include_poly_features: bool = False,
745
- force_train: bool = False,
746
- model_dir: Union[str, Path] = "./forecaster_models",
745
+ force_train: bool = True,
746
+ model_dir: Optional[Union[str, Path]] = None,
747
747
  verbose: bool = True,
748
748
  show_progress: bool = False,
749
749
  ) -> Tuple[pd.DataFrame, Dict, Dict]:
@@ -761,7 +761,7 @@ def n2n_predict_with_covariates(
761
761
  9. Generates multi-step ahead predictions
762
762
 
763
763
  Models are persisted to disk following scikit-learn conventions using joblib.
764
- Existing models are reused for prediction unless force_train=True.
764
+ By default, models are retrained (force_train=True). Set force_train=False to reuse existing cached models.
765
765
 
766
766
  Args:
767
767
  data: Optional DataFrame with target time series data. If None, fetches data automatically.
@@ -782,9 +782,10 @@ def n2n_predict_with_covariates(
782
782
  include_holiday_features: Include holiday features. Default: False.
783
783
  include_poly_features: Include polynomial interaction features. Default: False.
784
784
  force_train: Force retraining of all models, ignoring cached models.
785
- Default: False.
786
- model_dir: Directory for saving/loading trained models.
787
- Default: "./models_covariates".
785
+ Default: True.
786
+ model_dir: Directory for saving/loading trained models. If None, uses the
787
+ spotforecast2 cache directory (~/spotforecast2_cache by default, or
788
+ SPOTFORECAST2_CACHE environment variable). Default: None.
788
789
  verbose: Print progress messages. Default: True.
789
790
  show_progress: Show progress bar during training. Default: False.
790
791
 
@@ -850,12 +851,20 @@ def n2n_predict_with_covariates(
850
851
  proceeds without retraining. This significantly speeds up prediction
851
852
  for repeated calls with the same configuration.
852
853
  - The model_dir directory is created automatically if it doesn't exist.
854
+ - By default, models are cached in ~/spotforecast2_cache, which can be
855
+ customized via the SPOTFORECAST2_CACHE environment variable.
853
856
 
854
857
  Performance Notes:
855
858
  - First run: Full training (~5-10 minutes depending on data size)
856
859
  - Subsequent runs (force_train=False): Model loading only (~1-2 seconds)
857
860
  - Force retrain (force_train=True): Full training again (~5-10 minutes)
858
861
  """
862
+ # Set default model_dir if not provided
863
+ if model_dir is None:
864
+ from spotforecast2.data.fetch_data import get_cache_home
865
+
866
+ model_dir = get_cache_home() / "forecasters"
867
+
859
868
  if verbose:
860
869
  print("=" * 80)
861
870
  print("N2N Recursive Forecasting with Exogenous Covariates")
@@ -877,7 +886,7 @@ def n2n_predict_with_covariates(
877
886
  if verbose:
878
887
  print(" Using provided dataframe...")
879
888
  data = fetch_data(dataframe=data, timezone=timezone)
880
-
889
+
881
890
  target_columns = data.columns.tolist()
882
891
 
883
892
  if verbose:
@@ -921,13 +930,13 @@ def n2n_predict_with_covariates(
921
930
  # Invert missing_mask: True (missing) -> 0 (weight), False (valid) -> 1 (weight)
922
931
  weights_series = (~missing_mask).astype(float)
923
932
 
924
- def weight_func(index):
925
- """Return sample weights for given index."""
926
- return custom_weights(index, weights_series)
933
+ # Use WeightFunction class which is picklable (unlike local functions with closures)
934
+ from spotforecast2.preprocessing import WeightFunction
935
+
936
+ weight_func = WeightFunction(weights_series)
927
937
 
928
- # Note: weight_func is a local function and cannot be pickled.
929
- # Model persistence is disabled when using weight_func.
930
- use_model_persistence = False
938
+ # Model persistence enabled: WeightFunction instances can be pickled
939
+ use_model_persistence = True
931
940
 
932
941
  # ========================================================================
933
942
  # 4. EXOGENOUS FEATURES ENGINEERING
@@ -222,14 +222,18 @@ def initialize_weights(
222
222
  for key in weight_func:
223
223
  try:
224
224
  source_code_weight_func[key] = inspect.getsource(weight_func[key])
225
- except OSError:
225
+ except (OSError, TypeError):
226
+ # OSError: source not available, TypeError: callable class instance
226
227
  source_code_weight_func[key] = (
227
228
  f"<source unavailable: {weight_func[key]!r}>"
228
229
  )
229
230
  else:
230
231
  try:
231
232
  source_code_weight_func = inspect.getsource(weight_func)
232
- except OSError:
233
+ except (OSError, TypeError):
234
+ # OSError: source not available (e.g., built-in, lambda in REPL)
235
+ # TypeError: callable class instance (e.g., WeightFunction)
236
+ # In these cases, we can't get source but the object can still be pickled
233
237
  source_code_weight_func = f"<source unavailable: {weight_func!r}>"
234
238
 
235
239
  if "sample_weight" not in inspect.signature(estimator.fit).parameters:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: spotforecast2
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Forecasting with spot
5
5
  Author: bartzbeielstein
6
6
  Author-email: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com>
@@ -1,7 +1,7 @@
1
1
  spotforecast2/__init__.py,sha256=X9sBx15iz8yqr9iDJcrGJM5nhvnpaczXto4XV_GtfhE,59
2
- spotforecast2/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ spotforecast2/data/__init__.py,sha256=_AEH7sDHbeiDma7tn8XJQAiYxujzH6EkF4X9b8U0Xig,259
3
3
  spotforecast2/data/data.py,sha256=HEgr-FULaqHvuMeKTviOgYyo3GbxpGRTo3ZnmIU9w2Y,4422
4
- spotforecast2/data/fetch_data.py,sha256=N99W-NNTC2hbXmx1FofITsvXJfHj9py4r5Kllf5950Y,8464
4
+ spotforecast2/data/fetch_data.py,sha256=37fKCWjRfc2bkfvIVBRU53ZIwsldrc0JUIOlj66duG4,10562
5
5
  spotforecast2/exceptions.py,sha256=6gOji-3cP-YAisPoxXCcrEEbjTnfPN1YqEhGYhmyZ8Y,20499
6
6
  spotforecast2/forecaster/__init__.py,sha256=BbCOS2ouKcPC9VzcdprllVyqlZIyAWXCOvUAiInxDi4,140
7
7
  spotforecast2/forecaster/base.py,sha256=rXhcjY4AMpyQhkpbtLIA8OOrGEb8fU57SQiyeR9c9DQ,16748
@@ -21,27 +21,27 @@ spotforecast2/model_selection/split_ts_cv.py,sha256=uwACVC5m-cRuCtpA5U46K-tdj0zm
21
21
  spotforecast2/model_selection/utils_common.py,sha256=HKDxm4pLwG0cqhE4t8bzNHFtRa6yn_O7b5ud-nx6b7E,31814
22
22
  spotforecast2/model_selection/utils_metrics.py,sha256=mMVKh03-yAvRjEnZlbg3CsktXNcHo7yiTkI5VMg5wQk,3842
23
23
  spotforecast2/model_selection/validation.py,sha256=nwZATc74tVb992HbefP_sAcJaz8ukV_uqjtVFXaySxs,30038
24
- spotforecast2/preprocessing/__init__.py,sha256=Jk1RJRbPkggw70h4Lay4FY7yQHN9_tjRxzp9QJcF3Oo,828
24
+ spotforecast2/preprocessing/__init__.py,sha256=87koxOzPfn3ueVaIgx6u36gNBh27YRGPIVYwLcF6HGg,866
25
25
  spotforecast2/preprocessing/_binner.py,sha256=EYBOwNSOW85bdLUgQ_qLSq8xpujWJezWkNTIL1jNaYo,13723
26
26
  spotforecast2/preprocessing/_common.py,sha256=aP8EIYIg3iBXnijXByHedGEdcubXu-ciRtEgqdDfO_8,3141
27
27
  spotforecast2/preprocessing/_differentiator.py,sha256=otka_TO1edM3zgp16zOjeSKxa61arbmPPsr96_GfgLI,4646
28
28
  spotforecast2/preprocessing/_rolling.py,sha256=_BUG_aHbOI-1e2ku8AwsJJGl3akTBWjRju2PhclkXso,4202
29
29
  spotforecast2/preprocessing/curate_data.py,sha256=4VV8aYwShyrUc9lqWVx_ckIH-moK0B8ONEMb2i463ag,9603
30
- spotforecast2/preprocessing/imputation.py,sha256=lmH-HumI_QLLm9aMESe_oZq84Axn60woLaMqd_Abw3k,3509
30
+ spotforecast2/preprocessing/imputation.py,sha256=wXHXcIwWb7_XqW9JdBjaRA7NxWhbKWoQyW5z0KkPLd8,5201
31
31
  spotforecast2/preprocessing/outlier.py,sha256=jZxAR870QtYner7b4gXk6LLGJw0juLq1VU4CGklYd3c,4208
32
32
  spotforecast2/preprocessing/split.py,sha256=mzzt5ltUZdVzfWtBBTQjp8E2MyqVdWUFtz7nN11urbU,5011
33
33
  spotforecast2/processing/agg_predict.py,sha256=VKlruB0x-eJKokkHyJxR87rZ4m53si3ODbrd0ibPlow,2378
34
- spotforecast2/processing/n2n_predict.py,sha256=Sr2AFaCZxP-tbsxlEjiSdjBU-mtBiDa_f6rJLEJov64,14912
35
- spotforecast2/processing/n2n_predict_with_covariates.py,sha256=PyB3X1rNb18JBC72YiN12hUg5eSjUAsW4M-atczmCSQ,40914
34
+ spotforecast2/processing/n2n_predict.py,sha256=NZku7xnt9ZLu4V9FMlfbDmU2rzvQPXFYyhvdu2WRtlk,15324
35
+ spotforecast2/processing/n2n_predict_with_covariates.py,sha256=20bHmODXzb2CRSXjxtsqTtKuJ-1_zjo1RKQKjmygYyw,41399
36
36
  spotforecast2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  spotforecast2/utils/__init__.py,sha256=NrMt_xJLe4rbTFbsbgSQYeREohEOiYG5S-97e6Jj07I,1018
38
38
  spotforecast2/utils/convert_to_utc.py,sha256=hz8mJUHK9jDLUiN5LdNX5l3KZuOKlklyycB4zFdB9Ng,1405
39
39
  spotforecast2/utils/data_transform.py,sha256=PhLeZoimM0TLfp34Fp56dQrxlCYNWGVU8h8RZHdZSlo,7294
40
- spotforecast2/utils/forecaster_config.py,sha256=0jchk_9tjxzttN8btWlRBfAjT2bz27JO4CDrpPsC58E,12875
40
+ spotforecast2/utils/forecaster_config.py,sha256=qnpgH97u8ffD3rIgSXyNDl48lgm5FeWplKwrK5tKOJ4,13236
41
41
  spotforecast2/utils/generate_holiday.py,sha256=SHaPvPMt-abis95cChHf5ObyPwCTrzJ87bxffeqZLRc,2707
42
42
  spotforecast2/utils/validation.py,sha256=x9ypQzcneDhWJA_piiY4Q3_ogoGd1LTsZ7__MFeG9Fc,21618
43
43
  spotforecast2/weather/__init__.py,sha256=1Jco88pl0deNESgNATin83Nf5i9c58pxN7G-vNiOiu0,120
44
44
  spotforecast2/weather/weather_client.py,sha256=Ec_ywug6uoa71MfXM8RNbXEvtBtBzr-SUS5xq_HKtZE,9837
45
- spotforecast2-0.2.2.dist-info/WHEEL,sha256=5DEXXimM34_d4Gx1AuF9ysMr1_maoEtGKjaILM3s4w4,80
46
- spotforecast2-0.2.2.dist-info/METADATA,sha256=f5BfMpKyfzwbTOTguKeNPgjbuEu2N0zwMrfcjG82XYo,3481
47
- spotforecast2-0.2.2.dist-info/RECORD,,
45
+ spotforecast2-0.2.3.dist-info/WHEEL,sha256=5DEXXimM34_d4Gx1AuF9ysMr1_maoEtGKjaILM3s4w4,80
46
+ spotforecast2-0.2.3.dist-info/METADATA,sha256=nsr5BzvCVIwKXeRDsPVgpSuEwcQ_-KTm3T72Yz_7tYY,3481
47
+ spotforecast2-0.2.3.dist-info/RECORD,,