autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +339 -186
- autogluon/timeseries/learner.py +192 -60
- autogluon/timeseries/metrics/__init__.py +55 -11
- autogluon/timeseries/metrics/abstract.py +96 -25
- autogluon/timeseries/metrics/point.py +186 -39
- autogluon/timeseries/metrics/quantile.py +47 -20
- autogluon/timeseries/metrics/utils.py +6 -6
- autogluon/timeseries/models/__init__.py +13 -7
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
- autogluon/timeseries/models/abstract/model_trial.py +10 -10
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
- autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
- autogluon/timeseries/models/chronos/__init__.py +4 -0
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +738 -0
- autogluon/timeseries/models/chronos/utils.py +369 -0
- autogluon/timeseries/models/ensemble/__init__.py +35 -2
- autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
- autogluon/timeseries/models/gluonts/__init__.py +3 -1
- autogluon/timeseries/models/gluonts/abstract.py +583 -0
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
- autogluon/timeseries/models/local/__init__.py +1 -10
- autogluon/timeseries/models/local/abstract_local_model.py +150 -97
- autogluon/timeseries/models/local/naive.py +31 -23
- autogluon/timeseries/models/local/npts.py +6 -2
- autogluon/timeseries/models/local/statsforecast.py +99 -112
- autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +826 -305
- autogluon/timeseries/regressor.py +253 -0
- autogluon/timeseries/splitter.py +10 -31
- autogluon/timeseries/trainer/__init__.py +2 -3
- autogluon/timeseries/trainer/ensemble_composer.py +439 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/trainer/trainer.py +1298 -0
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -0
- autogluon/timeseries/transforms/covariate_scaler.py +164 -0
- autogluon/timeseries/transforms/target_scaler.py +149 -0
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/base.py +38 -20
- autogluon/timeseries/utils/datetime/lags.py +18 -16
- autogluon/timeseries/utils/datetime/seasonality.py +14 -14
- autogluon/timeseries/utils/datetime/time_features.py +17 -14
- autogluon/timeseries/utils/features.py +317 -53
- autogluon/timeseries/utils/forecast.py +31 -17
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +44 -6
- autogluon/timeseries/version.py +2 -1
- autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
- autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -11
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -325
- autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
- autogluon/timeseries/trainer/auto_trainer.py +0 -74
- autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
- autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
import warnings
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
|
|
12
|
+
from autogluon.common.loaders import load_pkl
|
|
13
|
+
from autogluon.common.space import Space
|
|
14
|
+
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
15
|
+
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
16
|
+
from autogluon.timeseries.utils.warning_filters import disable_duplicate_logs, warning_filter
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("autogluon.timeseries.models.chronos")
|
|
19
|
+
|
|
20
|
+
# TODO: Replace `evaluation_strategy` with `eval_strategy` when upgrading to `transformers>=4.41` + remove warning filter
|
|
21
|
+
warnings.filterwarnings("ignore", category=FutureWarning, message="`evaluation_strategy` is deprecated")
|
|
22
|
+
# TODO: Remove warning filter when upgrading to `transformers>=4.40`
|
|
23
|
+
warnings.filterwarnings("ignore", category=FutureWarning, message="Passing the following arguments to ")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# allowed HuggingFace model paths with custom parameter definitions
|
|
27
|
+
MODEL_CONFIGS = {
|
|
28
|
+
"chronos-t5-tiny": {
|
|
29
|
+
"num_gpus": 0, # minimum number of required GPUs
|
|
30
|
+
"default_torch_dtype": "auto",
|
|
31
|
+
"default_batch_size": 16,
|
|
32
|
+
},
|
|
33
|
+
"chronos-t5-mini": {
|
|
34
|
+
"num_gpus": 0,
|
|
35
|
+
"default_torch_dtype": "auto",
|
|
36
|
+
"default_batch_size": 16,
|
|
37
|
+
},
|
|
38
|
+
"chronos-t5-small": {
|
|
39
|
+
"num_gpus": 1,
|
|
40
|
+
"default_torch_dtype": "bfloat16",
|
|
41
|
+
"default_batch_size": 16,
|
|
42
|
+
},
|
|
43
|
+
"chronos-t5-base": {
|
|
44
|
+
"num_gpus": 1,
|
|
45
|
+
"default_torch_dtype": "bfloat16",
|
|
46
|
+
"default_batch_size": 16,
|
|
47
|
+
},
|
|
48
|
+
"chronos-t5-large": {
|
|
49
|
+
"num_gpus": 1,
|
|
50
|
+
"default_torch_dtype": "bfloat16",
|
|
51
|
+
"default_batch_size": 8,
|
|
52
|
+
},
|
|
53
|
+
"chronos-bolt-mini": {
|
|
54
|
+
"num_gpus": 0,
|
|
55
|
+
"default_torch_dtype": "auto",
|
|
56
|
+
"default_batch_size": 256,
|
|
57
|
+
},
|
|
58
|
+
"chronos-bolt-small": {
|
|
59
|
+
"num_gpus": 0,
|
|
60
|
+
"default_torch_dtype": "auto",
|
|
61
|
+
"default_batch_size": 256,
|
|
62
|
+
},
|
|
63
|
+
"chronos-bolt-base": {
|
|
64
|
+
"num_gpus": 0,
|
|
65
|
+
"default_torch_dtype": "auto",
|
|
66
|
+
"default_batch_size": 256,
|
|
67
|
+
},
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
MODEL_ALIASES = {
|
|
72
|
+
"tiny": "autogluon/chronos-t5-tiny",
|
|
73
|
+
"mini": "autogluon/chronos-t5-mini",
|
|
74
|
+
"small": "autogluon/chronos-t5-small",
|
|
75
|
+
"base": "autogluon/chronos-t5-base",
|
|
76
|
+
"large": "autogluon/chronos-t5-large",
|
|
77
|
+
"bolt_tiny": "autogluon/chronos-bolt-tiny",
|
|
78
|
+
"bolt_mini": "autogluon/chronos-bolt-mini",
|
|
79
|
+
"bolt_small": "autogluon/chronos-bolt-small",
|
|
80
|
+
"bolt_base": "autogluon/chronos-bolt-base",
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ChronosModel(AbstractTimeSeriesModel):
|
|
85
|
+
"""Chronos [Ansari2024]_ pretrained time series forecasting models which can be used for zero-shot
|
|
86
|
+
forecasting or fine-tuned in a task-specific manner.
|
|
87
|
+
|
|
88
|
+
Models can be based on the original
|
|
89
|
+
`Chronos <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py>`_
|
|
90
|
+
implementation, as well as a newer family of
|
|
91
|
+
`Chronos-Bolt <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos_bolt.py>`_
|
|
92
|
+
models capable of much faster inference.
|
|
93
|
+
|
|
94
|
+
The original Chronos is a family of pretrained models, based on the T5 family, with number of
|
|
95
|
+
parameters ranging between 8M and 710M. The full collection of Chronos models is available on
|
|
96
|
+
`Hugging Face <https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444>`_.
|
|
97
|
+
|
|
98
|
+
For Chronos (original) ``small``, ``base``, and ``large`` variants a GPU is required to
|
|
99
|
+
perform inference efficiently. Chronos takes a minimalistic approach to pretraining time series
|
|
100
|
+
models, by discretizing time series data directly into bins which are treated as tokens,
|
|
101
|
+
effectively performing regression by classification. This results in a simple and flexible
|
|
102
|
+
framework for using any language model in the context of time series forecasting.
|
|
103
|
+
See [Ansari2024]_ for more information.
|
|
104
|
+
|
|
105
|
+
The newer Chronos-Bolt variants enable much faster inference by first "patching" the time series.
|
|
106
|
+
The resulting time series is then fed into a T5 model for forecasting. The Chronos-Bolt variants
|
|
107
|
+
are capable of much faster inference, and can all run on CPUs.
|
|
108
|
+
|
|
109
|
+
Both Chronos and Chronos-Bolt variants can be fine-tuned by setting ``fine_tune=True`` and selecting
|
|
110
|
+
appropriate fine-tuning parameters such as the learning rate (``fine_tune_lr``) and max steps
|
|
111
|
+
(``fine_tune_steps``).
|
|
112
|
+
|
|
113
|
+
References
|
|
114
|
+
----------
|
|
115
|
+
.. [Ansari2024] Ansari, Abdul Fatir, Stella, Lorenzo et al.
|
|
116
|
+
"Chronos: Learning the Language of Time Series."
|
|
117
|
+
Transactions on Machine Learning Research (2024).
|
|
118
|
+
https://openreview.net/forum?id=gerNCVqqtR
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
Other Parameters
|
|
122
|
+
----------------
|
|
123
|
+
model_path : str, default = "autogluon/chronos-bolt-small"
|
|
124
|
+
Model path used for the model, i.e., a HuggingFace transformers ``name_or_path``. Can be a
|
|
125
|
+
compatible model name on HuggingFace Hub or a local path to a model directory. Original
|
|
126
|
+
Chronos models (i.e., ``autogluon/chronos-t5-{model_size}``) can be specified with aliases
|
|
127
|
+
``tiny``, ``mini`` , ``small``, ``base``, and ``large``. Chronos-Bolt models can be specified
|
|
128
|
+
with ``bolt_tiny``, ``bolt_mini``, ``bolt_small``, and ``bolt_base``.
|
|
129
|
+
batch_size : int, default = 256
|
|
130
|
+
Size of batches used during inference.
|
|
131
|
+
|
|
132
|
+
The default ``batch_size`` is selected based on the model type. Chronos (original) models use a
|
|
133
|
+
``batch_size`` of 16, except Chronos (Large) which uses 8.
|
|
134
|
+
|
|
135
|
+
For Chronos-Bolt models the ``batch_size`` is set to 256. However, ``batch_size`` is reduced by
|
|
136
|
+
a factor of 4 when the prediction horizon is greater than the model's
|
|
137
|
+
default prediction length.
|
|
138
|
+
num_samples : int, default = 20
|
|
139
|
+
Number of samples used during inference, only used for the original Chronos models
|
|
140
|
+
device : str, default = None
|
|
141
|
+
Device to use for inference (and fine-tuning, if enabled). If None, model will use the GPU if
|
|
142
|
+
available. For larger Chronos model sizes ``small``, ``base``, and ``large``; inference will fail
|
|
143
|
+
if no GPU is available.
|
|
144
|
+
|
|
145
|
+
For Chronos-Bolt models, inference can be performed on the CPU. Although fine-tuning the smaller
|
|
146
|
+
Chronos models (``tiny`` and ``mini``) and all Chronos-Bolt is allowed on the CPU, we recommend
|
|
147
|
+
using a GPU for faster fine-tuning.
|
|
148
|
+
context_length : int or None, default = None
|
|
149
|
+
The context length to use in the model.
|
|
150
|
+
|
|
151
|
+
Shorter context lengths will decrease model accuracy, but result in faster inference. If None,
|
|
152
|
+
the model will infer context length from the data set length at inference time, but cap it at a
|
|
153
|
+
maximum of 2048.
|
|
154
|
+
|
|
155
|
+
Note that this is only the context length used to pass data into the model. Individual model
|
|
156
|
+
implementations may have different context lengths specified in their configuration, and may
|
|
157
|
+
truncate the context further. For example, original Chronos models have a context length of 512,
|
|
158
|
+
but Chronos-Bolt models handle contexts up to 2048.
|
|
159
|
+
torch_dtype : torch.dtype or {"auto", "bfloat16", "float32"}, default = "auto"
|
|
160
|
+
Torch data type for model weights, provided to ``from_pretrained`` method of Hugging Face
|
|
161
|
+
AutoModels. If original Chronos models are specified and the model size is ``small``, ``base``,
|
|
162
|
+
or ``large``, the ``torch_dtype`` will be set to ``bfloat16`` to enable inference on GPUs.
|
|
163
|
+
data_loader_num_workers : int, default = 0
|
|
164
|
+
Number of worker processes to be used in the data loader. See documentation on
|
|
165
|
+
``torch.utils.data.DataLoader`` for more information.
|
|
166
|
+
fine_tune : bool, default = False
|
|
167
|
+
If True, the pretrained model will be fine-tuned
|
|
168
|
+
fine_tune_lr : float, default = 1e-5
|
|
169
|
+
The learning rate used for fine-tuning. This default is suitable for Chronos-Bolt models; for
|
|
170
|
+
the original Chronos models, we recommend using a higher learning rate such as ``1e-4``.
|
|
171
|
+
fine_tune_steps : int, default = 1000
|
|
172
|
+
The number of gradient update steps to fine-tune for
|
|
173
|
+
fine_tune_batch_size : int, default = 32
|
|
174
|
+
The batch size to use for fine-tuning
|
|
175
|
+
fine_tune_shuffle_buffer_size : int, default = 10000
|
|
176
|
+
The size of the shuffle buffer to shuffle the data during fine-tuning. If None, shuffling will
|
|
177
|
+
be turned off.
|
|
178
|
+
eval_during_fine_tune : bool, default = False
|
|
179
|
+
If True, validation will be performed during fine-tuning to select the best checkpoint. Setting this
|
|
180
|
+
argument to True may result in slower fine-tuning. This parameter is ignored if ``skip_model_selection=True``
|
|
181
|
+
in ``TimeSeriesPredictor.fit``.
|
|
182
|
+
fine_tune_eval_max_items : int, default = 256
|
|
183
|
+
The maximum number of randomly-sampled time series to use from the validation set for evaluation
|
|
184
|
+
during fine-tuning. If None, the entire validation dataset will be used.
|
|
185
|
+
fine_tune_trainer_kwargs : dict, optional
|
|
186
|
+
Extra keyword arguments passed to ``transformers.TrainingArguments``
|
|
187
|
+
keep_transformers_logs : bool, default = False
|
|
188
|
+
If True, the logs generated by transformers will NOT be removed after fine-tuning
|
|
189
|
+
revision : str, default = None
|
|
190
|
+
Model revision to use (branch name or commit hash). If None, the default branch (usually "main") is used.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
ag_priority = 55
|
|
194
|
+
default_num_samples: int = 20 # default number of samples for prediction
|
|
195
|
+
default_model_path = "autogluon/chronos-bolt-small"
|
|
196
|
+
default_max_time_limit_ratio = 0.8
|
|
197
|
+
maximum_context_length = 2048
|
|
198
|
+
fine_tuned_ckpt_name: str = "fine-tuned-ckpt"
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
freq: str | None = None,
|
|
203
|
+
prediction_length: int = 1,
|
|
204
|
+
path: str | None = None,
|
|
205
|
+
name: str | None = None,
|
|
206
|
+
eval_metric: str | None = None,
|
|
207
|
+
hyperparameters: dict[str, Any] | None = None,
|
|
208
|
+
**kwargs, # noqa
|
|
209
|
+
):
|
|
210
|
+
hyperparameters = hyperparameters if hyperparameters is not None else {}
|
|
211
|
+
|
|
212
|
+
model_path_input: str = hyperparameters.get("model_path", self.default_model_path)
|
|
213
|
+
self.model_path: str = MODEL_ALIASES.get(model_path_input, model_path_input)
|
|
214
|
+
|
|
215
|
+
name = name if name is not None else "Chronos"
|
|
216
|
+
if not isinstance(model_path_input, Space):
|
|
217
|
+
# we truncate the name to avoid long path errors on Windows
|
|
218
|
+
model_path_suffix = "[" + str(model_path_input).replace("/", "__").replace(os.path.sep, "__")[-50:] + "]"
|
|
219
|
+
if model_path_suffix not in name:
|
|
220
|
+
name += model_path_suffix
|
|
221
|
+
|
|
222
|
+
super().__init__(
|
|
223
|
+
path=path,
|
|
224
|
+
freq=freq,
|
|
225
|
+
prediction_length=prediction_length,
|
|
226
|
+
name=name,
|
|
227
|
+
eval_metric=eval_metric,
|
|
228
|
+
hyperparameters=hyperparameters,
|
|
229
|
+
**kwargs,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
self._model_pipeline: Any | None = None # of type BaseChronosPipeline
|
|
233
|
+
|
|
234
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
|
235
|
+
pipeline = self._model_pipeline
|
|
236
|
+
self._model_pipeline = None
|
|
237
|
+
path = super().save(path=path, verbose=verbose)
|
|
238
|
+
self._model_pipeline = pipeline
|
|
239
|
+
|
|
240
|
+
return str(path)
|
|
241
|
+
|
|
242
|
+
@classmethod
|
|
243
|
+
def load(cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True) -> Self:
|
|
244
|
+
model = load_pkl.load(path=os.path.join(path, cls.model_file_name), verbose=verbose)
|
|
245
|
+
if reset_paths:
|
|
246
|
+
model.set_contexts(path)
|
|
247
|
+
|
|
248
|
+
fine_tune_ckpt_path = Path(model.path) / cls.fine_tuned_ckpt_name
|
|
249
|
+
if fine_tune_ckpt_path.exists():
|
|
250
|
+
logger.debug(f"\tFine-tuned checkpoint exists, setting model_path to {fine_tune_ckpt_path}")
|
|
251
|
+
model.model_path = str(fine_tune_ckpt_path)
|
|
252
|
+
|
|
253
|
+
return model
|
|
254
|
+
|
|
255
|
+
def _is_gpu_available(self) -> bool:
|
|
256
|
+
import torch.cuda
|
|
257
|
+
|
|
258
|
+
return torch.cuda.is_available()
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def model_pipeline(self) -> Any: # of type BaseChronosPipeline
|
|
262
|
+
"""The model pipeline used for inference. If the model is not loaded, this will be None."""
|
|
263
|
+
if self._model_pipeline is None:
|
|
264
|
+
self.load_model_pipeline() # load model pipeline to device memory
|
|
265
|
+
return self._model_pipeline
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def ag_default_config(self) -> dict[str, Any]:
|
|
269
|
+
"""The default configuration of the model used by AutoGluon if the model is one of those
|
|
270
|
+
defined in MODEL_CONFIGS. For now, these are ``autogluon/chronos-t5-*`` family of models.
|
|
271
|
+
"""
|
|
272
|
+
for k in MODEL_CONFIGS:
|
|
273
|
+
if k in self.model_path:
|
|
274
|
+
return MODEL_CONFIGS[k]
|
|
275
|
+
return {}
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def min_num_gpus(self) -> int:
|
|
279
|
+
"""Minimum number of GPUs required for the model. For models not defined in AutoGluon,
|
|
280
|
+
this value defaults to 0.
|
|
281
|
+
"""
|
|
282
|
+
return self.ag_default_config.get("num_gpus", 0)
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def default_batch_size(self) -> int:
|
|
286
|
+
"""Default batch size used for the model. For models not defined in AutoGluon, this value
|
|
287
|
+
defaults to 8.
|
|
288
|
+
"""
|
|
289
|
+
return self.ag_default_config.get("default_batch_size", 8)
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def default_torch_dtype(self) -> Any:
|
|
293
|
+
"""Default torch data type used for the model. For models not defined in AutoGluon, this value
|
|
294
|
+
defaults to "auto".
|
|
295
|
+
"""
|
|
296
|
+
return self.ag_default_config.get("default_torch_dtype", "auto")
|
|
297
|
+
|
|
298
|
+
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
|
|
299
|
+
minimum_resources: dict[str, int | float] = {"num_cpus": 1}
|
|
300
|
+
# if GPU is available, we train with 1 GPU per trial
|
|
301
|
+
if is_gpu_available:
|
|
302
|
+
minimum_resources["num_gpus"] = self.min_num_gpus
|
|
303
|
+
return minimum_resources
|
|
304
|
+
|
|
305
|
+
def load_model_pipeline(self, is_training: bool = False):
|
|
306
|
+
from chronos import BaseChronosPipeline
|
|
307
|
+
|
|
308
|
+
gpu_available = self._is_gpu_available()
|
|
309
|
+
|
|
310
|
+
if not gpu_available and self.min_num_gpus > 0:
|
|
311
|
+
raise RuntimeError(
|
|
312
|
+
f"{self.name} requires a GPU to run, but no GPU was detected. "
|
|
313
|
+
"Please make sure that you are using a computer with a CUDA-compatible GPU and "
|
|
314
|
+
"`import torch; torch.cuda.is_available()` returns `True`."
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
device = (self.device or "cuda") if gpu_available else "cpu"
|
|
318
|
+
|
|
319
|
+
assert self.model_path is not None
|
|
320
|
+
pipeline = BaseChronosPipeline.from_pretrained(
|
|
321
|
+
self.model_path,
|
|
322
|
+
device_map=device,
|
|
323
|
+
torch_dtype=self.torch_dtype,
|
|
324
|
+
revision=self.get_hyperparameter("revision"),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self._model_pipeline = pipeline
|
|
328
|
+
|
|
329
|
+
def persist(self) -> "ChronosModel":
|
|
330
|
+
# TODO: Check the model has been fit before persist
|
|
331
|
+
self.load_model_pipeline()
|
|
332
|
+
return self
|
|
333
|
+
|
|
334
|
+
def _has_tf32(self):
|
|
335
|
+
import torch.cuda
|
|
336
|
+
|
|
337
|
+
return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
|
338
|
+
|
|
339
|
+
def get_hyperparameters(self) -> dict:
|
|
340
|
+
"""Gets params that are passed to the inner model."""
|
|
341
|
+
init_args = super().get_hyperparameters()
|
|
342
|
+
|
|
343
|
+
eval_during_fine_tune = init_args["eval_during_fine_tune"]
|
|
344
|
+
fine_tune_trainer_kwargs = self._get_fine_tune_trainer_kwargs(init_args, eval_during_fine_tune)
|
|
345
|
+
user_fine_tune_trainer_kwargs = init_args.get("fine_tune_trainer_kwargs", {})
|
|
346
|
+
fine_tune_trainer_kwargs.update(user_fine_tune_trainer_kwargs)
|
|
347
|
+
init_args["fine_tune_trainer_kwargs"] = fine_tune_trainer_kwargs
|
|
348
|
+
|
|
349
|
+
return init_args.copy()
|
|
350
|
+
|
|
351
|
+
def _get_default_hyperparameters(self) -> dict:
|
|
352
|
+
return {
|
|
353
|
+
"batch_size": self.default_batch_size,
|
|
354
|
+
"num_samples": self.default_num_samples,
|
|
355
|
+
"device": None,
|
|
356
|
+
"torch_dtype": self.default_torch_dtype,
|
|
357
|
+
"data_loader_num_workers": 0,
|
|
358
|
+
"context_length": None,
|
|
359
|
+
"fine_tune": False,
|
|
360
|
+
"keep_transformers_logs": False,
|
|
361
|
+
"fine_tune_lr": 1e-5,
|
|
362
|
+
"fine_tune_steps": 1000,
|
|
363
|
+
"fine_tune_batch_size": 32,
|
|
364
|
+
"eval_during_fine_tune": False,
|
|
365
|
+
"fine_tune_eval_max_items": 256,
|
|
366
|
+
"fine_tune_shuffle_buffer_size": 10_000,
|
|
367
|
+
"revision": None,
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
@property
|
|
371
|
+
def allowed_hyperparameters(self) -> list[str]:
|
|
372
|
+
return super().allowed_hyperparameters + [
|
|
373
|
+
"model_path",
|
|
374
|
+
"batch_size",
|
|
375
|
+
"num_samples",
|
|
376
|
+
"device",
|
|
377
|
+
"context_length",
|
|
378
|
+
"torch_dtype",
|
|
379
|
+
"data_loader_num_workers",
|
|
380
|
+
"fine_tune",
|
|
381
|
+
"fine_tune_lr",
|
|
382
|
+
"fine_tune_steps",
|
|
383
|
+
"fine_tune_batch_size",
|
|
384
|
+
"fine_tune_shuffle_buffer_size",
|
|
385
|
+
"eval_during_fine_tune",
|
|
386
|
+
"fine_tune_eval_max_items",
|
|
387
|
+
"fine_tune_trainer_kwargs",
|
|
388
|
+
"keep_transformers_logs",
|
|
389
|
+
"revision",
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
def _get_fine_tune_trainer_kwargs(self, init_args, eval_during_fine_tune: bool):
|
|
393
|
+
output_dir = Path(self.path) / "transformers_logs"
|
|
394
|
+
fine_tune_trainer_kwargs = dict(
|
|
395
|
+
output_dir=str(output_dir),
|
|
396
|
+
per_device_train_batch_size=init_args["fine_tune_batch_size"],
|
|
397
|
+
per_device_eval_batch_size=init_args["fine_tune_batch_size"],
|
|
398
|
+
learning_rate=init_args["fine_tune_lr"],
|
|
399
|
+
lr_scheduler_type="linear",
|
|
400
|
+
warmup_ratio=0.0,
|
|
401
|
+
optim="adamw_torch_fused",
|
|
402
|
+
logging_dir=str(output_dir),
|
|
403
|
+
logging_strategy="steps",
|
|
404
|
+
logging_steps=100,
|
|
405
|
+
disable_tqdm=True,
|
|
406
|
+
report_to="none",
|
|
407
|
+
max_steps=init_args["fine_tune_steps"],
|
|
408
|
+
gradient_accumulation_steps=1,
|
|
409
|
+
dataloader_num_workers=init_args["data_loader_num_workers"],
|
|
410
|
+
tf32=self._has_tf32(),
|
|
411
|
+
save_only_model=True,
|
|
412
|
+
prediction_loss_only=True,
|
|
413
|
+
save_total_limit=1,
|
|
414
|
+
save_strategy="steps" if eval_during_fine_tune else "no",
|
|
415
|
+
save_steps=100 if eval_during_fine_tune else None,
|
|
416
|
+
evaluation_strategy="steps" if eval_during_fine_tune else "no",
|
|
417
|
+
eval_steps=100 if eval_during_fine_tune else None,
|
|
418
|
+
load_best_model_at_end=True if eval_during_fine_tune else False,
|
|
419
|
+
metric_for_best_model="eval_loss" if eval_during_fine_tune else None,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
return fine_tune_trainer_kwargs
|
|
423
|
+
|
|
424
|
+
def _validate_and_assign_attributes(self, model_params: dict):
|
|
425
|
+
# we validate the params here because their values are concrete,
|
|
426
|
+
# unlike in the constructor where they may be a search space
|
|
427
|
+
|
|
428
|
+
# TODO: automatically determine batch size based on GPU / memory availability
|
|
429
|
+
self.batch_size = model_params["batch_size"]
|
|
430
|
+
self.num_samples = model_params["num_samples"]
|
|
431
|
+
self.device = model_params["device"]
|
|
432
|
+
self.torch_dtype = model_params["torch_dtype"]
|
|
433
|
+
self.data_loader_num_workers = model_params["data_loader_num_workers"]
|
|
434
|
+
self.context_length = model_params["context_length"]
|
|
435
|
+
|
|
436
|
+
if self.context_length is not None and self.context_length > self.maximum_context_length:
|
|
437
|
+
logger.info(
|
|
438
|
+
f"\tContext length {self.context_length} exceeds maximum context length {self.maximum_context_length}."
|
|
439
|
+
f"Context length will be set to {self.maximum_context_length}."
|
|
440
|
+
)
|
|
441
|
+
self.context_length = self.maximum_context_length
|
|
442
|
+
|
|
443
|
+
def _fit(
|
|
444
|
+
self,
|
|
445
|
+
train_data: TimeSeriesDataFrame,
|
|
446
|
+
val_data: TimeSeriesDataFrame | None = None,
|
|
447
|
+
time_limit: float | None = None,
|
|
448
|
+
num_cpus: int | None = None,
|
|
449
|
+
num_gpus: int | None = None,
|
|
450
|
+
verbosity: int = 2,
|
|
451
|
+
**kwargs,
|
|
452
|
+
) -> None:
|
|
453
|
+
import transformers
|
|
454
|
+
from chronos import ChronosBoltPipeline, ChronosPipeline
|
|
455
|
+
from packaging import version
|
|
456
|
+
from transformers.trainer import PrinterCallback, Trainer, TrainingArguments
|
|
457
|
+
|
|
458
|
+
from .utils import (
|
|
459
|
+
ChronosFineTuningDataset,
|
|
460
|
+
EvaluateAndSaveFinalStepCallback,
|
|
461
|
+
LoggerCallback,
|
|
462
|
+
TimeLimitCallback,
|
|
463
|
+
update_output_quantiles,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# TODO: Add support for fine-tuning models with context_length longer than the pretrained model
|
|
467
|
+
|
|
468
|
+
# verbosity < 3: all logs and warnings from transformers will be suppressed
|
|
469
|
+
# verbosity >= 3: progress bar and loss logs will be logged
|
|
470
|
+
# verbosity 4: everything will be logged
|
|
471
|
+
for logger_name in logging.root.manager.loggerDict:
|
|
472
|
+
if "transformers" in logger_name:
|
|
473
|
+
transformers_logger = logging.getLogger(logger_name)
|
|
474
|
+
transformers_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
|
|
475
|
+
|
|
476
|
+
self._check_fit_params()
|
|
477
|
+
self._log_unused_hyperparameters()
|
|
478
|
+
model_params = self.get_hyperparameters()
|
|
479
|
+
self._validate_and_assign_attributes(model_params)
|
|
480
|
+
do_fine_tune = model_params["fine_tune"]
|
|
481
|
+
|
|
482
|
+
if do_fine_tune:
|
|
483
|
+
assert train_data is not None, "train_data cannot be None when fine_tune=True"
|
|
484
|
+
|
|
485
|
+
eval_during_fine_tune = val_data is not None and model_params["eval_during_fine_tune"]
|
|
486
|
+
|
|
487
|
+
if do_fine_tune:
|
|
488
|
+
context_length = self._get_context_length(train_data)
|
|
489
|
+
# load model pipeline to device memory
|
|
490
|
+
self.load_model_pipeline(is_training=True)
|
|
491
|
+
|
|
492
|
+
fine_tune_prediction_length = self.prediction_length
|
|
493
|
+
model_prediction_length = self.model_pipeline.inner_model.config.chronos_config["prediction_length"]
|
|
494
|
+
|
|
495
|
+
if isinstance(self.model_pipeline, ChronosPipeline):
|
|
496
|
+
pipeline_specific_trainer_kwargs = {}
|
|
497
|
+
|
|
498
|
+
# Update prediction_length of the model
|
|
499
|
+
# NOTE: We only do this for ChronosPipeline because the prediction length of ChronosBolt models
|
|
500
|
+
# is fixed due to direct multistep forecasting setup
|
|
501
|
+
self.model_pipeline.model.config.prediction_length = fine_tune_prediction_length
|
|
502
|
+
self.model_pipeline.inner_model.config.chronos_config["prediction_length"] = (
|
|
503
|
+
fine_tune_prediction_length
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
elif isinstance(self.model_pipeline, ChronosBoltPipeline):
|
|
507
|
+
# custom label_names is needed for validation to work with ChronosBolt models
|
|
508
|
+
pipeline_specific_trainer_kwargs = dict(label_names=["target"])
|
|
509
|
+
|
|
510
|
+
# truncate prediction_length if it goes beyond ChronosBolt's prediction_length
|
|
511
|
+
fine_tune_prediction_length = min(model_prediction_length, self.prediction_length)
|
|
512
|
+
|
|
513
|
+
if self.prediction_length != fine_tune_prediction_length:
|
|
514
|
+
logger.debug(
|
|
515
|
+
f"\tChronos-Bolt models can only be fine-tuned with a maximum prediction_length of {model_prediction_length}. "
|
|
516
|
+
f"Fine-tuning prediction_length has been changed to {fine_tune_prediction_length}."
|
|
517
|
+
)
|
|
518
|
+
if self.quantile_levels != self.model_pipeline.quantiles:
|
|
519
|
+
update_output_quantiles(self.model_pipeline.model, self.quantile_levels)
|
|
520
|
+
logger.info(f"\tChronos-Bolt will be fine-tuned with quantile_levels={self.quantile_levels}")
|
|
521
|
+
else:
|
|
522
|
+
raise ValueError(f"Unsupported model pipeline: {type(self.model_pipeline)}")
|
|
523
|
+
|
|
524
|
+
fine_tune_trainer_kwargs = model_params["fine_tune_trainer_kwargs"]
|
|
525
|
+
fine_tune_trainer_kwargs["use_cpu"] = str(self.model_pipeline.inner_model.device) == "cpu"
|
|
526
|
+
|
|
527
|
+
if fine_tune_trainer_kwargs["use_cpu"]:
|
|
528
|
+
logger.info(
|
|
529
|
+
"\tFine-tuning on the CPU detected. We recommend using a GPU for faster fine-tuning of Chronos."
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# TODO: adamw_torch_fused is not supported on CPU in torch <= 2.3. When torch 2.4 becomes the lower bound
|
|
533
|
+
# this if block can be removed because torch >= 2.4 supports AdamW optimizer with fused=True on CPU
|
|
534
|
+
if fine_tune_trainer_kwargs["optim"] == "adamw_torch_fused":
|
|
535
|
+
fine_tune_trainer_kwargs["optim"] = "adamw_torch"
|
|
536
|
+
|
|
537
|
+
output_dir = Path(fine_tune_trainer_kwargs["output_dir"])
|
|
538
|
+
|
|
539
|
+
if not eval_during_fine_tune:
|
|
540
|
+
# turn off eval-related trainer args
|
|
541
|
+
fine_tune_trainer_kwargs["evaluation_strategy"] = "no"
|
|
542
|
+
fine_tune_trainer_kwargs["eval_steps"] = None
|
|
543
|
+
fine_tune_trainer_kwargs["load_best_model_at_end"] = False
|
|
544
|
+
fine_tune_trainer_kwargs["metric_for_best_model"] = None
|
|
545
|
+
|
|
546
|
+
if version.parse(transformers.__version__) >= version.parse("4.46"):
|
|
547
|
+
# transformers changed the argument name from `evaluation_strategy` to `eval_strategy`
|
|
548
|
+
fine_tune_trainer_kwargs["eval_strategy"] = fine_tune_trainer_kwargs.pop("evaluation_strategy")
|
|
549
|
+
|
|
550
|
+
training_args = TrainingArguments(**fine_tune_trainer_kwargs, **pipeline_specific_trainer_kwargs) # type: ignore
|
|
551
|
+
tokenizer_train_dataset = ChronosFineTuningDataset(
|
|
552
|
+
target_df=train_data,
|
|
553
|
+
target_column=self.target,
|
|
554
|
+
context_length=context_length,
|
|
555
|
+
prediction_length=fine_tune_prediction_length,
|
|
556
|
+
# if tokenizer exists, then the data is returned in the HF-style format accepted by
|
|
557
|
+
# the original Chronos models otherwise the data is returned in ChronosBolt's format
|
|
558
|
+
tokenizer=getattr(self.model_pipeline, "tokenizer", None),
|
|
559
|
+
mode="training",
|
|
560
|
+
).shuffle(model_params["fine_tune_shuffle_buffer_size"])
|
|
561
|
+
|
|
562
|
+
callbacks = []
|
|
563
|
+
if time_limit is not None:
|
|
564
|
+
callbacks.append(TimeLimitCallback(time_limit=time_limit))
|
|
565
|
+
|
|
566
|
+
tokenizer_val_dataset: ChronosFineTuningDataset | None = None
|
|
567
|
+
if val_data is not None:
|
|
568
|
+
callbacks.append(EvaluateAndSaveFinalStepCallback())
|
|
569
|
+
# evaluate on a randomly-sampled subset
|
|
570
|
+
fine_tune_eval_max_items = (
|
|
571
|
+
min(val_data.num_items, model_params["fine_tune_eval_max_items"])
|
|
572
|
+
if model_params["fine_tune_eval_max_items"] is not None
|
|
573
|
+
else val_data.num_items
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
if fine_tune_eval_max_items < val_data.num_items:
|
|
577
|
+
eval_items = np.random.choice(
|
|
578
|
+
val_data.item_ids.values, size=fine_tune_eval_max_items, replace=False
|
|
579
|
+
)
|
|
580
|
+
val_data = val_data.loc[eval_items]
|
|
581
|
+
|
|
582
|
+
assert isinstance(val_data, TimeSeriesDataFrame)
|
|
583
|
+
tokenizer_val_dataset = ChronosFineTuningDataset(
|
|
584
|
+
target_df=val_data,
|
|
585
|
+
target_column=self.target,
|
|
586
|
+
context_length=context_length,
|
|
587
|
+
prediction_length=fine_tune_prediction_length,
|
|
588
|
+
tokenizer=getattr(self.model_pipeline, "tokenizer", None),
|
|
589
|
+
mode="validation",
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
trainer = Trainer(
|
|
593
|
+
model=self.model_pipeline.inner_model,
|
|
594
|
+
args=training_args,
|
|
595
|
+
train_dataset=tokenizer_train_dataset,
|
|
596
|
+
eval_dataset=tokenizer_val_dataset,
|
|
597
|
+
callbacks=callbacks,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
# remove PrinterCallback from callbacks which logs to the console via a print() call,
|
|
601
|
+
# so it cannot be handled by setting the log level
|
|
602
|
+
trainer.pop_callback(PrinterCallback)
|
|
603
|
+
|
|
604
|
+
if verbosity >= 3:
|
|
605
|
+
logger.warning(
|
|
606
|
+
"Transformers logging is turned on during fine-tuning. Note that losses reported by transformers "
|
|
607
|
+
"do not correspond to those specified via `eval_metric`."
|
|
608
|
+
)
|
|
609
|
+
trainer.add_callback(LoggerCallback())
|
|
610
|
+
|
|
611
|
+
trainer.train()
|
|
612
|
+
|
|
613
|
+
fine_tuned_ckpt_path = Path(self.path) / self.fine_tuned_ckpt_name
|
|
614
|
+
logger.info(f"\tSaving fine-tuned model to {fine_tuned_ckpt_path}")
|
|
615
|
+
self.model_pipeline.inner_model.save_pretrained(Path(self.path) / self.fine_tuned_ckpt_name)
|
|
616
|
+
|
|
617
|
+
if not model_params["keep_transformers_logs"]:
|
|
618
|
+
logger.debug(f"Removing transformers_logs directory {output_dir}")
|
|
619
|
+
shutil.rmtree(output_dir)
|
|
620
|
+
|
|
621
|
+
def _get_inference_data_loader(
|
|
622
|
+
self,
|
|
623
|
+
data: TimeSeriesDataFrame,
|
|
624
|
+
context_length: int,
|
|
625
|
+
batch_size: int,
|
|
626
|
+
num_workers: int = 0,
|
|
627
|
+
time_limit: float | None = None,
|
|
628
|
+
):
|
|
629
|
+
from .utils import ChronosInferenceDataLoader, ChronosInferenceDataset, timeout_callback
|
|
630
|
+
|
|
631
|
+
chronos_dataset = ChronosInferenceDataset(
|
|
632
|
+
target_df=data,
|
|
633
|
+
target_column=self.target,
|
|
634
|
+
context_length=context_length,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
return ChronosInferenceDataLoader(
|
|
638
|
+
chronos_dataset,
|
|
639
|
+
batch_size=batch_size,
|
|
640
|
+
shuffle=False,
|
|
641
|
+
num_workers=num_workers,
|
|
642
|
+
on_batch=timeout_callback(seconds=time_limit),
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
def _get_context_length(self, data: TimeSeriesDataFrame) -> int:
|
|
646
|
+
context_length = self.context_length or min(
|
|
647
|
+
data.num_timesteps_per_item().max(),
|
|
648
|
+
self.maximum_context_length,
|
|
649
|
+
)
|
|
650
|
+
return context_length
|
|
651
|
+
|
|
652
|
+
def _predict(
|
|
653
|
+
self,
|
|
654
|
+
data: TimeSeriesDataFrame,
|
|
655
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
656
|
+
**kwargs,
|
|
657
|
+
) -> TimeSeriesDataFrame:
|
|
658
|
+
from chronos import ChronosBoltPipeline, ChronosPipeline
|
|
659
|
+
|
|
660
|
+
# We defer initialization of the model pipeline. i.e., the model is only loaded to device memory
|
|
661
|
+
# during inference. We also infer the maximum length of the time series in the inference data set
|
|
662
|
+
# and use that to determine the context length of the model. If the context length is specified
|
|
663
|
+
# during initialization, this is always used. If not, the context length is set to the longest
|
|
664
|
+
# item length. The context length is always capped by self.maximum_context_length.
|
|
665
|
+
# Note that this is independent of the model's own context length set in the model's config file.
|
|
666
|
+
# For example, if the context_length is set to 2048 here but the model expects context length
|
|
667
|
+
# (according to its config.json file) of 512, it will further truncate the series during inference.
|
|
668
|
+
context_length = self._get_context_length(data)
|
|
669
|
+
|
|
670
|
+
extra_predict_kwargs = (
|
|
671
|
+
{"num_samples": self.num_samples} if isinstance(self.model_pipeline, ChronosPipeline) else {}
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# adapt batch size for Chronos bolt if requested prediction length is longer than model prediction length
|
|
675
|
+
batch_size = self.batch_size
|
|
676
|
+
model_prediction_length = None
|
|
677
|
+
if isinstance(self.model_pipeline, ChronosBoltPipeline):
|
|
678
|
+
model_prediction_length = self.model_pipeline.model.config.chronos_config.get("prediction_length")
|
|
679
|
+
if model_prediction_length and self.prediction_length > model_prediction_length:
|
|
680
|
+
batch_size = max(1, batch_size // 4)
|
|
681
|
+
logger.debug(
|
|
682
|
+
f"\tThe prediction_length {self.prediction_length} exceeds model's prediction_length {model_prediction_length}. "
|
|
683
|
+
f"The inference batch_size has been reduced from {self.batch_size} to {batch_size} to avoid OOM errors."
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
with warning_filter(all_warnings=True):
|
|
687
|
+
import torch
|
|
688
|
+
|
|
689
|
+
self.model_pipeline.model.eval()
|
|
690
|
+
|
|
691
|
+
inference_data_loader = self._get_inference_data_loader(
|
|
692
|
+
data=data,
|
|
693
|
+
batch_size=batch_size,
|
|
694
|
+
num_workers=self.data_loader_num_workers,
|
|
695
|
+
context_length=context_length,
|
|
696
|
+
time_limit=kwargs.get("time_limit"),
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
with torch.inference_mode(), disable_duplicate_logs(logger):
|
|
700
|
+
batch_quantiles, batch_means = [], []
|
|
701
|
+
for batch in inference_data_loader:
|
|
702
|
+
try:
|
|
703
|
+
qs, mn = self.model_pipeline.predict_quantiles(
|
|
704
|
+
batch,
|
|
705
|
+
prediction_length=self.prediction_length,
|
|
706
|
+
quantile_levels=self.quantile_levels,
|
|
707
|
+
**extra_predict_kwargs,
|
|
708
|
+
)
|
|
709
|
+
except torch.OutOfMemoryError as ex:
|
|
710
|
+
logger.error(
|
|
711
|
+
"The call to predict() resulted in an out of memory error. Try reducing the batch_size by setting:"
|
|
712
|
+
f" predictor.fit(..., hyperparameters={{'Chronos': {{'batch_size': {batch_size // 2}, ...}}}})"
|
|
713
|
+
)
|
|
714
|
+
raise ex
|
|
715
|
+
batch_quantiles.append(qs.numpy())
|
|
716
|
+
batch_means.append(mn.numpy())
|
|
717
|
+
|
|
718
|
+
df = pd.DataFrame(
|
|
719
|
+
np.concatenate(
|
|
720
|
+
[
|
|
721
|
+
np.concatenate(batch_means, axis=0).reshape(-1, 1),
|
|
722
|
+
np.concatenate(batch_quantiles, axis=0).reshape(-1, len(self.quantile_levels)),
|
|
723
|
+
],
|
|
724
|
+
axis=1,
|
|
725
|
+
),
|
|
726
|
+
columns=["mean"] + [str(q) for q in self.quantile_levels],
|
|
727
|
+
index=self.get_forecast_horizon_index(data),
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
return TimeSeriesDataFrame(df)
|
|
731
|
+
|
|
732
|
+
def _more_tags(self) -> dict:
|
|
733
|
+
do_fine_tune = self.get_hyperparameter("fine_tune")
|
|
734
|
+
return {
|
|
735
|
+
"allow_nan": True,
|
|
736
|
+
"can_use_train_data": do_fine_tune,
|
|
737
|
+
"can_use_val_data": do_fine_tune,
|
|
738
|
+
}
|