autogluon.timeseries 1.3.2b20250712__py3-none-any.whl → 1.4.1b20251116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +98 -72
- autogluon/timeseries/learner.py +19 -18
- autogluon/timeseries/metrics/__init__.py +5 -5
- autogluon/timeseries/metrics/abstract.py +17 -17
- autogluon/timeseries/metrics/point.py +1 -1
- autogluon/timeseries/metrics/quantile.py +2 -2
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +4 -0
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -75
- autogluon/timeseries/models/abstract/tunable.py +6 -6
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +72 -76
- autogluon/timeseries/models/autogluon_tabular/per_step.py +104 -46
- autogluon/timeseries/models/autogluon_tabular/transforms.py +9 -7
- autogluon/timeseries/models/chronos/model.py +115 -78
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +76 -44
- autogluon/timeseries/models/ensemble/__init__.py +29 -2
- autogluon/timeseries/models/ensemble/abstract.py +16 -52
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +247 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +50 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +10 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +87 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +133 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +141 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +41 -0
- autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +8 -18
- autogluon/timeseries/models/ensemble/{greedy.py → weighted/greedy.py} +13 -13
- autogluon/timeseries/models/gluonts/abstract.py +26 -26
- autogluon/timeseries/models/gluonts/dataset.py +4 -4
- autogluon/timeseries/models/gluonts/models.py +27 -12
- autogluon/timeseries/models/local/abstract_local_model.py +14 -14
- autogluon/timeseries/models/local/naive.py +4 -0
- autogluon/timeseries/models/local/npts.py +1 -0
- autogluon/timeseries/models/local/statsforecast.py +30 -14
- autogluon/timeseries/models/multi_window/multi_window_model.py +34 -23
- autogluon/timeseries/models/registry.py +65 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +94 -107
- autogluon/timeseries/regressor.py +31 -27
- autogluon/timeseries/splitter.py +7 -31
- autogluon/timeseries/trainer/__init__.py +3 -0
- autogluon/timeseries/trainer/ensemble_composer.py +250 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/{trainer.py → trainer/trainer.py} +182 -307
- autogluon/timeseries/trainer/utils.py +18 -0
- autogluon/timeseries/transforms/covariate_scaler.py +4 -4
- autogluon/timeseries/transforms/target_scaler.py +14 -14
- autogluon/timeseries/utils/datetime/lags.py +2 -2
- autogluon/timeseries/utils/datetime/time_features.py +2 -2
- autogluon/timeseries/utils/features.py +41 -37
- autogluon/timeseries/utils/forecast.py +5 -5
- autogluon/timeseries/utils/warning_filters.py +3 -1
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.4.1b20251116-py3.9-nspkg.pth +1 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/METADATA +32 -17
- autogluon_timeseries-1.4.1b20251116.dist-info/RECORD +96 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -79
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -530
- autogluon/timeseries/models/presets.py +0 -358
- autogluon.timeseries-1.3.2b20250712-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.3.2b20250712.dist-info/RECORD +0 -71
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/zip-safe +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .hyperparameter_presets import get_hyperparameter_presets
|
|
2
|
+
from .predictor_presets import get_predictor_presets
|
|
2
3
|
|
|
3
|
-
__all__ = ["
|
|
4
|
+
__all__ = ["get_hyperparameter_presets", "get_predictor_presets"]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Any, Union
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_hyperparameter_presets() -> dict[str, dict[str, Union[dict[str, Any], list[dict[str, Any]]]]]:
|
|
5
|
+
return {
|
|
6
|
+
"very_light": {
|
|
7
|
+
"Naive": {},
|
|
8
|
+
"SeasonalNaive": {},
|
|
9
|
+
"ETS": {},
|
|
10
|
+
"Theta": {},
|
|
11
|
+
"RecursiveTabular": {"max_num_samples": 100_000},
|
|
12
|
+
"DirectTabular": {"max_num_samples": 100_000},
|
|
13
|
+
},
|
|
14
|
+
"light": {
|
|
15
|
+
"Naive": {},
|
|
16
|
+
"SeasonalNaive": {},
|
|
17
|
+
"ETS": {},
|
|
18
|
+
"Theta": {},
|
|
19
|
+
"RecursiveTabular": {},
|
|
20
|
+
"DirectTabular": {},
|
|
21
|
+
"TemporalFusionTransformer": {},
|
|
22
|
+
"Chronos": {"model_path": "bolt_small"},
|
|
23
|
+
},
|
|
24
|
+
"light_inference": {
|
|
25
|
+
"SeasonalNaive": {},
|
|
26
|
+
"DirectTabular": {},
|
|
27
|
+
"RecursiveTabular": {},
|
|
28
|
+
"TemporalFusionTransformer": {},
|
|
29
|
+
"PatchTST": {},
|
|
30
|
+
},
|
|
31
|
+
"default": {
|
|
32
|
+
"SeasonalNaive": {},
|
|
33
|
+
"AutoETS": {},
|
|
34
|
+
"NPTS": {},
|
|
35
|
+
"DynamicOptimizedTheta": {},
|
|
36
|
+
"RecursiveTabular": {},
|
|
37
|
+
"DirectTabular": {},
|
|
38
|
+
"TemporalFusionTransformer": {},
|
|
39
|
+
"PatchTST": {},
|
|
40
|
+
"DeepAR": {},
|
|
41
|
+
"Chronos": [
|
|
42
|
+
{
|
|
43
|
+
"ag_args": {"name_suffix": "ZeroShot"},
|
|
44
|
+
"model_path": "bolt_base",
|
|
45
|
+
},
|
|
46
|
+
{
|
|
47
|
+
"ag_args": {"name_suffix": "FineTuned"},
|
|
48
|
+
"model_path": "bolt_small",
|
|
49
|
+
"fine_tune": True,
|
|
50
|
+
"target_scaler": "standard",
|
|
51
|
+
"covariate_regressor": {"model_name": "CAT", "model_hyperparameters": {"iterations": 1_000}},
|
|
52
|
+
},
|
|
53
|
+
],
|
|
54
|
+
"TiDE": {
|
|
55
|
+
"encoder_hidden_dim": 256,
|
|
56
|
+
"decoder_hidden_dim": 256,
|
|
57
|
+
"temporal_hidden_dim": 64,
|
|
58
|
+
"num_batches_per_epoch": 100,
|
|
59
|
+
"lr": 1e-4,
|
|
60
|
+
},
|
|
61
|
+
},
|
|
62
|
+
}
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Preset configurations for autogluon.timeseries Predictors"""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from . import get_hyperparameter_presets
|
|
6
|
+
|
|
7
|
+
TIMESERIES_PRESETS_ALIASES = dict(
|
|
8
|
+
chronos="chronos_small",
|
|
9
|
+
best="best_quality",
|
|
10
|
+
high="high_quality",
|
|
11
|
+
medium="medium_quality",
|
|
12
|
+
bq="best_quality",
|
|
13
|
+
hq="high_quality",
|
|
14
|
+
mq="medium_quality",
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_predictor_presets() -> dict[str, Any]:
|
|
19
|
+
hp_presets = get_hyperparameter_presets()
|
|
20
|
+
|
|
21
|
+
predictor_presets = dict(
|
|
22
|
+
best_quality={"hyperparameters": "default", "num_val_windows": 2},
|
|
23
|
+
high_quality={"hyperparameters": "default"},
|
|
24
|
+
medium_quality={"hyperparameters": "light"},
|
|
25
|
+
fast_training={"hyperparameters": "very_light"},
|
|
26
|
+
# Chronos-Bolt models
|
|
27
|
+
bolt_tiny={
|
|
28
|
+
"hyperparameters": {"Chronos": {"model_path": "bolt_tiny"}},
|
|
29
|
+
"skip_model_selection": True,
|
|
30
|
+
},
|
|
31
|
+
bolt_mini={
|
|
32
|
+
"hyperparameters": {"Chronos": {"model_path": "bolt_mini"}},
|
|
33
|
+
"skip_model_selection": True,
|
|
34
|
+
},
|
|
35
|
+
bolt_small={
|
|
36
|
+
"hyperparameters": {"Chronos": {"model_path": "bolt_small"}},
|
|
37
|
+
"skip_model_selection": True,
|
|
38
|
+
},
|
|
39
|
+
bolt_base={
|
|
40
|
+
"hyperparameters": {"Chronos": {"model_path": "bolt_base"}},
|
|
41
|
+
"skip_model_selection": True,
|
|
42
|
+
},
|
|
43
|
+
# Original Chronos models
|
|
44
|
+
chronos_tiny={
|
|
45
|
+
"hyperparameters": {"Chronos": {"model_path": "tiny"}},
|
|
46
|
+
"skip_model_selection": True,
|
|
47
|
+
},
|
|
48
|
+
chronos_mini={
|
|
49
|
+
"hyperparameters": {"Chronos": {"model_path": "mini"}},
|
|
50
|
+
"skip_model_selection": True,
|
|
51
|
+
},
|
|
52
|
+
chronos_small={
|
|
53
|
+
"hyperparameters": {"Chronos": {"model_path": "small"}},
|
|
54
|
+
"skip_model_selection": True,
|
|
55
|
+
},
|
|
56
|
+
chronos_base={
|
|
57
|
+
"hyperparameters": {"Chronos": {"model_path": "base"}},
|
|
58
|
+
"skip_model_selection": True,
|
|
59
|
+
},
|
|
60
|
+
chronos_large={
|
|
61
|
+
"hyperparameters": {"Chronos": {"model_path": "large", "batch_size": 8}},
|
|
62
|
+
"skip_model_selection": True,
|
|
63
|
+
},
|
|
64
|
+
chronos_ensemble={
|
|
65
|
+
"hyperparameters": {
|
|
66
|
+
"Chronos": {"model_path": "small"},
|
|
67
|
+
**hp_presets["light_inference"],
|
|
68
|
+
}
|
|
69
|
+
},
|
|
70
|
+
chronos_large_ensemble={
|
|
71
|
+
"hyperparameters": {
|
|
72
|
+
"Chronos": {"model_path": "large", "batch_size": 8},
|
|
73
|
+
**hp_presets["light_inference"],
|
|
74
|
+
}
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# update with aliases
|
|
79
|
+
predictor_presets = {
|
|
80
|
+
**predictor_presets,
|
|
81
|
+
**{k: predictor_presets[v].copy() for k, v in TIMESERIES_PRESETS_ALIASES.items()},
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return predictor_presets
|
|
@@ -7,7 +7,7 @@ import reprlib
|
|
|
7
7
|
from collections.abc import Iterable
|
|
8
8
|
from itertools import islice
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import TYPE_CHECKING, Any,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Final, Optional, Type, Union, overload
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import pandas as pd
|
|
@@ -19,11 +19,6 @@ from autogluon.common.loaders import load_pd
|
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
21
21
|
|
|
22
|
-
ITEMID = "item_id"
|
|
23
|
-
TIMESTAMP = "timestamp"
|
|
24
|
-
|
|
25
|
-
IRREGULAR_TIME_INDEX_FREQSTR = "IRREG"
|
|
26
|
-
|
|
27
22
|
|
|
28
23
|
class TimeSeriesDataFrame(pd.DataFrame):
|
|
29
24
|
"""A collection of univariate time series, where each row is identified by an (``item_id``, ``timestamp``) pair.
|
|
@@ -118,9 +113,13 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
118
113
|
|
|
119
114
|
"""
|
|
120
115
|
|
|
121
|
-
index: pd.MultiIndex
|
|
116
|
+
index: pd.MultiIndex # type: ignore
|
|
122
117
|
_metadata = ["_static_features"]
|
|
123
118
|
|
|
119
|
+
IRREGULAR_TIME_INDEX_FREQSTR: Final[str] = "IRREG"
|
|
120
|
+
ITEMID: Final[str] = "item_id"
|
|
121
|
+
TIMESTAMP: Final[str] = "timestamp"
|
|
122
|
+
|
|
124
123
|
def __init__(
|
|
125
124
|
self,
|
|
126
125
|
data: Union[pd.DataFrame, str, Path, Iterable],
|
|
@@ -175,23 +174,27 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
175
174
|
df = df.copy()
|
|
176
175
|
if id_column is not None:
|
|
177
176
|
assert id_column in df.columns, f"Column '{id_column}' not found!"
|
|
178
|
-
if id_column != ITEMID and ITEMID in df.columns:
|
|
179
|
-
logger.warning(
|
|
180
|
-
|
|
181
|
-
|
|
177
|
+
if id_column != cls.ITEMID and cls.ITEMID in df.columns:
|
|
178
|
+
logger.warning(
|
|
179
|
+
f"Renaming existing column '{cls.ITEMID}' -> '__{cls.ITEMID}' to avoid name collisions."
|
|
180
|
+
)
|
|
181
|
+
df.rename(columns={cls.ITEMID: "__" + cls.ITEMID}, inplace=True)
|
|
182
|
+
df.rename(columns={id_column: cls.ITEMID}, inplace=True)
|
|
182
183
|
|
|
183
184
|
if timestamp_column is not None:
|
|
184
185
|
assert timestamp_column in df.columns, f"Column '{timestamp_column}' not found!"
|
|
185
|
-
if timestamp_column != TIMESTAMP and TIMESTAMP in df.columns:
|
|
186
|
-
logger.warning(
|
|
187
|
-
|
|
188
|
-
|
|
186
|
+
if timestamp_column != cls.TIMESTAMP and cls.TIMESTAMP in df.columns:
|
|
187
|
+
logger.warning(
|
|
188
|
+
f"Renaming existing column '{cls.TIMESTAMP}' -> '__{cls.TIMESTAMP}' to avoid name collisions."
|
|
189
|
+
)
|
|
190
|
+
df.rename(columns={cls.TIMESTAMP: "__" + cls.TIMESTAMP}, inplace=True)
|
|
191
|
+
df.rename(columns={timestamp_column: cls.TIMESTAMP}, inplace=True)
|
|
189
192
|
|
|
190
|
-
if TIMESTAMP in df.columns:
|
|
191
|
-
df[TIMESTAMP] = pd.to_datetime(df[TIMESTAMP])
|
|
193
|
+
if cls.TIMESTAMP in df.columns:
|
|
194
|
+
df[cls.TIMESTAMP] = pd.to_datetime(df[cls.TIMESTAMP])
|
|
192
195
|
|
|
193
196
|
cls._validate_data_frame(df)
|
|
194
|
-
return df.set_index([ITEMID, TIMESTAMP])
|
|
197
|
+
return df.set_index([cls.ITEMID, cls.TIMESTAMP])
|
|
195
198
|
|
|
196
199
|
@classmethod
|
|
197
200
|
def _construct_tsdf_from_iterable_dataset(cls, iterable_dataset: Iterable, num_cpus: int = -1) -> pd.DataFrame:
|
|
@@ -202,7 +205,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
202
205
|
start_timestamp = start_timestamp.to_timestamp(how="S")
|
|
203
206
|
target = ts["target"]
|
|
204
207
|
datetime_index = tuple(pd.date_range(start_timestamp, periods=len(target), freq=freq))
|
|
205
|
-
idx = pd.MultiIndex.from_product([(item_id,), datetime_index], names=[ITEMID, TIMESTAMP])
|
|
208
|
+
idx = pd.MultiIndex.from_product([(item_id,), datetime_index], names=[cls.ITEMID, cls.TIMESTAMP])
|
|
206
209
|
return pd.Series(target, name="target", index=idx).to_frame()
|
|
207
210
|
|
|
208
211
|
cls._validate_iterable(iterable_dataset)
|
|
@@ -219,32 +222,34 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
219
222
|
raise ValueError(f"data must be a pd.DataFrame, got {type(data)}")
|
|
220
223
|
if not isinstance(data.index, pd.MultiIndex):
|
|
221
224
|
raise ValueError(f"data must have pd.MultiIndex, got {type(data.index)}")
|
|
222
|
-
if not pd.api.types.is_datetime64_dtype(data.index.dtypes[TIMESTAMP]):
|
|
223
|
-
raise ValueError(f"for {TIMESTAMP}, the only pandas dtype allowed is `datetime64`.")
|
|
224
|
-
if not data.index.names == (f"{ITEMID}", f"{TIMESTAMP}"):
|
|
225
|
-
raise ValueError(
|
|
225
|
+
if not pd.api.types.is_datetime64_dtype(data.index.dtypes[cls.TIMESTAMP]):
|
|
226
|
+
raise ValueError(f"for {cls.TIMESTAMP}, the only pandas dtype allowed is `datetime64`.")
|
|
227
|
+
if not data.index.names == (f"{cls.ITEMID}", f"{cls.TIMESTAMP}"):
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"data must have index names as ('{cls.ITEMID}', '{cls.TIMESTAMP}'), got {data.index.names}"
|
|
230
|
+
)
|
|
226
231
|
item_id_index = data.index.levels[0]
|
|
227
232
|
if not (pd.api.types.is_integer_dtype(item_id_index) or pd.api.types.is_string_dtype(item_id_index)):
|
|
228
|
-
raise ValueError(f"all entries in index `{ITEMID}` must be of integer or string dtype")
|
|
233
|
+
raise ValueError(f"all entries in index `{cls.ITEMID}` must be of integer or string dtype")
|
|
229
234
|
|
|
230
235
|
@classmethod
|
|
231
236
|
def _validate_data_frame(cls, df: pd.DataFrame):
|
|
232
237
|
"""Validate that a pd.DataFrame with ITEMID and TIMESTAMP columns can be converted to TimeSeriesDataFrame"""
|
|
233
238
|
if not isinstance(df, pd.DataFrame):
|
|
234
239
|
raise ValueError(f"data must be a pd.DataFrame, got {type(df)}")
|
|
235
|
-
if ITEMID not in df.columns:
|
|
236
|
-
raise ValueError(f"data must have a `{ITEMID}` column")
|
|
237
|
-
if TIMESTAMP not in df.columns:
|
|
238
|
-
raise ValueError(f"data must have a `{TIMESTAMP}` column")
|
|
239
|
-
if df[ITEMID].isnull().any():
|
|
240
|
-
raise ValueError(f"`{ITEMID}` column can not have nan")
|
|
241
|
-
if df[TIMESTAMP].isnull().any():
|
|
242
|
-
raise ValueError(f"`{TIMESTAMP}` column can not have nan")
|
|
243
|
-
if not pd.api.types.is_datetime64_dtype(df[TIMESTAMP]):
|
|
244
|
-
raise ValueError(f"for {TIMESTAMP}, the only pandas dtype allowed is `datetime64`.")
|
|
245
|
-
item_id_column = df[ITEMID]
|
|
240
|
+
if cls.ITEMID not in df.columns:
|
|
241
|
+
raise ValueError(f"data must have a `{cls.ITEMID}` column")
|
|
242
|
+
if cls.TIMESTAMP not in df.columns:
|
|
243
|
+
raise ValueError(f"data must have a `{cls.TIMESTAMP}` column")
|
|
244
|
+
if df[cls.ITEMID].isnull().any():
|
|
245
|
+
raise ValueError(f"`{cls.ITEMID}` column can not have nan")
|
|
246
|
+
if df[cls.TIMESTAMP].isnull().any():
|
|
247
|
+
raise ValueError(f"`{cls.TIMESTAMP}` column can not have nan")
|
|
248
|
+
if not pd.api.types.is_datetime64_dtype(df[cls.TIMESTAMP]):
|
|
249
|
+
raise ValueError(f"for {cls.TIMESTAMP}, the only pandas dtype allowed is `datetime64`.")
|
|
250
|
+
item_id_column = df[cls.ITEMID]
|
|
246
251
|
if not (pd.api.types.is_integer_dtype(item_id_column) or pd.api.types.is_string_dtype(item_id_column)):
|
|
247
|
-
raise ValueError(f"all entries in column `{ITEMID}` must be of integer or string dtype")
|
|
252
|
+
raise ValueError(f"all entries in column `{cls.ITEMID}` must be of integer or string dtype")
|
|
248
253
|
|
|
249
254
|
@classmethod
|
|
250
255
|
def _validate_iterable(cls, data: Iterable):
|
|
@@ -386,7 +391,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
386
391
|
@property
|
|
387
392
|
def item_ids(self) -> pd.Index:
|
|
388
393
|
"""List of unique time series IDs contained in the data set."""
|
|
389
|
-
return self.index.unique(level=ITEMID)
|
|
394
|
+
return self.index.unique(level=self.ITEMID)
|
|
390
395
|
|
|
391
396
|
@classmethod
|
|
392
397
|
def _construct_static_features(
|
|
@@ -403,10 +408,12 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
403
408
|
|
|
404
409
|
if id_column is not None:
|
|
405
410
|
assert id_column in static_features.columns, f"Column '{id_column}' not found in static_features!"
|
|
406
|
-
if id_column != ITEMID and ITEMID in static_features.columns:
|
|
407
|
-
logger.warning(
|
|
408
|
-
|
|
409
|
-
|
|
411
|
+
if id_column != cls.ITEMID and cls.ITEMID in static_features.columns:
|
|
412
|
+
logger.warning(
|
|
413
|
+
f"Renaming existing column '{cls.ITEMID}' -> '__{cls.ITEMID}' to avoid name collisions."
|
|
414
|
+
)
|
|
415
|
+
static_features.rename(columns={cls.ITEMID: "__" + cls.ITEMID}, inplace=True)
|
|
416
|
+
static_features.rename(columns={id_column: cls.ITEMID}, inplace=True)
|
|
410
417
|
return static_features
|
|
411
418
|
|
|
412
419
|
@property
|
|
@@ -431,10 +438,10 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
431
438
|
|
|
432
439
|
# Avoid modifying static features inplace
|
|
433
440
|
value = value.copy()
|
|
434
|
-
if ITEMID in value.columns and value.index.name != ITEMID:
|
|
435
|
-
value = value.set_index(ITEMID)
|
|
436
|
-
if value.index.name != ITEMID:
|
|
437
|
-
value.index.rename(ITEMID, inplace=True)
|
|
441
|
+
if self.ITEMID in value.columns and value.index.name != self.ITEMID:
|
|
442
|
+
value = value.set_index(self.ITEMID)
|
|
443
|
+
if value.index.name != self.ITEMID:
|
|
444
|
+
value.index.rename(self.ITEMID, inplace=True)
|
|
438
445
|
missing_item_ids = self.item_ids.difference(value.index)
|
|
439
446
|
if len(missing_item_ids) > 0:
|
|
440
447
|
raise ValueError(
|
|
@@ -456,7 +463,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
456
463
|
Number of items (individual time series) randomly selected to infer the frequency. Lower values speed up
|
|
457
464
|
the method, but increase the chance that some items with invalid frequency are missed by subsampling.
|
|
458
465
|
|
|
459
|
-
If set to
|
|
466
|
+
If set to ``None``, all items will be used for inferring the frequency.
|
|
460
467
|
raise_if_irregular : bool, default = False
|
|
461
468
|
If True, an exception will be raised if some items have an irregular frequency, or if different items have
|
|
462
469
|
different frequencies.
|
|
@@ -467,7 +474,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
467
474
|
If all time series have a regular frequency, returns a pandas-compatible `frequency alias <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`_.
|
|
468
475
|
|
|
469
476
|
If some items have an irregular frequency or if different items have different frequencies, returns string
|
|
470
|
-
|
|
477
|
+
``IRREG``.
|
|
471
478
|
"""
|
|
472
479
|
ts_df = self
|
|
473
480
|
if num_items is not None and ts_df.num_items > num_items:
|
|
@@ -514,7 +521,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
514
521
|
else:
|
|
515
522
|
raise ValueError(f"Cannot infer frequency. Multiple frequencies detected: {unique_freqs}")
|
|
516
523
|
else:
|
|
517
|
-
return IRREGULAR_TIME_INDEX_FREQSTR
|
|
524
|
+
return self.IRREGULAR_TIME_INDEX_FREQSTR
|
|
518
525
|
else:
|
|
519
526
|
return pd.tseries.frequencies.to_offset(unique_freqs[0]).freqstr
|
|
520
527
|
|
|
@@ -526,7 +533,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
526
533
|
values. For reliable results, use :meth:`~autogluon.timeseries.TimeSeriesDataFrame.infer_frequency`.
|
|
527
534
|
"""
|
|
528
535
|
inferred_freq = self.infer_frequency(num_items=50)
|
|
529
|
-
return None if inferred_freq == IRREGULAR_TIME_INDEX_FREQSTR else inferred_freq
|
|
536
|
+
return None if inferred_freq == self.IRREGULAR_TIME_INDEX_FREQSTR else inferred_freq
|
|
530
537
|
|
|
531
538
|
@property
|
|
532
539
|
def num_items(self):
|
|
@@ -536,7 +543,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
536
543
|
def num_timesteps_per_item(self) -> pd.Series:
|
|
537
544
|
"""Number of observations in each time series in the dataframe.
|
|
538
545
|
|
|
539
|
-
Returns a
|
|
546
|
+
Returns a ``pandas.Series`` with ``item_id`` as index and number of observations per item as values.
|
|
540
547
|
"""
|
|
541
548
|
counts = pd.Series(self.index.codes[0]).value_counts(sort=False)
|
|
542
549
|
counts.index = self.index.levels[0][counts.index]
|
|
@@ -572,7 +579,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
572
579
|
self.static_features = other._static_features
|
|
573
580
|
return self
|
|
574
581
|
|
|
575
|
-
def split_by_time(self, cutoff_time: pd.Timestamp) ->
|
|
582
|
+
def split_by_time(self, cutoff_time: pd.Timestamp) -> tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]:
|
|
576
583
|
"""Split dataframe to two different ``TimeSeriesDataFrame`` s before and after a certain ``cutoff_time``.
|
|
577
584
|
|
|
578
585
|
Parameters
|
|
@@ -603,7 +610,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
603
610
|
This operation is equivalent to selecting a slice ``[start_index : end_index]`` from each time series, and then
|
|
604
611
|
combining these slices into a new ``TimeSeriesDataFrame``. See examples below.
|
|
605
612
|
|
|
606
|
-
It is recommended to sort the index with
|
|
613
|
+
It is recommended to sort the index with ``ts_df.sort_index()`` before calling this method to take advantage of
|
|
607
614
|
a fast optimized algorithm.
|
|
608
615
|
|
|
609
616
|
Parameters
|
|
@@ -735,7 +742,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
735
742
|
return self.loc[mask]
|
|
736
743
|
else:
|
|
737
744
|
# Fall back to a slow groupby operation
|
|
738
|
-
result = self.groupby(level=ITEMID, sort=False, as_index=False).nth(slice(start_index, end_index))
|
|
745
|
+
result = self.groupby(level=self.ITEMID, sort=False, as_index=False).nth(slice(start_index, end_index))
|
|
739
746
|
result.static_features = self.static_features
|
|
740
747
|
return result
|
|
741
748
|
|
|
@@ -798,11 +805,11 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
798
805
|
method : str, default = "auto"
|
|
799
806
|
Method used to impute missing values.
|
|
800
807
|
|
|
801
|
-
- "auto" - first forward fill (to fill the in-between and trailing NaNs), then backward fill (to fill the leading NaNs)
|
|
802
|
-
- "ffill" or "pad" - propagate last valid observation forward. Note: missing values at the start of the time series are not filled.
|
|
803
|
-
- "bfill" or "backfill" - use next valid observation to fill gap. Note: this may result in information leakage; missing values at the end of the time series are not filled.
|
|
804
|
-
- "constant" - replace NaNs with the given constant ``value``.
|
|
805
|
-
- "interpolate" - fill NaN values using linear interpolation. Note: this may result in information leakage.
|
|
808
|
+
- ``"auto"`` - first forward fill (to fill the in-between and trailing NaNs), then backward fill (to fill the leading NaNs)
|
|
809
|
+
- ``"ffill"`` or ``"pad"`` - propagate last valid observation forward. Note: missing values at the start of the time series are not filled.
|
|
810
|
+
- ``"bfill"`` or ``"backfill"`` - use next valid observation to fill gap. Note: this may result in information leakage; missing values at the end of the time series are not filled.
|
|
811
|
+
- ``"constant"`` - replace NaNs with the given constant ``value``.
|
|
812
|
+
- ``"interpolate"`` - fill NaN values using linear interpolation. Note: this may result in information leakage.
|
|
806
813
|
value : float, default = 0.0
|
|
807
814
|
Value used by the "constant" imputation method.
|
|
808
815
|
|
|
@@ -852,12 +859,12 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
852
859
|
"It is highly recommended to call `ts_df.sort_index()` before calling `ts_df.fill_missing_values()`"
|
|
853
860
|
)
|
|
854
861
|
|
|
855
|
-
grouped_df = df.groupby(level=ITEMID, sort=False, group_keys=False)
|
|
862
|
+
grouped_df = df.groupby(level=self.ITEMID, sort=False, group_keys=False)
|
|
856
863
|
if method == "auto":
|
|
857
864
|
filled_df = grouped_df.ffill()
|
|
858
865
|
# If necessary, fill missing values at the start of each time series with bfill
|
|
859
866
|
if filled_df.isna().any(axis=None):
|
|
860
|
-
filled_df = filled_df.groupby(level=ITEMID, sort=False, group_keys=False).bfill()
|
|
867
|
+
filled_df = filled_df.groupby(level=self.ITEMID, sort=False, group_keys=False).bfill()
|
|
861
868
|
elif method in ["ffill", "pad"]:
|
|
862
869
|
filled_df = grouped_df.ffill()
|
|
863
870
|
elif method in ["bfill", "backfill"]:
|
|
@@ -900,17 +907,17 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
900
907
|
return super().sort_index(*args, **kwargs) # type: ignore
|
|
901
908
|
|
|
902
909
|
def get_model_inputs_for_scoring(
|
|
903
|
-
self, prediction_length: int, known_covariates_names: Optional[
|
|
904
|
-
) ->
|
|
910
|
+
self, prediction_length: int, known_covariates_names: Optional[list[str]] = None
|
|
911
|
+
) -> tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
|
|
905
912
|
"""Prepare model inputs necessary to predict the last ``prediction_length`` time steps of each time series in the dataset.
|
|
906
913
|
|
|
907
914
|
Parameters
|
|
908
915
|
----------
|
|
909
916
|
prediction_length : int
|
|
910
917
|
The forecast horizon, i.e., How many time steps into the future must be predicted.
|
|
911
|
-
known_covariates_names :
|
|
918
|
+
known_covariates_names : list[str], optional
|
|
912
919
|
Names of the dataframe columns that contain covariates known in the future.
|
|
913
|
-
See
|
|
920
|
+
See ``known_covariates_names`` of :class:`~autogluon.timeseries.TimeSeriesPredictor` for more details.
|
|
914
921
|
|
|
915
922
|
Returns
|
|
916
923
|
-------
|
|
@@ -933,7 +940,7 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
933
940
|
prediction_length: int,
|
|
934
941
|
end_index: Optional[int] = None,
|
|
935
942
|
suffix: Optional[str] = None,
|
|
936
|
-
) ->
|
|
943
|
+
) -> tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]:
|
|
937
944
|
"""Generate a train/test split from the given dataset.
|
|
938
945
|
|
|
939
946
|
This method can be used to generate splits for multi-window backtesting.
|
|
@@ -1083,11 +1090,11 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
1083
1090
|
iterable = iter(iterable)
|
|
1084
1091
|
return iter(lambda: tuple(islice(iterable, size)), ())
|
|
1085
1092
|
|
|
1086
|
-
def resample_chunk(chunk: Iterable[
|
|
1093
|
+
def resample_chunk(chunk: Iterable[tuple[str, pd.DataFrame]]) -> pd.DataFrame:
|
|
1087
1094
|
resampled_dfs = []
|
|
1088
1095
|
for item_id, df in chunk:
|
|
1089
|
-
resampled_df = df.resample(offset, level=TIMESTAMP, **kwargs).agg(aggregation)
|
|
1090
|
-
resampled_dfs.append(pd.concat({item_id: resampled_df}, names=[ITEMID]))
|
|
1096
|
+
resampled_df = df.resample(offset, level=self.TIMESTAMP, **kwargs).agg(aggregation)
|
|
1097
|
+
resampled_dfs.append(pd.concat({item_id: resampled_df}, names=[self.ITEMID]))
|
|
1091
1098
|
return pd.concat(resampled_dfs)
|
|
1092
1099
|
|
|
1093
1100
|
# Resampling time for 1 item < overhead time for a single parallel job. Therefore, we group items into chunks
|
|
@@ -1095,15 +1102,15 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
1095
1102
|
df = pd.DataFrame(self)
|
|
1096
1103
|
# Make sure that timestamp index has dtype 'datetime64[ns]', otherwise index may contain NaT values.
|
|
1097
1104
|
# See https://github.com/autogluon/autogluon/issues/4917
|
|
1098
|
-
df.index = df.index.set_levels(df.index.levels[1].astype("datetime64[ns]"), level=TIMESTAMP)
|
|
1099
|
-
chunks = split_into_chunks(df.groupby(level=ITEMID, sort=False), chunk_size)
|
|
1105
|
+
df.index = df.index.set_levels(df.index.levels[1].astype("datetime64[ns]"), level=self.TIMESTAMP)
|
|
1106
|
+
chunks = split_into_chunks(df.groupby(level=self.ITEMID, sort=False), chunk_size)
|
|
1100
1107
|
resampled_chunks = Parallel(n_jobs=num_cpus)(delayed(resample_chunk)(chunk) for chunk in chunks)
|
|
1101
1108
|
resampled_df = TimeSeriesDataFrame(pd.concat(resampled_chunks))
|
|
1102
1109
|
resampled_df.static_features = self.static_features
|
|
1103
1110
|
return resampled_df
|
|
1104
1111
|
|
|
1105
1112
|
def to_data_frame(self) -> pd.DataFrame:
|
|
1106
|
-
"""Convert
|
|
1113
|
+
"""Convert ``TimeSeriesDataFrame`` to a ``pandas.DataFrame``"""
|
|
1107
1114
|
return pd.DataFrame(self)
|
|
1108
1115
|
|
|
1109
1116
|
def get_indptr(self) -> np.ndarray:
|
|
@@ -1124,8 +1131,27 @@ class TimeSeriesDataFrame(pd.DataFrame):
|
|
|
1124
1131
|
|
|
1125
1132
|
@overload
|
|
1126
1133
|
def __new__(cls, data: pd.DataFrame, static_features: Optional[pd.DataFrame] = None) -> Self: ... # type: ignore
|
|
1134
|
+
@overload
|
|
1135
|
+
def __new__(
|
|
1136
|
+
cls,
|
|
1137
|
+
data: Union[pd.DataFrame, str, Path, Iterable],
|
|
1138
|
+
static_features: Optional[Union[pd.DataFrame, str, Path]] = None,
|
|
1139
|
+
id_column: Optional[str] = None,
|
|
1140
|
+
timestamp_column: Optional[str] = None,
|
|
1141
|
+
num_cpus: int = -1,
|
|
1142
|
+
*args,
|
|
1143
|
+
**kwargs,
|
|
1144
|
+
) -> Self:
|
|
1145
|
+
"""This overload is needed since in pandas, during type checking, the default constructor resolves to __new__"""
|
|
1146
|
+
...
|
|
1127
1147
|
|
|
1128
1148
|
@overload
|
|
1129
|
-
def __getitem__(self, items:
|
|
1149
|
+
def __getitem__(self, items: list[str]) -> Self: ... # type: ignore
|
|
1130
1150
|
@overload
|
|
1131
1151
|
def __getitem__(self, item: str) -> pd.Series: ... # type: ignore
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
# TODO: remove with v2.0
|
|
1155
|
+
# module-level constants kept for backward compatibility.
|
|
1156
|
+
ITEMID = TimeSeriesDataFrame.ITEMID
|
|
1157
|
+
TIMESTAMP = TimeSeriesDataFrame.TIMESTAMP
|
autogluon/timeseries/learner.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import reprlib
|
|
3
3
|
import time
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Literal, Optional, Type, Union
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
|
|
8
8
|
from autogluon.core.learner import AbstractLearner
|
|
9
|
-
from autogluon.timeseries.dataset
|
|
9
|
+
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
10
10
|
from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_metric
|
|
11
11
|
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
12
|
-
from autogluon.timeseries.splitter import AbstractWindowSplitter
|
|
13
12
|
from autogluon.timeseries.trainer import TimeSeriesTrainer
|
|
14
13
|
from autogluon.timeseries.utils.features import TimeSeriesFeatureGenerator
|
|
15
14
|
from autogluon.timeseries.utils.forecast import make_future_data_frame
|
|
@@ -26,7 +25,7 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
26
25
|
self,
|
|
27
26
|
path_context: str,
|
|
28
27
|
target: str = "target",
|
|
29
|
-
known_covariates_names: Optional[
|
|
28
|
+
known_covariates_names: Optional[list[str]] = None,
|
|
30
29
|
trainer_type: Type[TimeSeriesTrainer] = TimeSeriesTrainer,
|
|
31
30
|
eval_metric: Union[str, TimeSeriesScorer, None] = None,
|
|
32
31
|
prediction_length: int = 1,
|
|
@@ -56,11 +55,12 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
56
55
|
def fit(
|
|
57
56
|
self,
|
|
58
57
|
train_data: TimeSeriesDataFrame,
|
|
59
|
-
hyperparameters: Union[str,
|
|
58
|
+
hyperparameters: Union[str, dict],
|
|
60
59
|
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
61
60
|
hyperparameter_tune_kwargs: Optional[Union[str, dict]] = None,
|
|
62
61
|
time_limit: Optional[float] = None,
|
|
63
|
-
|
|
62
|
+
num_val_windows: Optional[int] = None,
|
|
63
|
+
val_step_size: Optional[int] = None,
|
|
64
64
|
refit_every_n_windows: Optional[int] = 1,
|
|
65
65
|
random_seed: Optional[int] = None,
|
|
66
66
|
**kwargs,
|
|
@@ -86,7 +86,8 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
86
86
|
skip_model_selection=kwargs.get("skip_model_selection", False),
|
|
87
87
|
enable_ensemble=kwargs.get("enable_ensemble", True),
|
|
88
88
|
covariate_metadata=self.feature_generator.covariate_metadata,
|
|
89
|
-
|
|
89
|
+
num_val_windows=num_val_windows,
|
|
90
|
+
val_step_size=val_step_size,
|
|
90
91
|
refit_every_n_windows=refit_every_n_windows,
|
|
91
92
|
cache_predictions=self.cache_predictions,
|
|
92
93
|
ensemble_model_type=self.ensemble_model_type,
|
|
@@ -194,9 +195,9 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
194
195
|
self,
|
|
195
196
|
data: TimeSeriesDataFrame,
|
|
196
197
|
model: Optional[str] = None,
|
|
197
|
-
metrics: Optional[Union[str, TimeSeriesScorer,
|
|
198
|
+
metrics: Optional[Union[str, TimeSeriesScorer, list[Union[str, TimeSeriesScorer]]]] = None,
|
|
198
199
|
use_cache: bool = True,
|
|
199
|
-
) ->
|
|
200
|
+
) -> dict[str, float]:
|
|
200
201
|
data = self.feature_generator.transform(data)
|
|
201
202
|
return self.load_trainer().evaluate(data=data, model=model, metrics=metrics, use_cache=use_cache)
|
|
202
203
|
|
|
@@ -205,7 +206,7 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
205
206
|
data: Optional[TimeSeriesDataFrame] = None,
|
|
206
207
|
model: Optional[str] = None,
|
|
207
208
|
metric: Optional[Union[str, TimeSeriesScorer]] = None,
|
|
208
|
-
features: Optional[
|
|
209
|
+
features: Optional[list[str]] = None,
|
|
209
210
|
time_limit: Optional[float] = None,
|
|
210
211
|
method: Literal["naive", "permutation"] = "permutation",
|
|
211
212
|
subsample_size: int = 50,
|
|
@@ -273,7 +274,7 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
273
274
|
self,
|
|
274
275
|
data: Optional[TimeSeriesDataFrame] = None,
|
|
275
276
|
extra_info: bool = False,
|
|
276
|
-
extra_metrics: Optional[
|
|
277
|
+
extra_metrics: Optional[list[Union[str, TimeSeriesScorer]]] = None,
|
|
277
278
|
use_cache: bool = True,
|
|
278
279
|
) -> pd.DataFrame:
|
|
279
280
|
if data is not None:
|
|
@@ -282,7 +283,7 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
282
283
|
data, extra_info=extra_info, extra_metrics=extra_metrics, use_cache=use_cache
|
|
283
284
|
)
|
|
284
285
|
|
|
285
|
-
def get_info(self, include_model_info: bool = False, **kwargs) ->
|
|
286
|
+
def get_info(self, include_model_info: bool = False, **kwargs) -> dict[str, Any]:
|
|
286
287
|
learner_info = super().get_info(include_model_info=include_model_info)
|
|
287
288
|
trainer = self.load_trainer()
|
|
288
289
|
trainer_info = trainer.get_info(include_model_info=include_model_info)
|
|
@@ -300,31 +301,31 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
300
301
|
return learner_info
|
|
301
302
|
|
|
302
303
|
def persist_trainer(
|
|
303
|
-
self, models: Union[Literal["all", "best"],
|
|
304
|
-
) ->
|
|
304
|
+
self, models: Union[Literal["all", "best"], list[str]] = "all", with_ancestors: bool = False
|
|
305
|
+
) -> list[str]:
|
|
305
306
|
"""Loads models and trainer in memory so that they don't have to be
|
|
306
307
|
loaded during predictions
|
|
307
308
|
|
|
308
309
|
Returns
|
|
309
310
|
-------
|
|
310
|
-
list_of_models
|
|
311
|
+
list_of_models
|
|
311
312
|
List of models persisted in memory
|
|
312
313
|
"""
|
|
313
314
|
self.trainer = self.load_trainer()
|
|
314
315
|
return self.trainer.persist(models, with_ancestors=with_ancestors)
|
|
315
316
|
|
|
316
|
-
def unpersist_trainer(self) ->
|
|
317
|
+
def unpersist_trainer(self) -> list[str]:
|
|
317
318
|
"""Unloads models and trainer from memory. Models will have to be reloaded from disk
|
|
318
319
|
when predicting.
|
|
319
320
|
|
|
320
321
|
Returns
|
|
321
322
|
-------
|
|
322
|
-
list_of_models
|
|
323
|
+
list_of_models
|
|
323
324
|
List of models removed from memory
|
|
324
325
|
"""
|
|
325
326
|
unpersisted_models = self.load_trainer().unpersist()
|
|
326
327
|
self.trainer = None # type: ignore
|
|
327
328
|
return unpersisted_models
|
|
328
329
|
|
|
329
|
-
def refit_full(self, model: str = "all") ->
|
|
330
|
+
def refit_full(self, model: str = "all") -> dict[str, str]:
|
|
330
331
|
return self.load_trainer().refit_full(model=model)
|