spotforecast2 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2/.DS_Store +0 -0
- spotforecast2/__init__.py +2 -0
- spotforecast2/data/__init__.py +0 -0
- spotforecast2/data/data.py +130 -0
- spotforecast2/data/fetch_data.py +209 -0
- spotforecast2/exceptions.py +681 -0
- spotforecast2/forecaster/.DS_Store +0 -0
- spotforecast2/forecaster/__init__.py +7 -0
- spotforecast2/forecaster/base.py +448 -0
- spotforecast2/forecaster/metrics.py +527 -0
- spotforecast2/forecaster/recursive/__init__.py +4 -0
- spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
- spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
- spotforecast2/forecaster/recursive/_warnings.py +15 -0
- spotforecast2/forecaster/utils.py +954 -0
- spotforecast2/model_selection/__init__.py +5 -0
- spotforecast2/model_selection/bayesian_search.py +453 -0
- spotforecast2/model_selection/grid_search.py +314 -0
- spotforecast2/model_selection/random_search.py +151 -0
- spotforecast2/model_selection/split_base.py +357 -0
- spotforecast2/model_selection/split_one_step.py +245 -0
- spotforecast2/model_selection/split_ts_cv.py +634 -0
- spotforecast2/model_selection/utils_common.py +718 -0
- spotforecast2/model_selection/utils_metrics.py +103 -0
- spotforecast2/model_selection/validation.py +685 -0
- spotforecast2/preprocessing/__init__.py +30 -0
- spotforecast2/preprocessing/_binner.py +378 -0
- spotforecast2/preprocessing/_common.py +123 -0
- spotforecast2/preprocessing/_differentiator.py +123 -0
- spotforecast2/preprocessing/_rolling.py +136 -0
- spotforecast2/preprocessing/curate_data.py +254 -0
- spotforecast2/preprocessing/imputation.py +92 -0
- spotforecast2/preprocessing/outlier.py +114 -0
- spotforecast2/preprocessing/split.py +139 -0
- spotforecast2/py.typed +0 -0
- spotforecast2/utils/__init__.py +43 -0
- spotforecast2/utils/convert_to_utc.py +44 -0
- spotforecast2/utils/data_transform.py +208 -0
- spotforecast2/utils/forecaster_config.py +344 -0
- spotforecast2/utils/generate_holiday.py +70 -0
- spotforecast2/utils/validation.py +569 -0
- spotforecast2/weather/__init__.py +0 -0
- spotforecast2/weather/weather_client.py +288 -0
- spotforecast2-0.0.1.dist-info/METADATA +47 -0
- spotforecast2-0.0.1.dist-info/RECORD +46 -0
- spotforecast2-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validation utilities for time series forecasting.
|
|
3
|
+
|
|
4
|
+
This module provides validation functions for time series data and exogenous variables.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Union, List, Tuple, Optional, Dict
|
|
8
|
+
import warnings
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import numpy as np
|
|
11
|
+
from spotforecast2.exceptions import MissingValuesWarning, DataTypeWarning
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def check_y(y: Any, series_id: str = "`y`") -> None:
|
|
15
|
+
"""
|
|
16
|
+
Validate that y is a pandas Series without missing values.
|
|
17
|
+
|
|
18
|
+
This function ensures that the input time series meets the basic requirements
|
|
19
|
+
for forecasting: it must be a pandas Series and must not contain any NaN values.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
y: Time series values to validate.
|
|
23
|
+
series_id: Identifier of the series used in error messages. Defaults to "`y`".
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
TypeError: If y is not a pandas Series.
|
|
27
|
+
ValueError: If y contains missing (NaN) values.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
>>> import pandas as pd
|
|
31
|
+
>>> import numpy as np
|
|
32
|
+
>>> from spotforecast2.utils.validation import check_y
|
|
33
|
+
>>>
|
|
34
|
+
>>> # Valid series
|
|
35
|
+
>>> y = pd.Series([1, 2, 3, 4, 5])
|
|
36
|
+
>>> check_y(y) # No error
|
|
37
|
+
>>>
|
|
38
|
+
>>> # Invalid: not a Series
|
|
39
|
+
>>> try:
|
|
40
|
+
... check_y([1, 2, 3])
|
|
41
|
+
... except TypeError as e:
|
|
42
|
+
... print(f"Error: {e}")
|
|
43
|
+
Error: `y` must be a pandas Series with a DatetimeIndex or a RangeIndex. Found <class 'list'>.
|
|
44
|
+
>>>
|
|
45
|
+
>>> # Invalid: contains NaN
|
|
46
|
+
>>> y_with_nan = pd.Series([1, 2, np.nan, 4])
|
|
47
|
+
>>> try:
|
|
48
|
+
... check_y(y_with_nan)
|
|
49
|
+
... except ValueError as e:
|
|
50
|
+
... print(f"Error: {e}")
|
|
51
|
+
Error: `y` has missing values.
|
|
52
|
+
"""
|
|
53
|
+
if not isinstance(y, pd.Series):
|
|
54
|
+
raise TypeError(
|
|
55
|
+
f"{series_id} must be a pandas Series with a DatetimeIndex or a RangeIndex. "
|
|
56
|
+
f"Found {type(y)}."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if y.isna().to_numpy().any():
|
|
60
|
+
raise ValueError(f"{series_id} has missing values.")
|
|
61
|
+
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def check_exog(
|
|
66
|
+
exog: Union[pd.Series, pd.DataFrame],
|
|
67
|
+
allow_nan: bool = True,
|
|
68
|
+
series_id: str = "`exog`",
|
|
69
|
+
) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Validate that exog is a pandas Series or DataFrame.
|
|
72
|
+
|
|
73
|
+
This function ensures that exogenous variables meet basic requirements:
|
|
74
|
+
- Must be a pandas Series or DataFrame
|
|
75
|
+
- If Series, must have a name
|
|
76
|
+
- Optionally warns if NaN values are present
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
exog: Exogenous variable/s included as predictor/s.
|
|
80
|
+
allow_nan: If True, allows NaN values but issues a warning. If False,
|
|
81
|
+
raises no warning about NaN values. Defaults to True.
|
|
82
|
+
series_id: Identifier of the series used in error messages. Defaults to "`exog`".
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
TypeError: If exog is not a pandas Series or DataFrame.
|
|
86
|
+
ValueError: If exog is a Series without a name.
|
|
87
|
+
|
|
88
|
+
Warnings:
|
|
89
|
+
MissingValuesWarning: If allow_nan=True and exog contains NaN values.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> import pandas as pd
|
|
93
|
+
>>> import numpy as np
|
|
94
|
+
>>> from spotforecast2.utils.validation import check_exog
|
|
95
|
+
>>>
|
|
96
|
+
>>> # Valid DataFrame
|
|
97
|
+
>>> exog_df = pd.DataFrame({"temp": [20, 21, 22], "humidity": [50, 55, 60]})
|
|
98
|
+
>>> check_exog(exog_df) # No error
|
|
99
|
+
>>>
|
|
100
|
+
>>> # Valid Series with name
|
|
101
|
+
>>> exog_series = pd.Series([1, 2, 3], name="temperature")
|
|
102
|
+
>>> check_exog(exog_series) # No error
|
|
103
|
+
>>>
|
|
104
|
+
>>> # Invalid: Series without name
|
|
105
|
+
>>> exog_no_name = pd.Series([1, 2, 3])
|
|
106
|
+
>>> try:
|
|
107
|
+
... check_exog(exog_no_name)
|
|
108
|
+
... except ValueError as e:
|
|
109
|
+
... print(f"Error: {e}")
|
|
110
|
+
Error: When `exog` is a pandas Series, it must have a name.
|
|
111
|
+
>>>
|
|
112
|
+
>>> # Invalid: not a Series/DataFrame
|
|
113
|
+
>>> try:
|
|
114
|
+
... check_exog([1, 2, 3])
|
|
115
|
+
... except TypeError as e:
|
|
116
|
+
... print(f"Error: {e}")
|
|
117
|
+
Error: `exog` must be a pandas Series or DataFrame. Got <class 'list'>.
|
|
118
|
+
"""
|
|
119
|
+
if not isinstance(exog, (pd.Series, pd.DataFrame)):
|
|
120
|
+
raise TypeError(
|
|
121
|
+
f"{series_id} must be a pandas Series or DataFrame. Got {type(exog)}."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if isinstance(exog, pd.Series) and exog.name is None:
|
|
125
|
+
raise ValueError(f"When {series_id} is a pandas Series, it must have a name.")
|
|
126
|
+
|
|
127
|
+
if not allow_nan:
|
|
128
|
+
if exog.isna().to_numpy().any():
|
|
129
|
+
warnings.warn(
|
|
130
|
+
f"{series_id} has missing values. Most machine learning models "
|
|
131
|
+
f"do not allow missing values. Fitting the forecaster may fail.",
|
|
132
|
+
MissingValuesWarning,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def check_exog_dtypes(
|
|
139
|
+
exog: Union[pd.Series, pd.DataFrame],
|
|
140
|
+
call_check_exog: bool = True,
|
|
141
|
+
series_id: str = "`exog`",
|
|
142
|
+
) -> None:
|
|
143
|
+
"""
|
|
144
|
+
Check that exogenous variables have valid data types (int, float, category).
|
|
145
|
+
|
|
146
|
+
This function validates that the exogenous variables (Series or DataFrame)
|
|
147
|
+
contain only supported data types: integer, float, or category. It issues a
|
|
148
|
+
warning if other types (like object/string) are found, as these may cause
|
|
149
|
+
issues with some machine learning estimators.
|
|
150
|
+
|
|
151
|
+
It also strictly enforces that categorical columns must have integer categories.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
exog: Exogenous variables to check.
|
|
155
|
+
call_check_exog: If True, calls check_exog() first to ensure basic validity.
|
|
156
|
+
Defaults to True.
|
|
157
|
+
series_id: Identifier used in warning/error messages. Defaults to "`exog`".
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
TypeError: If categorical columns contain non-integer categories.
|
|
161
|
+
|
|
162
|
+
Warnings:
|
|
163
|
+
DataTypeWarning: If columns with unsupported data types (not int, float, category)
|
|
164
|
+
are found.
|
|
165
|
+
|
|
166
|
+
Examples:
|
|
167
|
+
>>> import pandas as pd
|
|
168
|
+
>>> import numpy as np
|
|
169
|
+
>>> from spotforecast2.utils.validation import check_exog_dtypes
|
|
170
|
+
>>>
|
|
171
|
+
>>> # Valid types (float, int)
|
|
172
|
+
>>> df_valid = pd.DataFrame({
|
|
173
|
+
... "a": [1.0, 2.0, 3.0],
|
|
174
|
+
... "b": [1, 2, 3]
|
|
175
|
+
... })
|
|
176
|
+
>>> check_exog_dtypes(df_valid) # No warning
|
|
177
|
+
>>>
|
|
178
|
+
>>> # Invalid type (object/string)
|
|
179
|
+
>>> df_invalid = pd.DataFrame({
|
|
180
|
+
... "a": [1, 2, 3],
|
|
181
|
+
... "b": ["x", "y", "z"]
|
|
182
|
+
... })
|
|
183
|
+
>>> check_exog_dtypes(df_invalid)
|
|
184
|
+
... # Issues DataTypeWarning about column 'b'
|
|
185
|
+
>>>
|
|
186
|
+
>>> # Valid categorical (with integer categories)
|
|
187
|
+
>>> df_cat = pd.DataFrame({"a": [1, 2, 1]})
|
|
188
|
+
>>> df_cat["a"] = df_cat["a"].astype("category")
|
|
189
|
+
>>> check_exog_dtypes(df_cat) # No warning
|
|
190
|
+
"""
|
|
191
|
+
if call_check_exog:
|
|
192
|
+
check_exog(exog=exog, allow_nan=False, series_id=series_id)
|
|
193
|
+
|
|
194
|
+
valid_dtypes = ("int", "Int", "float", "Float", "uint")
|
|
195
|
+
|
|
196
|
+
if isinstance(exog, pd.DataFrame):
|
|
197
|
+
unique_dtypes = set(exog.dtypes)
|
|
198
|
+
has_invalid_dtype = False
|
|
199
|
+
for dtype in unique_dtypes:
|
|
200
|
+
if isinstance(dtype, pd.CategoricalDtype):
|
|
201
|
+
try:
|
|
202
|
+
is_integer = np.issubdtype(dtype.categories.dtype, np.integer)
|
|
203
|
+
except TypeError:
|
|
204
|
+
# Pandas StringDtype and other non-numpy dtypes will raise TypeError
|
|
205
|
+
is_integer = False
|
|
206
|
+
|
|
207
|
+
if not is_integer:
|
|
208
|
+
raise TypeError(
|
|
209
|
+
"Categorical dtypes in exog must contain only integer values. "
|
|
210
|
+
)
|
|
211
|
+
elif not dtype.name.startswith(valid_dtypes):
|
|
212
|
+
has_invalid_dtype = True
|
|
213
|
+
|
|
214
|
+
if has_invalid_dtype:
|
|
215
|
+
warnings.warn(
|
|
216
|
+
f"{series_id} may contain only `int`, `float` or `category` dtypes. "
|
|
217
|
+
f"Most machine learning models do not allow other types of values. "
|
|
218
|
+
f"Fitting the forecaster may fail.",
|
|
219
|
+
DataTypeWarning,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
dtype_name = str(exog.dtypes)
|
|
224
|
+
if not (dtype_name.startswith(valid_dtypes) or dtype_name == "category"):
|
|
225
|
+
warnings.warn(
|
|
226
|
+
f"{series_id} may contain only `int`, `float` or `category` dtypes. Most "
|
|
227
|
+
f"machine learning models do not allow other types of values. "
|
|
228
|
+
f"Fitting the forecaster may fail.",
|
|
229
|
+
DataTypeWarning,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if isinstance(exog.dtype, pd.CategoricalDtype):
|
|
233
|
+
if not np.issubdtype(exog.cat.categories.dtype, np.integer):
|
|
234
|
+
raise TypeError(
|
|
235
|
+
"Categorical dtypes in exog must contain only integer values. "
|
|
236
|
+
)
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def get_exog_dtypes(exog: Union[pd.Series, pd.DataFrame]) -> Dict[str, type]:
|
|
241
|
+
"""
|
|
242
|
+
Extract and store the data types of exogenous variables.
|
|
243
|
+
|
|
244
|
+
This function returns a dictionary mapping column names to their data types.
|
|
245
|
+
For Series, uses the series name as the key. For DataFrames, uses all column names.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
exog: Exogenous variable/s (Series or DataFrame).
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Dictionary mapping variable names to their pandas dtypes.
|
|
252
|
+
|
|
253
|
+
Examples:
|
|
254
|
+
>>> import pandas as pd
|
|
255
|
+
>>> import numpy as np
|
|
256
|
+
>>> from spotforecast2.utils.validation import get_exog_dtypes
|
|
257
|
+
>>>
|
|
258
|
+
>>> # DataFrame with mixed types
|
|
259
|
+
>>> exog_df = pd.DataFrame({
|
|
260
|
+
... "temp": pd.Series([20.5, 21.3, 22.1], dtype='float64'),
|
|
261
|
+
... "day": pd.Series([1, 2, 3], dtype='int64'),
|
|
262
|
+
... "is_weekend": pd.Series([False, False, True], dtype='bool')
|
|
263
|
+
... })
|
|
264
|
+
>>> dtypes = get_exog_dtypes(exog_df)
|
|
265
|
+
>>> dtypes['temp']
|
|
266
|
+
dtype('float64')
|
|
267
|
+
>>> dtypes['day']
|
|
268
|
+
dtype('int64')
|
|
269
|
+
>>>
|
|
270
|
+
>>> # Series
|
|
271
|
+
>>> exog_series = pd.Series([1.0, 2.0, 3.0], name="temperature", dtype='float64')
|
|
272
|
+
>>> dtypes = get_exog_dtypes(exog_series)
|
|
273
|
+
>>> dtypes
|
|
274
|
+
{'temperature': dtype('float64')}
|
|
275
|
+
"""
|
|
276
|
+
if isinstance(exog, pd.Series):
|
|
277
|
+
exog_dtypes = {exog.name: exog.dtypes}
|
|
278
|
+
else:
|
|
279
|
+
exog_dtypes = exog.dtypes.to_dict()
|
|
280
|
+
|
|
281
|
+
return exog_dtypes
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def check_interval(
|
|
285
|
+
interval: Union[List[float], Tuple[float], None] = None,
|
|
286
|
+
ensure_symmetric_intervals: bool = False,
|
|
287
|
+
quantiles: Union[List[float], Tuple[float], None] = None,
|
|
288
|
+
alpha: Optional[float] = None,
|
|
289
|
+
alpha_literal: Optional[str] = "alpha",
|
|
290
|
+
) -> None:
|
|
291
|
+
"""
|
|
292
|
+
Validate that a confidence interval specification is valid.
|
|
293
|
+
|
|
294
|
+
This function checks that interval values are properly formatted and within
|
|
295
|
+
valid ranges for confidence interval prediction.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
interval: Confidence interval percentiles (0-100 inclusive).
|
|
299
|
+
Should be [lower_bound, upper_bound]. Example: [2.5, 97.5] for 95% interval.
|
|
300
|
+
ensure_symmetric_intervals: If True, ensure intervals are symmetric
|
|
301
|
+
(lower + upper = 100).
|
|
302
|
+
quantiles: Sequence of quantiles (0-1 inclusive). Currently not validated,
|
|
303
|
+
reserved for future use.
|
|
304
|
+
alpha: Confidence level (1-alpha). Currently not validated, reserved for future use.
|
|
305
|
+
alpha_literal: Name used in error messages for alpha parameter.
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
TypeError: If interval is not a list or tuple.
|
|
309
|
+
ValueError: If interval doesn't have exactly 2 values, values out of range (0-100),
|
|
310
|
+
lower >= upper, or intervals not symmetric when required.
|
|
311
|
+
|
|
312
|
+
Examples:
|
|
313
|
+
>>> from spotforecast2.utils.validation import check_interval
|
|
314
|
+
>>>
|
|
315
|
+
>>> # Valid 95% confidence interval
|
|
316
|
+
>>> check_interval(interval=[2.5, 97.5]) # No error
|
|
317
|
+
>>>
|
|
318
|
+
>>> # Valid symmetric interval
|
|
319
|
+
>>> check_interval(interval=[2.5, 97.5], ensure_symmetric_intervals=True) # No error
|
|
320
|
+
>>>
|
|
321
|
+
>>> # Invalid: not symmetric
|
|
322
|
+
>>> try:
|
|
323
|
+
... check_interval(interval=[5, 90], ensure_symmetric_intervals=True)
|
|
324
|
+
... except ValueError as e:
|
|
325
|
+
... print("Error: Interval not symmetric")
|
|
326
|
+
Error: Interval not symmetric
|
|
327
|
+
>>>
|
|
328
|
+
>>> # Invalid: wrong number of values
|
|
329
|
+
>>> try:
|
|
330
|
+
... check_interval(interval=[2.5, 50, 97.5])
|
|
331
|
+
... except ValueError as e:
|
|
332
|
+
... print("Error: Must have exactly 2 values")
|
|
333
|
+
Error: Must have exactly 2 values
|
|
334
|
+
>>>
|
|
335
|
+
>>> # Invalid: out of range
|
|
336
|
+
>>> try:
|
|
337
|
+
... check_interval(interval=[-5, 105])
|
|
338
|
+
... except ValueError as e:
|
|
339
|
+
... print("Error: Values out of range")
|
|
340
|
+
Error: Values out of range
|
|
341
|
+
"""
|
|
342
|
+
if interval is not None:
|
|
343
|
+
if not isinstance(interval, (list, tuple)):
|
|
344
|
+
raise TypeError(
|
|
345
|
+
"`interval` must be a `list` or `tuple`. For example, interval of 95% "
|
|
346
|
+
"should be as `interval = [2.5, 97.5]`."
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if len(interval) != 2:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"`interval` must contain exactly 2 values, respectively the "
|
|
352
|
+
"lower and upper interval bounds. For example, interval of 95% "
|
|
353
|
+
"should be as `interval = [2.5, 97.5]`."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
if (interval[0] < 0.0) or (interval[0] >= 100.0):
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"Lower interval bound ({interval[0]}) must be >= 0 and < 100."
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if (interval[1] <= 0.0) or (interval[1] > 100.0):
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Upper interval bound ({interval[1]}) must be > 0 and <= 100."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if interval[0] >= interval[1]:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
f"Lower interval bound ({interval[0]}) must be less than the "
|
|
369
|
+
f"upper interval bound ({interval[1]})."
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
if ensure_symmetric_intervals and interval[0] + interval[1] != 100:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Interval must be symmetric, the sum of the lower, ({interval[0]}), "
|
|
375
|
+
f"and upper, ({interval[1]}), interval bounds must be equal to "
|
|
376
|
+
f"100. Got {interval[0] + interval[1]}."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
return
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def check_predict_input(
|
|
383
|
+
forecaster_name: str,
|
|
384
|
+
steps: Union[int, List[int]],
|
|
385
|
+
is_fitted: bool,
|
|
386
|
+
exog_in_: bool,
|
|
387
|
+
index_type_: type,
|
|
388
|
+
index_freq_: str,
|
|
389
|
+
window_size: int,
|
|
390
|
+
last_window: Optional[Union[pd.Series, pd.DataFrame]],
|
|
391
|
+
last_window_exog: Optional[Union[pd.Series, pd.DataFrame]] = None,
|
|
392
|
+
exog: Optional[
|
|
393
|
+
Union[pd.Series, pd.DataFrame, Dict[str, Union[pd.Series, pd.DataFrame]]]
|
|
394
|
+
] = None,
|
|
395
|
+
exog_names_in_: Optional[List[str]] = None,
|
|
396
|
+
interval: Optional[List[float]] = None,
|
|
397
|
+
alpha: Optional[float] = None,
|
|
398
|
+
max_step: Optional[int] = None,
|
|
399
|
+
levels: Optional[Union[str, List[str]]] = None,
|
|
400
|
+
levels_forecaster: Optional[Union[str, List[str]]] = None,
|
|
401
|
+
series_names_in_: Optional[List[str]] = None,
|
|
402
|
+
encoding: Optional[str] = None,
|
|
403
|
+
) -> None:
|
|
404
|
+
"""
|
|
405
|
+
Check all inputs of predict method. This is a helper function to validate
|
|
406
|
+
that inputs used in predict method match attributes of a forecaster already
|
|
407
|
+
trained.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
forecaster_name : str
|
|
411
|
+
Forecaster name.
|
|
412
|
+
steps : int, list
|
|
413
|
+
Number of future steps predicted.
|
|
414
|
+
is_fitted: bool
|
|
415
|
+
Tag to identify if the estimator has been fitted (trained).
|
|
416
|
+
exog_in_ : bool
|
|
417
|
+
If the forecaster has been trained using exogenous variable/s.
|
|
418
|
+
index_type_ : type
|
|
419
|
+
Type of index of the input used in training.
|
|
420
|
+
index_freq_ : str
|
|
421
|
+
Frequency of Index of the input used in training.
|
|
422
|
+
window_size: int
|
|
423
|
+
Size of the window needed to create the predictors. It is equal to
|
|
424
|
+
`max_lag`.
|
|
425
|
+
last_window : pandas Series, pandas DataFrame, None
|
|
426
|
+
Values of the series used to create the predictors (lags) need in the
|
|
427
|
+
first iteration of prediction (t + 1).
|
|
428
|
+
last_window_exog : pandas Series, pandas DataFrame, default None
|
|
429
|
+
Values of the exogenous variables aligned with `last_window` in
|
|
430
|
+
ForecasterStats predictions.
|
|
431
|
+
exog : pandas Series, pandas DataFrame, dict, default None
|
|
432
|
+
Exogenous variable/s included as predictor/s.
|
|
433
|
+
exog_names_in_ : list, default None
|
|
434
|
+
Names of the exogenous variables used during training.
|
|
435
|
+
interval : list, tuple, default None
|
|
436
|
+
Confidence of the prediction interval estimated. Sequence of percentiles
|
|
437
|
+
to compute, which must be between 0 and 100 inclusive. For example,
|
|
438
|
+
interval of 95% should be as `interval = [2.5, 97.5]`.
|
|
439
|
+
alpha : float, default None
|
|
440
|
+
The confidence intervals used in ForecasterStats are (1 - alpha) %.
|
|
441
|
+
max_step: int, default None
|
|
442
|
+
Maximum number of steps allowed (`ForecasterDirect` and
|
|
443
|
+
`ForecasterDirectMultiVariate`).
|
|
444
|
+
levels : str, list, default None
|
|
445
|
+
Time series to be predicted (`ForecasterRecursiveMultiSeries`
|
|
446
|
+
and `ForecasterRnn).
|
|
447
|
+
levels_forecaster : str, list, default None
|
|
448
|
+
Time series used as output data of a multiseries problem in a RNN problem
|
|
449
|
+
(`ForecasterRnn`).
|
|
450
|
+
series_names_in_ : list, default None
|
|
451
|
+
Names of the columns used during fit (`ForecasterRecursiveMultiSeries`,
|
|
452
|
+
`ForecasterDirectMultiVariate` and `ForecasterRnn`).
|
|
453
|
+
encoding : str, default None
|
|
454
|
+
Encoding used to identify the different series (`ForecasterRecursiveMultiSeries`).
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
None
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
if not is_fitted:
|
|
461
|
+
raise RuntimeError(
|
|
462
|
+
"This forecaster is not fitted yet. Call `fit` with appropriate "
|
|
463
|
+
"arguments before using `predict`."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
if isinstance(steps, (int, np.integer)) and steps < 1:
|
|
467
|
+
raise ValueError(
|
|
468
|
+
f"`steps` must be an integer greater than or equal to 1. Got {steps}."
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if isinstance(steps, list) and min(steps) < 1:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
f"`steps` must be a list of integers greater than or equal to 1. Got {steps}."
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
if max_step is not None:
|
|
477
|
+
if isinstance(steps, (int, np.integer)):
|
|
478
|
+
if steps > max_step:
|
|
479
|
+
raise ValueError(
|
|
480
|
+
f"The maximum step that can be predicted is {max_step}. "
|
|
481
|
+
f"Got {steps}."
|
|
482
|
+
)
|
|
483
|
+
elif isinstance(steps, list):
|
|
484
|
+
if max(steps) > max_step:
|
|
485
|
+
raise ValueError(
|
|
486
|
+
f"The maximum step that can be predicted is {max_step}. "
|
|
487
|
+
f"Got {max(steps)}."
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
if interval is not None or alpha is not None:
|
|
491
|
+
check_interval(interval=interval, alpha=alpha)
|
|
492
|
+
|
|
493
|
+
if exog_in_ and exog is None:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"Forecaster trained with exogenous variable/s. "
|
|
496
|
+
"Same variable/s must be provided when predicting."
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if not exog_in_ and exog is not None:
|
|
500
|
+
raise ValueError(
|
|
501
|
+
"Forecaster trained without exogenous variable/s. "
|
|
502
|
+
"`exog` must be `None` when predicting."
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
if exog is not None:
|
|
506
|
+
# If exog is a dictionary, it is assumed that it contains the exogenous
|
|
507
|
+
# variables for each series.
|
|
508
|
+
if isinstance(exog, dict):
|
|
509
|
+
# Check that all series have the exogenous variables
|
|
510
|
+
if levels is None and series_names_in_ is not None:
|
|
511
|
+
levels = series_names_in_
|
|
512
|
+
|
|
513
|
+
if isinstance(levels, str):
|
|
514
|
+
levels = [levels]
|
|
515
|
+
|
|
516
|
+
if levels is not None:
|
|
517
|
+
for level in levels:
|
|
518
|
+
if level not in exog:
|
|
519
|
+
raise ValueError(
|
|
520
|
+
f"Exogenous variables for series '{level}' are missing."
|
|
521
|
+
)
|
|
522
|
+
check_exog(
|
|
523
|
+
exog=exog[level],
|
|
524
|
+
allow_nan=False,
|
|
525
|
+
series_id=f"`exog` for series '{level}'",
|
|
526
|
+
)
|
|
527
|
+
check_exog_dtypes(
|
|
528
|
+
exog=exog[level],
|
|
529
|
+
call_check_exog=False,
|
|
530
|
+
series_id=f"`exog` for series '{level}'",
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Check that exogenous variables are the same as used in training
|
|
534
|
+
# Get the name of columns
|
|
535
|
+
if isinstance(exog[level], pd.Series):
|
|
536
|
+
exog_names = [exog[level].name]
|
|
537
|
+
else:
|
|
538
|
+
exog_names = exog[level].columns.tolist()
|
|
539
|
+
|
|
540
|
+
if len(set(exog_names) - set(exog_names_in_)) > 0:
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"Exogenous variables must be: {exog_names_in_}. "
|
|
543
|
+
f"Got {exog_names} for series '{level}'."
|
|
544
|
+
)
|
|
545
|
+
else:
|
|
546
|
+
check_exog(exog=exog, allow_nan=False)
|
|
547
|
+
check_exog_dtypes(exog=exog, call_check_exog=False)
|
|
548
|
+
|
|
549
|
+
# Check that exogenous variables are the same as used in training
|
|
550
|
+
# Get the name of columns
|
|
551
|
+
if isinstance(exog, pd.Series):
|
|
552
|
+
exog_names = [exog.name]
|
|
553
|
+
else:
|
|
554
|
+
exog_names = exog.columns.tolist()
|
|
555
|
+
|
|
556
|
+
if len(set(exog_names) - set(exog_names_in_)) > 0:
|
|
557
|
+
raise ValueError(
|
|
558
|
+
f"Exogenous variables must be: {exog_names_in_}. Got {exog_names}."
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
# Check last_window
|
|
562
|
+
if last_window is not None:
|
|
563
|
+
if isinstance(last_window, pd.DataFrame):
|
|
564
|
+
if last_window.isna().to_numpy().any():
|
|
565
|
+
raise ValueError("`last_window` has missing values.")
|
|
566
|
+
else:
|
|
567
|
+
check_y(last_window, series_id="`last_window`")
|
|
568
|
+
|
|
569
|
+
return
|
|
File without changes
|