spotforecast2 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. spotforecast2/.DS_Store +0 -0
  2. spotforecast2/__init__.py +2 -0
  3. spotforecast2/data/__init__.py +0 -0
  4. spotforecast2/data/data.py +130 -0
  5. spotforecast2/data/fetch_data.py +209 -0
  6. spotforecast2/exceptions.py +681 -0
  7. spotforecast2/forecaster/.DS_Store +0 -0
  8. spotforecast2/forecaster/__init__.py +7 -0
  9. spotforecast2/forecaster/base.py +448 -0
  10. spotforecast2/forecaster/metrics.py +527 -0
  11. spotforecast2/forecaster/recursive/__init__.py +4 -0
  12. spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
  13. spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
  14. spotforecast2/forecaster/recursive/_warnings.py +15 -0
  15. spotforecast2/forecaster/utils.py +954 -0
  16. spotforecast2/model_selection/__init__.py +5 -0
  17. spotforecast2/model_selection/bayesian_search.py +453 -0
  18. spotforecast2/model_selection/grid_search.py +314 -0
  19. spotforecast2/model_selection/random_search.py +151 -0
  20. spotforecast2/model_selection/split_base.py +357 -0
  21. spotforecast2/model_selection/split_one_step.py +245 -0
  22. spotforecast2/model_selection/split_ts_cv.py +634 -0
  23. spotforecast2/model_selection/utils_common.py +718 -0
  24. spotforecast2/model_selection/utils_metrics.py +103 -0
  25. spotforecast2/model_selection/validation.py +685 -0
  26. spotforecast2/preprocessing/__init__.py +30 -0
  27. spotforecast2/preprocessing/_binner.py +378 -0
  28. spotforecast2/preprocessing/_common.py +123 -0
  29. spotforecast2/preprocessing/_differentiator.py +123 -0
  30. spotforecast2/preprocessing/_rolling.py +136 -0
  31. spotforecast2/preprocessing/curate_data.py +254 -0
  32. spotforecast2/preprocessing/imputation.py +92 -0
  33. spotforecast2/preprocessing/outlier.py +114 -0
  34. spotforecast2/preprocessing/split.py +139 -0
  35. spotforecast2/py.typed +0 -0
  36. spotforecast2/utils/__init__.py +43 -0
  37. spotforecast2/utils/convert_to_utc.py +44 -0
  38. spotforecast2/utils/data_transform.py +208 -0
  39. spotforecast2/utils/forecaster_config.py +344 -0
  40. spotforecast2/utils/generate_holiday.py +70 -0
  41. spotforecast2/utils/validation.py +569 -0
  42. spotforecast2/weather/__init__.py +0 -0
  43. spotforecast2/weather/weather_client.py +288 -0
  44. spotforecast2-0.0.1.dist-info/METADATA +47 -0
  45. spotforecast2-0.0.1.dist-info/RECORD +46 -0
  46. spotforecast2-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,344 @@
