spotforecast2 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2/.DS_Store +0 -0
- spotforecast2/__init__.py +2 -0
- spotforecast2/data/__init__.py +0 -0
- spotforecast2/data/data.py +130 -0
- spotforecast2/data/fetch_data.py +209 -0
- spotforecast2/exceptions.py +681 -0
- spotforecast2/forecaster/.DS_Store +0 -0
- spotforecast2/forecaster/__init__.py +7 -0
- spotforecast2/forecaster/base.py +448 -0
- spotforecast2/forecaster/metrics.py +527 -0
- spotforecast2/forecaster/recursive/__init__.py +4 -0
- spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
- spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
- spotforecast2/forecaster/recursive/_warnings.py +15 -0
- spotforecast2/forecaster/utils.py +954 -0
- spotforecast2/model_selection/__init__.py +5 -0
- spotforecast2/model_selection/bayesian_search.py +453 -0
- spotforecast2/model_selection/grid_search.py +314 -0
- spotforecast2/model_selection/random_search.py +151 -0
- spotforecast2/model_selection/split_base.py +357 -0
- spotforecast2/model_selection/split_one_step.py +245 -0
- spotforecast2/model_selection/split_ts_cv.py +634 -0
- spotforecast2/model_selection/utils_common.py +718 -0
- spotforecast2/model_selection/utils_metrics.py +103 -0
- spotforecast2/model_selection/validation.py +685 -0
- spotforecast2/preprocessing/__init__.py +30 -0
- spotforecast2/preprocessing/_binner.py +378 -0
- spotforecast2/preprocessing/_common.py +123 -0
- spotforecast2/preprocessing/_differentiator.py +123 -0
- spotforecast2/preprocessing/_rolling.py +136 -0
- spotforecast2/preprocessing/curate_data.py +254 -0
- spotforecast2/preprocessing/imputation.py +92 -0
- spotforecast2/preprocessing/outlier.py +114 -0
- spotforecast2/preprocessing/split.py +139 -0
- spotforecast2/py.typed +0 -0
- spotforecast2/utils/__init__.py +43 -0
- spotforecast2/utils/convert_to_utc.py +44 -0
- spotforecast2/utils/data_transform.py +208 -0
- spotforecast2/utils/forecaster_config.py +344 -0
- spotforecast2/utils/generate_holiday.py +70 -0
- spotforecast2/utils/validation.py +569 -0
- spotforecast2/weather/__init__.py +0 -0
- spotforecast2/weather/weather_client.py +288 -0
- spotforecast2-0.0.1.dist-info/METADATA +47 -0
- spotforecast2-0.0.1.dist-info/RECORD +46 -0
- spotforecast2-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Bayesian hyperparameter search functions for forecasters using Optuna.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Callable
|
|
8
|
+
import warnings
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import optuna
|
|
14
|
+
from optuna.samplers import TPESampler
|
|
15
|
+
except ImportError:
|
|
16
|
+
warnings.warn(
|
|
17
|
+
"optuna is not installed. bayesian_search_forecaster will not work.",
|
|
18
|
+
ImportWarning,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from spotforecast2.model_selection.split_ts_cv import TimeSeriesFold
|
|
22
|
+
from spotforecast2.model_selection.split_one_step import OneStepAheadFold
|
|
23
|
+
from spotforecast2.model_selection.validation import (
|
|
24
|
+
_backtesting_forecaster,
|
|
25
|
+
)
|
|
26
|
+
from spotforecast2.forecaster.metrics import add_y_train_argument, _get_metric
|
|
27
|
+
from spotforecast2.model_selection.utils_common import (
|
|
28
|
+
check_one_step_ahead_input,
|
|
29
|
+
check_backtesting_input,
|
|
30
|
+
select_n_jobs_backtesting,
|
|
31
|
+
)
|
|
32
|
+
from spotforecast2.model_selection.utils_metrics import (
|
|
33
|
+
_calculate_metrics_one_step_ahead,
|
|
34
|
+
)
|
|
35
|
+
from spotforecast2.forecaster.utils import (
|
|
36
|
+
initialize_lags,
|
|
37
|
+
date_to_index_position,
|
|
38
|
+
set_skforecast_warnings,
|
|
39
|
+
)
|
|
40
|
+
from spotforecast2.exceptions import IgnoredArgumentWarning
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def bayesian_search_forecaster(
|
|
44
|
+
forecaster: object,
|
|
45
|
+
y: pd.Series,
|
|
46
|
+
cv: TimeSeriesFold | OneStepAheadFold,
|
|
47
|
+
search_space: Callable,
|
|
48
|
+
metric: str | Callable | list[str | Callable],
|
|
49
|
+
exog: pd.Series | pd.DataFrame | None = None,
|
|
50
|
+
n_trials: int = 10,
|
|
51
|
+
random_state: int = 123,
|
|
52
|
+
return_best: bool = True,
|
|
53
|
+
n_jobs: int | str = "auto",
|
|
54
|
+
verbose: bool = False,
|
|
55
|
+
show_progress: bool = True,
|
|
56
|
+
suppress_warnings: bool = False,
|
|
57
|
+
output_file: str | None = None,
|
|
58
|
+
kwargs_create_study: dict = {},
|
|
59
|
+
kwargs_study_optimize: dict = {},
|
|
60
|
+
) -> tuple[pd.DataFrame, object]:
|
|
61
|
+
"""
|
|
62
|
+
Bayesian hyperparameter optimization for a Forecaster using Optuna.
|
|
63
|
+
|
|
64
|
+
Performs Bayesian hyperparameter search using the Optuna library for a
|
|
65
|
+
Forecaster object. Validation is done using time series backtesting with
|
|
66
|
+
the provided cross-validation strategy.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
forecaster: Forecaster model. Can be ForecasterRecursive, ForecasterDirect,
|
|
70
|
+
or any compatible forecaster class.
|
|
71
|
+
y: Training time series values. Must be a pandas Series with a
|
|
72
|
+
datetime or numeric index.
|
|
73
|
+
cv: Cross-validation strategy with information needed to split the data
|
|
74
|
+
into folds. Must be an instance of TimeSeriesFold or OneStepAheadFold.
|
|
75
|
+
search_space: Callable function with argument `trial` that returns
|
|
76
|
+
a dictionary with parameter names (str) as keys and Trial objects
|
|
77
|
+
from optuna (trial.suggest_float, trial.suggest_int,
|
|
78
|
+
trial.suggest_categorical) as values. Can optionally include 'lags'
|
|
79
|
+
key to search over different lag configurations.
|
|
80
|
+
metric: Metric(s) to quantify model goodness of fit. Can be:
|
|
81
|
+
- str: One of 'mean_squared_error', 'mean_absolute_error',
|
|
82
|
+
'mean_absolute_percentage_error', 'mean_squared_log_error',
|
|
83
|
+
'mean_absolute_scaled_error', 'root_mean_squared_scaled_error'.
|
|
84
|
+
- Callable: Function with arguments (y_true, y_pred) or
|
|
85
|
+
(y_true, y_pred, y_train) that returns a float.
|
|
86
|
+
- list: List containing multiple strings and/or Callables.
|
|
87
|
+
exog: Exogenous variable(s) included as predictors. Must have the
|
|
88
|
+
same number of observations as `y` and aligned so that y[i] is
|
|
89
|
+
regressed on exog[i]. Default is None.
|
|
90
|
+
n_trials: Number of parameter settings sampled during optimization.
|
|
91
|
+
Default is 10.
|
|
92
|
+
random_state: Seed for sampling reproducibility. When passing a custom
|
|
93
|
+
sampler in kwargs_create_study, set the seed within the sampler
|
|
94
|
+
(e.g., {'sampler': TPESampler(seed=145)}). Default is 123.
|
|
95
|
+
return_best: If True, refit the forecaster using the best parameters
|
|
96
|
+
found on the whole dataset at the end. Default is True.
|
|
97
|
+
n_jobs: Number of parallel jobs. If -1, uses all cores. If 'auto',
|
|
98
|
+
uses spotforecast.skforecast.utils.select_n_jobs_backtesting to
|
|
99
|
+
automatically determine the number of jobs. Default is 'auto'.
|
|
100
|
+
verbose: If True, print number of folds used for cross-validation.
|
|
101
|
+
Default is False.
|
|
102
|
+
show_progress: Whether to show an Optuna progress bar during
|
|
103
|
+
optimization. Default is True.
|
|
104
|
+
suppress_warnings: If True, suppress spotforecast warnings during
|
|
105
|
+
hyperparameter search. Default is False.
|
|
106
|
+
output_file: Filename or full path to save results as TSV. If None,
|
|
107
|
+
results are not saved to file. Default is None.
|
|
108
|
+
kwargs_create_study: Additional keyword arguments passed to
|
|
109
|
+
optuna.create_study(). If not specified, direction is set to
|
|
110
|
+
'minimize' and TPESampler(seed=123) is used. Default is {}.
|
|
111
|
+
kwargs_study_optimize: Additional keyword arguments passed to
|
|
112
|
+
study.optimize(). Default is {}.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
tuple[pd.DataFrame, object]: A tuple containing:
|
|
116
|
+
- results: DataFrame with columns 'lags', 'params', metric values,
|
|
117
|
+
and individual parameter columns. Sorted by the first metric.
|
|
118
|
+
- best_trial: Best optimization result as an optuna.FrozenTrial
|
|
119
|
+
object containing the best parameters and metric value.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: If exog length doesn't match y length when return_best=True.
|
|
123
|
+
TypeError: If cv is not an instance of TimeSeriesFold or OneStepAheadFold.
|
|
124
|
+
ValueError: If metric list contains duplicate metric names.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
if return_best and exog is not None and (len(exog) != len(y)):
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"`exog` must have same number of samples as `y`. "
|
|
131
|
+
f"length `exog`: ({len(exog)}), length `y`: ({len(y)})"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
results, best_trial = _bayesian_search_optuna(
|
|
135
|
+
forecaster=forecaster,
|
|
136
|
+
y=y,
|
|
137
|
+
cv=cv,
|
|
138
|
+
exog=exog,
|
|
139
|
+
search_space=search_space,
|
|
140
|
+
metric=metric,
|
|
141
|
+
n_trials=n_trials,
|
|
142
|
+
random_state=random_state,
|
|
143
|
+
return_best=return_best,
|
|
144
|
+
n_jobs=n_jobs,
|
|
145
|
+
verbose=verbose,
|
|
146
|
+
show_progress=show_progress,
|
|
147
|
+
suppress_warnings=suppress_warnings,
|
|
148
|
+
output_file=output_file,
|
|
149
|
+
kwargs_create_study=kwargs_create_study,
|
|
150
|
+
kwargs_study_optimize=kwargs_study_optimize,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return results, best_trial
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _bayesian_search_optuna(
|
|
157
|
+
forecaster: object,
|
|
158
|
+
y: pd.Series,
|
|
159
|
+
cv: TimeSeriesFold | OneStepAheadFold,
|
|
160
|
+
search_space: Callable,
|
|
161
|
+
metric: str | Callable | list[str | Callable],
|
|
162
|
+
exog: pd.Series | pd.DataFrame | None = None,
|
|
163
|
+
n_trials: int = 10,
|
|
164
|
+
random_state: int = 123,
|
|
165
|
+
return_best: bool = True,
|
|
166
|
+
n_jobs: int | str = "auto",
|
|
167
|
+
verbose: bool = False,
|
|
168
|
+
show_progress: bool = True,
|
|
169
|
+
suppress_warnings: bool = False,
|
|
170
|
+
output_file: str | None = None,
|
|
171
|
+
kwargs_create_study: dict = {},
|
|
172
|
+
kwargs_study_optimize: dict = {},
|
|
173
|
+
) -> tuple[pd.DataFrame, object]:
|
|
174
|
+
"""
|
|
175
|
+
Bayesian search for hyperparameters of a Forecaster object using Optuna library.
|
|
176
|
+
|
|
177
|
+
This is the internal implementation function that performs the actual Bayesian
|
|
178
|
+
optimization using Optuna. It handles both TimeSeriesFold (backtesting) and
|
|
179
|
+
OneStepAheadFold validation strategies.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
set_skforecast_warnings(suppress_warnings, action="ignore")
|
|
183
|
+
|
|
184
|
+
forecaster_search = deepcopy(forecaster)
|
|
185
|
+
forecaster_name = type(forecaster_search).__name__
|
|
186
|
+
is_regression = (
|
|
187
|
+
forecaster_search.__spotforecast_tags__["forecaster_task"] == "regression"
|
|
188
|
+
)
|
|
189
|
+
cv_name = type(cv).__name__
|
|
190
|
+
|
|
191
|
+
if cv_name not in ["TimeSeriesFold", "OneStepAheadFold"]:
|
|
192
|
+
raise TypeError(
|
|
193
|
+
f"`cv` must be an instance of `TimeSeriesFold` or `OneStepAheadFold`. "
|
|
194
|
+
f"Got {type(cv)}."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if cv_name == "OneStepAheadFold":
|
|
198
|
+
|
|
199
|
+
check_one_step_ahead_input(
|
|
200
|
+
forecaster=forecaster_search,
|
|
201
|
+
cv=cv,
|
|
202
|
+
metric=metric,
|
|
203
|
+
y=y,
|
|
204
|
+
exog=exog,
|
|
205
|
+
show_progress=show_progress,
|
|
206
|
+
suppress_warnings=False,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
cv = deepcopy(cv)
|
|
210
|
+
initial_train_size = date_to_index_position(
|
|
211
|
+
index=cv._extract_index(y),
|
|
212
|
+
date_input=cv.initial_train_size,
|
|
213
|
+
method="validation",
|
|
214
|
+
date_literal="initial_train_size",
|
|
215
|
+
)
|
|
216
|
+
cv.set_params(
|
|
217
|
+
{
|
|
218
|
+
"initial_train_size": initial_train_size,
|
|
219
|
+
"window_size": forecaster_search.window_size,
|
|
220
|
+
"differentiation": forecaster_search.differentiation_max,
|
|
221
|
+
"verbose": verbose,
|
|
222
|
+
}
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
# TimeSeriesFold
|
|
226
|
+
# NOTE: Add checking input here for consistency with grid_search?
|
|
227
|
+
check_backtesting_input(
|
|
228
|
+
forecaster=forecaster_search,
|
|
229
|
+
cv=cv,
|
|
230
|
+
y=y,
|
|
231
|
+
metric=metric,
|
|
232
|
+
exog=exog,
|
|
233
|
+
n_jobs=n_jobs,
|
|
234
|
+
show_progress=show_progress,
|
|
235
|
+
suppress_warnings=suppress_warnings,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if not isinstance(metric, list):
|
|
239
|
+
metric = [metric]
|
|
240
|
+
metric = [
|
|
241
|
+
_get_metric(metric=m) if isinstance(m, str) else add_y_train_argument(m)
|
|
242
|
+
for m in metric
|
|
243
|
+
]
|
|
244
|
+
metric_dict = {(m if isinstance(m, str) else m.__name__): [] for m in metric}
|
|
245
|
+
|
|
246
|
+
if len(metric_dict) != len(metric):
|
|
247
|
+
raise ValueError("When `metric` is a `list`, each metric name must be unique.")
|
|
248
|
+
|
|
249
|
+
if n_jobs == "auto":
|
|
250
|
+
# Check refit if TimeSeriesFold
|
|
251
|
+
refit = cv.refit if isinstance(cv, TimeSeriesFold) else None
|
|
252
|
+
n_jobs = select_n_jobs_backtesting(forecaster=forecaster_search, refit=refit)
|
|
253
|
+
elif isinstance(cv, TimeSeriesFold) and cv.refit != 1 and n_jobs != 1:
|
|
254
|
+
warnings.warn(
|
|
255
|
+
"If `refit` is an integer other than 1 (intermittent refit). `n_jobs` "
|
|
256
|
+
"is set to 1 to avoid unexpected results during parallelization.",
|
|
257
|
+
IgnoredArgumentWarning,
|
|
258
|
+
)
|
|
259
|
+
n_jobs = 1
|
|
260
|
+
|
|
261
|
+
# Objective function using backtesting_forecaster
|
|
262
|
+
if cv_name == "TimeSeriesFold":
|
|
263
|
+
|
|
264
|
+
def _objective(
|
|
265
|
+
trial,
|
|
266
|
+
search_space=search_space,
|
|
267
|
+
forecaster_search=forecaster_search,
|
|
268
|
+
y=y,
|
|
269
|
+
cv=cv,
|
|
270
|
+
exog=exog,
|
|
271
|
+
metric=metric,
|
|
272
|
+
n_jobs=n_jobs,
|
|
273
|
+
verbose=verbose,
|
|
274
|
+
) -> float:
|
|
275
|
+
|
|
276
|
+
sample = search_space(trial)
|
|
277
|
+
sample_params = {k: v for k, v in sample.items() if k != "lags"}
|
|
278
|
+
forecaster_search.set_params(**sample_params)
|
|
279
|
+
if "lags" in sample:
|
|
280
|
+
forecaster_search.set_lags(sample["lags"])
|
|
281
|
+
|
|
282
|
+
metrics, _ = _backtesting_forecaster(
|
|
283
|
+
forecaster=forecaster_search,
|
|
284
|
+
y=y,
|
|
285
|
+
cv=cv,
|
|
286
|
+
metric=metric,
|
|
287
|
+
exog=exog,
|
|
288
|
+
n_jobs=n_jobs,
|
|
289
|
+
verbose=verbose,
|
|
290
|
+
show_progress=False,
|
|
291
|
+
suppress_warnings=suppress_warnings,
|
|
292
|
+
)
|
|
293
|
+
# _backtesting_forecaster returns DataFrame, we need list of values for the SINGLE result row
|
|
294
|
+
metrics_list = metrics.iloc[0, :].to_list()
|
|
295
|
+
|
|
296
|
+
# Store metrics in the variable `metric_values` defined outside _objective.
|
|
297
|
+
metric_values.append(metrics_list)
|
|
298
|
+
|
|
299
|
+
# Return the first metric (optimized one)
|
|
300
|
+
return metrics_list[0]
|
|
301
|
+
|
|
302
|
+
else:
|
|
303
|
+
|
|
304
|
+
def _objective(
|
|
305
|
+
trial,
|
|
306
|
+
search_space=search_space,
|
|
307
|
+
forecaster_search=forecaster_search,
|
|
308
|
+
y=y,
|
|
309
|
+
cv=cv,
|
|
310
|
+
exog=exog,
|
|
311
|
+
metric=metric,
|
|
312
|
+
) -> float:
|
|
313
|
+
|
|
314
|
+
sample = search_space(trial)
|
|
315
|
+
sample_params = {k: v for k, v in sample.items() if k != "lags"}
|
|
316
|
+
forecaster_search.set_params(**sample_params)
|
|
317
|
+
if "lags" in sample:
|
|
318
|
+
forecaster_search.set_lags(sample["lags"])
|
|
319
|
+
|
|
320
|
+
X_train, y_train, X_test, y_test = (
|
|
321
|
+
forecaster_search._train_test_split_one_step_ahead(
|
|
322
|
+
y=y, initial_train_size=cv.initial_train_size, exog=exog
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
metrics_list = _calculate_metrics_one_step_ahead(
|
|
327
|
+
forecaster=forecaster_search,
|
|
328
|
+
metrics=metric,
|
|
329
|
+
X_train=X_train,
|
|
330
|
+
y_train=y_train,
|
|
331
|
+
X_test=X_test,
|
|
332
|
+
y_test=y_test,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Store all metrics in the variable `metric_values` defined outside _objective.
|
|
336
|
+
metric_values.append(metrics_list)
|
|
337
|
+
|
|
338
|
+
return metrics_list[0]
|
|
339
|
+
|
|
340
|
+
if "direction" not in kwargs_create_study.keys():
|
|
341
|
+
kwargs_create_study["direction"] = "minimize" if is_regression else "maximize"
|
|
342
|
+
|
|
343
|
+
if show_progress:
|
|
344
|
+
kwargs_study_optimize["show_progress_bar"] = True
|
|
345
|
+
|
|
346
|
+
if output_file is not None:
|
|
347
|
+
# Redirect optuna logging to file
|
|
348
|
+
optuna.logging.disable_default_handler()
|
|
349
|
+
logger = logging.getLogger("optuna")
|
|
350
|
+
logger.setLevel(logging.INFO)
|
|
351
|
+
for handler in logger.handlers.copy():
|
|
352
|
+
if isinstance(handler, logging.StreamHandler):
|
|
353
|
+
logger.removeHandler(handler)
|
|
354
|
+
handler = logging.FileHandler(output_file, mode="w")
|
|
355
|
+
logger.addHandler(handler)
|
|
356
|
+
else:
|
|
357
|
+
logging.getLogger("optuna").setLevel(logging.WARNING)
|
|
358
|
+
optuna.logging.disable_default_handler()
|
|
359
|
+
|
|
360
|
+
# `metric_values` will be modified inside _objective function.
|
|
361
|
+
# It is a trick to extract multiple values from _objective since
|
|
362
|
+
# only the optimized value can be returned.
|
|
363
|
+
metric_values = []
|
|
364
|
+
|
|
365
|
+
study = optuna.create_study(**kwargs_create_study)
|
|
366
|
+
|
|
367
|
+
if "sampler" not in kwargs_create_study.keys():
|
|
368
|
+
study.sampler = TPESampler(seed=random_state)
|
|
369
|
+
|
|
370
|
+
with warnings.catch_warnings():
|
|
371
|
+
warnings.filterwarnings(
|
|
372
|
+
"ignore",
|
|
373
|
+
category=UserWarning,
|
|
374
|
+
message="Choices for a categorical distribution should be*",
|
|
375
|
+
)
|
|
376
|
+
study.optimize(_objective, n_trials=n_trials, **kwargs_study_optimize)
|
|
377
|
+
best_trial = study.best_trial
|
|
378
|
+
search_space_best = search_space(best_trial)
|
|
379
|
+
|
|
380
|
+
if output_file is not None:
|
|
381
|
+
handler.close()
|
|
382
|
+
|
|
383
|
+
if search_space_best.keys() != best_trial.params.keys():
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"Some of the key values do not match the search_space key names.\n"
|
|
386
|
+
f" Search Space keys : {list(search_space_best.keys())}\n"
|
|
387
|
+
f" Trial objects keys : {list(best_trial.params.keys())}."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
lags_list = []
|
|
391
|
+
params_list = []
|
|
392
|
+
|
|
393
|
+
# Optuna does not guarantee order of trials in get_trials() matches execution order
|
|
394
|
+
# strictly if parallel? But here n_jobs is for backtesting, study.optimize is sequential usually unless specified.
|
|
395
|
+
# Wait, study.optimize with n_jobs > 1? NO, bayesian_search_forecaster argument `n_jobs` is passed to `_backtesting_forecaster` or `cv` parallelization.
|
|
396
|
+
# Optuna itself is running sequentially here (study.optimize call default n_jobs=1).
|
|
397
|
+
# So `metric_values` append order should match `study.get_trials()` order IF optuna preserves that.
|
|
398
|
+
# Optuna trials are stored in ID order usually.
|
|
399
|
+
# To be safe, we should rely on `trial.number`.
|
|
400
|
+
# But `metric_values` is a list appended during execution.
|
|
401
|
+
# If optuna runs sequentially, it matches trial creation order.
|
|
402
|
+
|
|
403
|
+
for i, trial in enumerate(study.get_trials()):
|
|
404
|
+
estimator_params = {k: v for k, v in trial.params.items() if k != "lags"}
|
|
405
|
+
lags = trial.params.get(
|
|
406
|
+
"lags",
|
|
407
|
+
forecaster_search.lags if hasattr(forecaster_search, "lags") else None,
|
|
408
|
+
)
|
|
409
|
+
params_list.append(estimator_params)
|
|
410
|
+
lags_list.append(lags)
|
|
411
|
+
|
|
412
|
+
# We assume metric_values[i] corresponds to trial i.
|
|
413
|
+
# This is true for sequential optimization.
|
|
414
|
+
for m, m_values in zip(metric, metric_values[i]):
|
|
415
|
+
m_name = m if isinstance(m, str) else m.__name__
|
|
416
|
+
metric_dict[m_name].append(m_values)
|
|
417
|
+
|
|
418
|
+
lags_list = [
|
|
419
|
+
initialize_lags(forecaster_name=forecaster_name, lags=lag)[0]
|
|
420
|
+
for lag in lags_list
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
results = pd.DataFrame({"lags": lags_list, "params": params_list, **metric_dict})
|
|
424
|
+
|
|
425
|
+
results = results.sort_values(
|
|
426
|
+
by=list(metric_dict.keys())[0], ascending=True if is_regression else False
|
|
427
|
+
).reset_index(drop=True)
|
|
428
|
+
results = pd.concat([results, results["params"].apply(pd.Series)], axis=1)
|
|
429
|
+
|
|
430
|
+
if return_best:
|
|
431
|
+
|
|
432
|
+
best_lags = results.loc[0, "lags"]
|
|
433
|
+
best_params = results.loc[0, "params"]
|
|
434
|
+
best_metric = results.loc[0, list(metric_dict.keys())[0]]
|
|
435
|
+
|
|
436
|
+
# NOTE: Here we use the actual forecaster passed by the user
|
|
437
|
+
forecaster.set_lags(best_lags)
|
|
438
|
+
forecaster.set_params(**best_params)
|
|
439
|
+
|
|
440
|
+
forecaster.fit(y=y, exog=exog, store_in_sample_residuals=True)
|
|
441
|
+
|
|
442
|
+
print(
|
|
443
|
+
f"`Forecaster` refitted using the best-found lags and parameters, "
|
|
444
|
+
f"and the whole data set: \n"
|
|
445
|
+
f" Lags: {best_lags} \n"
|
|
446
|
+
f" Parameters: {best_params}\n"
|
|
447
|
+
f" {'Backtesting' if cv_name == 'TimeSeriesFold' else 'One-step-ahead'} "
|
|
448
|
+
f"metric: {best_metric}"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
set_skforecast_warnings(suppress_warnings, action="default")
|
|
452
|
+
|
|
453
|
+
return results, best_trial
|