1
+ """
2
+ Forecaster configuration utilities.
3
+
4
+ This module provides functions for initializing and validating forecaster
5
+ configuration parameters like lags and weights.
6
+ """
7
+
8
+ from typing import Any, Union, List, Tuple, Optional
9
+ import numpy as np
10
+
11
+
12
+ def initialize_lags(
13
+ forecaster_name: str, lags: Any
14
+ ) -> Tuple[Optional[np.ndarray], Optional[List[str]], Optional[int]]:
15
+ """
16
+ Validate and normalize lag specification for forecasting.
17
+
18
+ This function converts various lag specifications (int, list, tuple, range, ndarray)
19
+ into a standardized format: sorted numpy array, lag names, and maximum lag value.
20
+
21
+ Args:
22
+ forecaster_name: Name of the forecaster class for error messages.
23
+ lags: Lag specification in one of several formats:
24
+ - int: Creates lags from 1 to lags (e.g., 5 → [1,2,3,4,5])
25
+ - list/tuple/range: Converted to numpy array
26
+ - numpy.ndarray: Validated and used directly
27
+ - None: Returns (None, None, None)
28
+
29
+ Returns:
30
+ Tuple containing:
31
+ - lags: Sorted numpy array of lag values (or None)
32
+ - lags_names: List of lag names like ['lag_1', 'lag_2', ...] (or None)
33
+ - max_lag: Maximum lag value (or None)
34
+
35
+ Raises:
36
+ ValueError: If lags < 1, empty array, or not 1-dimensional.
37
+ TypeError: If lags is not an integer, not in the right format for the forecaster,
38
+ or array contains non-integer values.
39
+
40
+ Examples:
41
+ >>> import numpy as np
42
+ >>> from spotforecast2.utils.forecaster_config import initialize_lags
43
+ >>>
44
+ >>> # Integer input
45
+ >>> lags, names, max_lag = initialize_lags("ForecasterRecursive", 3)
46
+ >>> lags
47
+ array([1, 2, 3])
48
+ >>> names
49
+ ['lag_1', 'lag_2', 'lag_3']
50
+ >>> max_lag
51
+ 3
52
+ >>>
53
+ >>> # List input
54
+ >>> lags, names, max_lag = initialize_lags("ForecasterRecursive", [1, 3, 5])
55
+ >>> lags
56
+ array([1, 3, 5])
57
+ >>> names
58
+ ['lag_1', 'lag_3', 'lag_5']
59
+ >>>
60
+ >>> # Range input
61
+ >>> lags, names, max_lag = initialize_lags("ForecasterRecursive", range(1, 4))
62
+ >>> lags
63
+ array([1, 2, 3])
64
+ >>>
65
+ >>> # None input
66
+ >>> lags, names, max_lag = initialize_lags("ForecasterRecursive", None)
67
+ >>> lags is None
68
+ True
69
+ >>>
70
+ >>> # Invalid: lags < 1
71
+ >>> try:
72
+ ... initialize_lags("ForecasterRecursive", 0)
73
+ ... except ValueError as e:
74
+ ... print("Error: Minimum value of lags allowed is 1")
75
+ Error: Minimum value of lags allowed is 1
76
+ >>>
77
+ >>> # Invalid: negative lags
78
+ >>> try:
79
+ ... initialize_lags("ForecasterRecursive", [1, -2, 3])
80
+ ... except ValueError as e:
81
+ ... print("Error: Minimum value of lags allowed is 1")
82
+ Error: Minimum value of lags allowed is 1
83
+ """
84
+ lags_names = None
85
+ max_lag = None
86
+
87
+ if lags is not None:
88
+ if isinstance(lags, int):
89
+ if lags < 1:
90
+ raise ValueError("Minimum value of lags allowed is 1.")
91
+ lags = np.arange(1, lags + 1)
92
+
93
+ if isinstance(lags, (list, tuple, range)):
94
+ lags = np.array(lags)
95
+
96
+ if isinstance(lags, np.ndarray):
97
+ if lags.size == 0:
98
+ return None, None, None
99
+ if lags.ndim != 1:
100
+ raise ValueError("`lags` must be a 1-dimensional array.")
101
+ if not np.issubdtype(lags.dtype, np.integer):
102
+ raise TypeError("All values in `lags` must be integers.")
103
+ if np.any(lags < 1):
104
+ raise ValueError("Minimum value of lags allowed is 1.")
105
+ else:
106
+ if forecaster_name == "ForecasterDirectMultiVariate":
107
+ raise TypeError(
108
+ f"`lags` argument must be a dict, int, 1d numpy ndarray, range, "
109
+ f"tuple or list. Got {type(lags)}."
110
+ )
111
+ else:
112
+ raise TypeError(
113
+ f"`lags` argument must be an int, 1d numpy ndarray, range, "
114
+ f"tuple or list. Got {type(lags)}."
115
+ )
116
+
117
+ lags = np.sort(lags)
118
+ lags_names = [f"lag_{i}" for i in lags]
119
+ max_lag = int(max(lags))
120
+
121
+ return lags, lags_names, max_lag
122
+
123
+
124
+ def initialize_weights(
125
+ forecaster_name: str, estimator: Any, weight_func: Any, series_weights: Any
126
+ ) -> Tuple[Any, Optional[Union[str, dict]], Any]:
127
+ """
128
+ Validate and initialize weight function configuration for forecasting.
129
+
130
+ This function validates weight_func and series_weights, extracts source code
131
+ from weight functions for serialization, and checks if the estimator supports
132
+ sample weights in its fit method.
133
+
134
+ Args:
135
+ forecaster_name: Name of the forecaster class.
136
+ estimator: Scikit-learn compatible estimator or pipeline.
137
+ weight_func: Weight function specification:
138
+ - Callable: Single weight function
139
+ - dict: Dictionary of weight functions (for MultiSeries forecasters)
140
+ - None: No weighting
141
+ series_weights: Dictionary of series-level weights (for MultiSeries forecasters).
142
+ - dict: Maps series names to weight values
143
+ - None: No series weighting
144
+
145
+ Returns:
146
+ Tuple containing:
147
+ - weight_func: Validated weight function (or None if invalid)
148
+ - source_code_weight_func: Source code of weight function(s) for serialization (or None)
149
+ - series_weights: Validated series weights (or None if invalid)
150
+
151
+ Raises:
152
+ TypeError: If weight_func is not Callable/dict (depending on forecaster type),
153
+ or if series_weights is not a dict.
154
+
155
+ Warnings:
156
+ IgnoredArgumentWarning: If estimator doesn't support sample_weight.
157
+
158
+ Examples:
159
+ >>> import numpy as np
160
+ >>> from sklearn.linear_model import Ridge
161
+ >>> from spotforecast2.utils.forecaster_config import initialize_weights
162
+ >>>
163
+ >>> # Simple weight function
164
+ >>> def custom_weights(index):
165
+ ... return np.ones(len(index))
166
+ >>>
167
+ >>> estimator = Ridge()
168
+ >>> wf, source, sw = initialize_weights(
169
+ ... "ForecasterRecursive", estimator, custom_weights, None
170
+ ... )
171
+ >>> wf is not None
172
+ True
173
+ >>> isinstance(source, str)
174
+ True
175
+ >>>
176
+ >>> # No weight function
177
+ >>> wf, source, sw = initialize_weights(
178
+ ... "ForecasterRecursive", estimator, None, None
179
+ ... )
180
+ >>> wf is None
181
+ True
182
+ >>> source is None
183
+ True
184
+ >>>
185
+ >>> # Invalid type for non-MultiSeries forecaster
186
+ >>> try:
187
+ ... initialize_weights("ForecasterRecursive", estimator, "invalid", None)
188
+ ... except TypeError as e:
189
+ ... print("Error: weight_func must be Callable")
190
+ Error: weight_func must be Callable
191
+ """
192
+ import inspect
193
+ import warnings
194
+ from collections.abc import Callable
195
+
196
+ # Import IgnoredArgumentWarning if available, otherwise define locally
197
+ try:
198
+ from spotforecast2.exceptions import IgnoredArgumentWarning
199
+ except ImportError:
200
+
201
+ class IgnoredArgumentWarning(UserWarning):
202
+ """Warning for ignored arguments."""
203
+
204
+ pass
205
+
206
+ source_code_weight_func = None
207
+
208
+ if weight_func is not None:
209
+ if forecaster_name in ["ForecasterRecursiveMultiSeries"]:
210
+ if not isinstance(weight_func, (Callable, dict)):
211
+ raise TypeError(
212
+ f"Argument `weight_func` must be a Callable or a dict of "
213
+ f"Callables. Got {type(weight_func)}."
214
+ )
215
+ elif not isinstance(weight_func, Callable):
216
+ raise TypeError(
217
+ f"Argument `weight_func` must be a Callable. Got {type(weight_func)}."
218
+ )
219
+
220
+ if isinstance(weight_func, dict):
221
+ source_code_weight_func = {}
222
+ for key in weight_func:
223
+ try:
224
+ source_code_weight_func[key] = inspect.getsource(weight_func[key])
225
+ except OSError:
226
+ source_code_weight_func[key] = (
227
+ f"<source unavailable: {weight_func[key]!r}>"
228
+ )
229
+ else:
230
+ try:
231
+ source_code_weight_func = inspect.getsource(weight_func)
232
+ except OSError:
233
+ source_code_weight_func = f"<source unavailable: {weight_func!r}>"
234
+
235
+ if "sample_weight" not in inspect.signature(estimator.fit).parameters:
236
+ warnings.warn(
237
+ f"Argument `weight_func` is ignored since estimator {estimator} "
238
+ f"does not accept `sample_weight` in its `fit` method.",
239
+ IgnoredArgumentWarning,
240
+ )
241
+ weight_func = None
242
+ source_code_weight_func = None
243
+
244
+ if series_weights is not None:
245
+ if not isinstance(series_weights, dict):
246
+ raise TypeError(
247
+ f"Argument `series_weights` must be a dict of floats or ints."
248
+ f"Got {type(series_weights)}."
249
+ )
250
+ if "sample_weight" not in inspect.signature(estimator.fit).parameters:
251
+ warnings.warn(
252
+ f"Argument `series_weights` is ignored since estimator {estimator} "
253
+ f"does not accept `sample_weight` in its `fit` method.",
254
+ IgnoredArgumentWarning,
255
+ )
256
+ series_weights = None
257
+
258
+ return weight_func, source_code_weight_func, series_weights
259
+
260
+
261
+ def check_select_fit_kwargs(estimator: Any, fit_kwargs: Optional[dict] = None) -> dict:
262
+ """
263
+ Check if `fit_kwargs` is a dict and select only keys used by estimator's `fit`.
264
+
265
+ This function validates that fit_kwargs is a dictionary, warns about unused arguments,
266
+ removes 'sample_weight' (which should be handled via weight_func), and returns
267
+ a dictionary containing only the arguments accepted by the estimator's fit method.
268
+
269
+ Args:
270
+ estimator: Scikit-learn compatible estimator.
271
+ fit_kwargs: Dictionary of arguments to pass to the estimator's fit method.
272
+
273
+ Returns:
274
+ Dictionary with only the arguments accepted by the estimator's fit method.
275
+
276
+ Raises:
277
+ TypeError: If fit_kwargs is not a dict.
278
+
279
+ Warnings:
280
+ IgnoredArgumentWarning: If fit_kwargs contains keys not used by fit method,
281
+ or if 'sample_weight' is present (it gets removed).
282
+
283
+ Examples:
284
+ >>> from sklearn.linear_model import Ridge
285
+ >>> from spotforecast2.utils.forecaster_config import check_select_fit_kwargs
286
+ >>>
287
+ >>> estimator = Ridge()
288
+ >>> # Valid argument for Ridge.fit
289
+ >>> kwargs = {"sample_weight": [1, 1], "invalid_arg": 10}
290
+ >>> # sample_weight is removed (should be passed via weight_func in forecaster)
291
+ >>> # invalid_arg is ignored
292
+ >>> filtered = check_select_fit_kwargs(estimator, kwargs)
293
+ >>> filtered
294
+ {}
295
+ """
296
+ import inspect
297
+ import warnings
298
+
299
+ # Import IgnoredArgumentWarning if available, otherwise define locally
300
+ try:
301
+ from spotforecast2.exceptions import IgnoredArgumentWarning
302
+ except ImportError:
303
+
304
+ class IgnoredArgumentWarning(UserWarning):
305
+ """Warning for ignored arguments."""
306
+
307
+ pass
308
+
309
+ if fit_kwargs is None:
310
+ fit_kwargs = {}
311
+ else:
312
+ if not isinstance(fit_kwargs, dict):
313
+ raise TypeError(
314
+ f"Argument `fit_kwargs` must be a dict. Got {type(fit_kwargs)}."
315
+ )
316
+
317
+ # Get parameters accepted by estimator.fit
318
+ fit_params = inspect.signature(estimator.fit).parameters
319
+
320
+ # Identify unused keys
321
+ non_used_keys = [k for k in fit_kwargs.keys() if k not in fit_params]
322
+ if non_used_keys:
323
+ warnings.warn(
324
+ f"Argument/s {non_used_keys} ignored since they are not used by the "
325
+ f"estimator's `fit` method.",
326
+ IgnoredArgumentWarning,
327
+ )
328
+
329
+ # Handle sample_weight specially
330
+ if "sample_weight" in fit_kwargs.keys():
331
+ warnings.warn(
332
+ "The `sample_weight` argument is ignored. Use `weight_func` to pass "
333
+ "a function that defines the individual weights for each sample "
334
+ "based on its index.",
335
+ IgnoredArgumentWarning,
336
+ )
337
+ del fit_kwargs["sample_weight"]
338
+
339
+ # Select only the keyword arguments allowed by the estimator's `fit` method.
340
+ # Note: We need to re-check keys because sample_weight might have been deleted but it might be in fit_params
341
+ # If it was deleted, it is no longer in fit_kwargs, so this comprehension is safe
342
+ fit_kwargs = {k: v for k, v in fit_kwargs.items() if k in fit_params}
343
+
344
+ return fit_kwargs
@@ -0,0 +1,70 @@
1
+ """Utilities for generating holiday dataframe as covariate."""
2
+
3
+ from typing import Union
4
+ import pandas as pd
5
+ import holidays
6
+
7
+
8
+ def create_holiday_df(
9
+ start: Union[str, pd.Timestamp],
10
+ end: Union[str, pd.Timestamp],
11
+ tz: str = "UTC",
12
+ freq: str = "h",
13
+ country_code: str = "DE",
14
+ state: str = "NW",
15
+ ) -> pd.DataFrame:
16
+ """Create a DataFrame with datetime index and a binary holiday indicator column.
17
+
18
+ Expands daily holidays to all timestamps in the desired frequency.
19
+
20
+ Args:
21
+ start: Start date/datetime.
22
+ end: End date/datetime.
23
+ tz: Timezone to use if not inferred from start/end.
24
+ freq: Frequency of the resulting DataFrame.
25
+ country_code: Country code for holidays (e.g. "DE", "US").
26
+ state: State code for holidays (e.g. "NW", "CA").
27
+
28
+ Returns:
29
+ pd.DataFrame: DataFrame with index covering [start, end] at `freq`,
30
+ and a 'holiday' column (1 if holiday, 0 otherwise).
31
+
32
+ Examples:
33
+ >>> df = create_holiday_df("2023-12-24", "2023-12-26", freq="D")
34
+ >>> df["holiday"].tolist()
35
+ [0, 1, 1]
36
+ """
37
+ # If start/end are Timestamps with timezones, use that timezone instead of
38
+ # the default. This avoids conflicts when timezone-aware Timestamps are
39
+ # passed with a different tz parameter
40
+ inferred_tz = None
41
+ if isinstance(start, pd.Timestamp) and start.tz is not None:
42
+ inferred_tz = str(start.tz)
43
+ elif isinstance(end, pd.Timestamp) and end.tz is not None:
44
+ inferred_tz = str(end.tz)
45
+
46
+ # Use inferred timezone if available, otherwise use the provided tz parameter
47
+ effective_tz = inferred_tz if inferred_tz is not None else tz
48
+
49
+ # When creating date_range with timezone-aware Timestamps, don't pass tz parameter
50
+ # to avoid conflicts - pandas will infer it from the Timestamps
51
+ if inferred_tz is not None:
52
+ full_index = pd.date_range(start=start, end=end, freq=freq)
53
+ daily_index = pd.date_range(start=start, end=end, freq="D")
54
+ else:
55
+ full_index = pd.date_range(start=start, end=end, freq=freq, tz=effective_tz)
56
+ daily_index = pd.date_range(start=start, end=end, freq="D", tz=effective_tz)
57
+
58
+ # Get holidays for the country/state
59
+ country_holidays = holidays.country_holidays(country_code, subdiv=state)
60
+
61
+ # Check each day if it is a holiday
62
+ # We use the date part for lookup
63
+ is_holiday = [1 if date.date() in country_holidays else 0 for date in daily_index]
64
+
65
+ df_holiday = pd.DataFrame({"holiday": is_holiday}, index=daily_index)
66
+
67
+ # Reindex to full frequency and forward fill
68
+ df_full = df_holiday.reindex(full_index, method="ffill").fillna(0).astype(int)
69
+
70
+ return df_full