autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +106 -0
- autogluon/timeseries/dataset/ts_dataframe.py +256 -141
- autogluon/timeseries/learner.py +86 -52
- autogluon/timeseries/metrics/__init__.py +42 -8
- autogluon/timeseries/metrics/abstract.py +89 -19
- autogluon/timeseries/metrics/point.py +142 -53
- autogluon/timeseries/metrics/quantile.py +46 -21
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +8 -2
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +219 -138
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
- autogluon/timeseries/models/ensemble/__init__.py +37 -2
- autogluon/timeseries/models/ensemble/abstract.py +107 -0
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
- autogluon/timeseries/models/gluonts/__init__.py +1 -1
- autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +71 -74
- autogluon/timeseries/models/local/naive.py +13 -9
- autogluon/timeseries/models/local/npts.py +9 -2
- autogluon/timeseries/models/local/statsforecast.py +52 -36
- autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
- autogluon/timeseries/models/toto/model.py +249 -0
- autogluon/timeseries/predictor.py +685 -297
- autogluon/timeseries/regressor.py +94 -44
- autogluon/timeseries/splitter.py +8 -32
- autogluon/timeseries/trainer/__init__.py +3 -0
- autogluon/timeseries/trainer/ensemble_composer.py +444 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -13
- autogluon/timeseries/transforms/covariate_scaler.py +34 -40
- autogluon/timeseries/transforms/target_scaler.py +37 -20
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +3 -5
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/datetime/time_features.py +2 -2
- autogluon/timeseries/utils/features.py +70 -47
- autogluon/timeseries/utils/forecast.py +19 -14
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +4 -2
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
- autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -79
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
- autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -360
- autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
|
@@ -1,518 +0,0 @@
|
|
|
1
|
-
# Implements Chronos with T5 architecture but with patched inputs instead of
|
|
2
|
-
# per-time-step tokenization. a.k.a. Chronos-Bolt
|
|
3
|
-
|
|
4
|
-
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>
|
|
5
|
-
|
|
6
|
-
import copy
|
|
7
|
-
import logging
|
|
8
|
-
import warnings
|
|
9
|
-
from dataclasses import dataclass, fields
|
|
10
|
-
from typing import List, Optional, Tuple, Union
|
|
11
|
-
|
|
12
|
-
import torch
|
|
13
|
-
import torch.nn as nn
|
|
14
|
-
from transformers import AutoConfig
|
|
15
|
-
from transformers.models.t5.modeling_t5 import (
|
|
16
|
-
ACT2FN,
|
|
17
|
-
T5Config,
|
|
18
|
-
T5LayerNorm,
|
|
19
|
-
T5PreTrainedModel,
|
|
20
|
-
T5Stack,
|
|
21
|
-
)
|
|
22
|
-
from transformers.utils import ModelOutput
|
|
23
|
-
|
|
24
|
-
from .base import BaseChronosPipeline, ForecastType
|
|
25
|
-
|
|
26
|
-
logger = logging.getLogger("autogluon.timeseries.models.chronos")
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class ChronosBoltConfig:
|
|
31
|
-
context_length: int
|
|
32
|
-
prediction_length: int
|
|
33
|
-
input_patch_size: int
|
|
34
|
-
input_patch_stride: int
|
|
35
|
-
quantiles: List[float]
|
|
36
|
-
use_reg_token: bool = False
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class ChronosBoltOutput(ModelOutput):
|
|
41
|
-
loss: Optional[torch.Tensor] = None
|
|
42
|
-
quantile_preds: Optional[torch.Tensor] = None
|
|
43
|
-
attentions: Optional[torch.Tensor] = None
|
|
44
|
-
cross_attentions: Optional[torch.Tensor] = None
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class Patch(nn.Module):
|
|
48
|
-
def __init__(self, patch_size: int, patch_stride: int) -> None:
|
|
49
|
-
super().__init__()
|
|
50
|
-
self.patch_size = patch_size
|
|
51
|
-
self.patch_stride = patch_stride
|
|
52
|
-
|
|
53
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
-
length = x.shape[-1]
|
|
55
|
-
|
|
56
|
-
if length % self.patch_size != 0:
|
|
57
|
-
padding_size = (
|
|
58
|
-
*x.shape[:-1],
|
|
59
|
-
self.patch_size - (length % self.patch_size),
|
|
60
|
-
)
|
|
61
|
-
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
|
62
|
-
x = torch.concat((padding, x), dim=-1)
|
|
63
|
-
|
|
64
|
-
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
|
65
|
-
return x
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class InstanceNorm(nn.Module):
|
|
69
|
-
"""
|
|
70
|
-
See, also, RevIN. Apply standardization along the last dimension.
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def __init__(self, eps: float = 1e-5) -> None:
|
|
74
|
-
super().__init__()
|
|
75
|
-
self.eps = eps
|
|
76
|
-
|
|
77
|
-
def forward(
|
|
78
|
-
self,
|
|
79
|
-
x: torch.Tensor,
|
|
80
|
-
loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
81
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
82
|
-
if loc_scale is None:
|
|
83
|
-
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
|
|
84
|
-
scale = torch.nan_to_num((x - loc).square().nanmean(dim=-1, keepdim=True).sqrt(), nan=1.0)
|
|
85
|
-
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
|
86
|
-
else:
|
|
87
|
-
loc, scale = loc_scale
|
|
88
|
-
|
|
89
|
-
return (x - loc) / scale, (loc, scale)
|
|
90
|
-
|
|
91
|
-
def inverse(self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
92
|
-
loc, scale = loc_scale
|
|
93
|
-
return x * scale + loc
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
class ResidualBlock(nn.Module):
|
|
97
|
-
def __init__(
|
|
98
|
-
self,
|
|
99
|
-
in_dim: int,
|
|
100
|
-
h_dim: int,
|
|
101
|
-
out_dim: int,
|
|
102
|
-
act_fn_name: str,
|
|
103
|
-
dropout_p: float = 0.0,
|
|
104
|
-
use_layer_norm: bool = False,
|
|
105
|
-
) -> None:
|
|
106
|
-
super().__init__()
|
|
107
|
-
|
|
108
|
-
self.dropout = nn.Dropout(dropout_p)
|
|
109
|
-
self.hidden_layer = nn.Linear(in_dim, h_dim)
|
|
110
|
-
self.act = ACT2FN[act_fn_name]
|
|
111
|
-
self.output_layer = nn.Linear(h_dim, out_dim)
|
|
112
|
-
self.residual_layer = nn.Linear(in_dim, out_dim)
|
|
113
|
-
|
|
114
|
-
self.use_layer_norm = use_layer_norm
|
|
115
|
-
if use_layer_norm:
|
|
116
|
-
self.layer_norm = T5LayerNorm(out_dim)
|
|
117
|
-
|
|
118
|
-
def forward(self, x: torch.Tensor):
|
|
119
|
-
hid = self.act(self.hidden_layer(x))
|
|
120
|
-
out = self.dropout(self.output_layer(hid))
|
|
121
|
-
res = self.residual_layer(x)
|
|
122
|
-
|
|
123
|
-
out = out + res
|
|
124
|
-
|
|
125
|
-
if self.use_layer_norm:
|
|
126
|
-
return self.layer_norm(out)
|
|
127
|
-
return out
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|
131
|
-
_keys_to_ignore_on_load_missing = [
|
|
132
|
-
r"input_patch_embedding\.",
|
|
133
|
-
r"output_patch_embedding\.",
|
|
134
|
-
]
|
|
135
|
-
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
|
|
136
|
-
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
|
137
|
-
|
|
138
|
-
def __init__(self, config: T5Config):
|
|
139
|
-
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
|
140
|
-
|
|
141
|
-
super().__init__(config)
|
|
142
|
-
self.model_dim = config.d_model
|
|
143
|
-
|
|
144
|
-
# TODO: remove filtering eventually, added for backward compatibility
|
|
145
|
-
config_fields = {f.name for f in fields(ChronosBoltConfig)}
|
|
146
|
-
self.chronos_config = ChronosBoltConfig(
|
|
147
|
-
**{k: v for k, v in config.chronos_config.items() if k in config_fields}
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
# Only decoder_start_id (and optionally REG token)
|
|
151
|
-
if self.chronos_config.use_reg_token:
|
|
152
|
-
config.reg_token_id = 1
|
|
153
|
-
|
|
154
|
-
config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
|
|
155
|
-
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
|
156
|
-
|
|
157
|
-
# Input patch embedding layer
|
|
158
|
-
self.input_patch_embedding = ResidualBlock(
|
|
159
|
-
in_dim=self.chronos_config.input_patch_size * 2,
|
|
160
|
-
h_dim=config.d_ff,
|
|
161
|
-
out_dim=config.d_model,
|
|
162
|
-
act_fn_name=config.dense_act_fn,
|
|
163
|
-
dropout_p=config.dropout_rate,
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
# patching layer
|
|
167
|
-
self.patch = Patch(
|
|
168
|
-
patch_size=self.chronos_config.input_patch_size,
|
|
169
|
-
patch_stride=self.chronos_config.input_patch_stride,
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
|
|
173
|
-
self.instance_norm = InstanceNorm()
|
|
174
|
-
|
|
175
|
-
encoder_config = copy.deepcopy(config)
|
|
176
|
-
encoder_config.is_decoder = False
|
|
177
|
-
encoder_config.use_cache = False
|
|
178
|
-
encoder_config.is_encoder_decoder = False
|
|
179
|
-
self.encoder = T5Stack(encoder_config, self.shared)
|
|
180
|
-
|
|
181
|
-
self._init_decoder(config)
|
|
182
|
-
|
|
183
|
-
self.num_quantiles = len(self.chronos_config.quantiles)
|
|
184
|
-
quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
|
|
185
|
-
self.register_buffer("quantiles", quantiles, persistent=False)
|
|
186
|
-
|
|
187
|
-
self.output_patch_embedding = ResidualBlock(
|
|
188
|
-
in_dim=config.d_model,
|
|
189
|
-
h_dim=config.d_ff,
|
|
190
|
-
out_dim=self.num_quantiles * self.chronos_config.prediction_length,
|
|
191
|
-
act_fn_name=config.dense_act_fn,
|
|
192
|
-
dropout_p=config.dropout_rate,
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
# Initialize weights and apply final processing
|
|
196
|
-
self.post_init()
|
|
197
|
-
|
|
198
|
-
# Model parallel
|
|
199
|
-
self.model_parallel = False
|
|
200
|
-
self.device_map = None
|
|
201
|
-
|
|
202
|
-
def _init_weights(self, module):
|
|
203
|
-
super()._init_weights(module)
|
|
204
|
-
"""Initialize the weights"""
|
|
205
|
-
factor = self.config.initializer_factor
|
|
206
|
-
if isinstance(module, (self.__class__)):
|
|
207
|
-
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
|
208
|
-
elif isinstance(module, ResidualBlock):
|
|
209
|
-
module.hidden_layer.weight.data.normal_(
|
|
210
|
-
mean=0.0,
|
|
211
|
-
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
|
212
|
-
)
|
|
213
|
-
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
|
|
214
|
-
module.hidden_layer.bias.data.zero_()
|
|
215
|
-
|
|
216
|
-
module.residual_layer.weight.data.normal_(
|
|
217
|
-
mean=0.0,
|
|
218
|
-
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
|
219
|
-
)
|
|
220
|
-
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
|
|
221
|
-
module.residual_layer.bias.data.zero_()
|
|
222
|
-
|
|
223
|
-
module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
|
224
|
-
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
|
|
225
|
-
module.output_layer.bias.data.zero_()
|
|
226
|
-
|
|
227
|
-
def forward(
|
|
228
|
-
self,
|
|
229
|
-
context: torch.Tensor,
|
|
230
|
-
mask: Optional[torch.Tensor] = None,
|
|
231
|
-
target: Optional[torch.Tensor] = None,
|
|
232
|
-
target_mask: Optional[torch.Tensor] = None,
|
|
233
|
-
) -> ChronosBoltOutput:
|
|
234
|
-
mask = mask.to(context.dtype) if mask is not None else torch.isnan(context).logical_not().to(context.dtype)
|
|
235
|
-
|
|
236
|
-
batch_size, _ = context.shape
|
|
237
|
-
if context.shape[-1] > self.chronos_config.context_length:
|
|
238
|
-
context = context[..., -self.chronos_config.context_length :]
|
|
239
|
-
mask = mask[..., -self.chronos_config.context_length :]
|
|
240
|
-
|
|
241
|
-
# scaling
|
|
242
|
-
context, loc_scale = self.instance_norm(context)
|
|
243
|
-
|
|
244
|
-
# the scaling op above is done in 32-bit precision,
|
|
245
|
-
# then the context is moved to model's dtype
|
|
246
|
-
context = context.to(self.dtype)
|
|
247
|
-
mask = mask.to(self.dtype)
|
|
248
|
-
|
|
249
|
-
# patching
|
|
250
|
-
patched_context = self.patch(context)
|
|
251
|
-
patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
|
|
252
|
-
patched_context[~(patched_mask > 0)] = 0.0
|
|
253
|
-
# concat context and mask along patch dim
|
|
254
|
-
patched_context = torch.cat([patched_context, patched_mask], dim=-1)
|
|
255
|
-
|
|
256
|
-
# attention_mask = 1 if at least one item in the patch is observed
|
|
257
|
-
attention_mask = patched_mask.sum(dim=-1) > 0 # (batch_size, patched_seq_length)
|
|
258
|
-
|
|
259
|
-
input_embeds = self.input_patch_embedding(patched_context)
|
|
260
|
-
|
|
261
|
-
if self.chronos_config.use_reg_token:
|
|
262
|
-
# Append [REG]
|
|
263
|
-
reg_input_ids = torch.full(
|
|
264
|
-
(batch_size, 1),
|
|
265
|
-
self.config.reg_token_id,
|
|
266
|
-
device=input_embeds.device,
|
|
267
|
-
)
|
|
268
|
-
reg_embeds = self.shared(reg_input_ids)
|
|
269
|
-
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
|
|
270
|
-
attention_mask = torch.cat([attention_mask, torch.ones_like(reg_input_ids)], dim=-1)
|
|
271
|
-
|
|
272
|
-
encoder_outputs = self.encoder(
|
|
273
|
-
attention_mask=attention_mask,
|
|
274
|
-
inputs_embeds=input_embeds,
|
|
275
|
-
)
|
|
276
|
-
hidden_states = encoder_outputs[0]
|
|
277
|
-
|
|
278
|
-
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
|
|
279
|
-
|
|
280
|
-
quantile_preds_shape = (
|
|
281
|
-
batch_size,
|
|
282
|
-
self.num_quantiles,
|
|
283
|
-
self.chronos_config.prediction_length,
|
|
284
|
-
)
|
|
285
|
-
quantile_preds = self.output_patch_embedding(sequence_output).view(*quantile_preds_shape)
|
|
286
|
-
|
|
287
|
-
loss = None
|
|
288
|
-
if target is not None:
|
|
289
|
-
# normalize target
|
|
290
|
-
target, _ = self.instance_norm(target, loc_scale)
|
|
291
|
-
target = target.unsqueeze(1) # type: ignore
|
|
292
|
-
assert self.chronos_config.prediction_length >= target.shape[-1]
|
|
293
|
-
|
|
294
|
-
target = target.to(quantile_preds.device)
|
|
295
|
-
target_mask = (
|
|
296
|
-
target_mask.unsqueeze(1).to(quantile_preds.device) if target_mask is not None else ~torch.isnan(target)
|
|
297
|
-
)
|
|
298
|
-
target[~target_mask] = 0.0
|
|
299
|
-
|
|
300
|
-
# pad target and target_mask if they are shorter than model's prediction_length
|
|
301
|
-
if self.chronos_config.prediction_length > target.shape[-1]:
|
|
302
|
-
padding_shape = (*target.shape[:-1], self.chronos_config.prediction_length - target.shape[-1])
|
|
303
|
-
target = torch.cat([target, torch.zeros(padding_shape).to(target)], dim=-1)
|
|
304
|
-
target_mask = torch.cat([target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1)
|
|
305
|
-
|
|
306
|
-
loss = (
|
|
307
|
-
2
|
|
308
|
-
* torch.abs(
|
|
309
|
-
(target - quantile_preds)
|
|
310
|
-
* ((target <= quantile_preds).float() - self.quantiles.view(1, self.num_quantiles, 1))
|
|
311
|
-
)
|
|
312
|
-
* target_mask.float()
|
|
313
|
-
)
|
|
314
|
-
loss = loss.mean(dim=-2) # Mean over prediction horizon
|
|
315
|
-
loss = loss.sum(dim=-1) # Sum over quantile levels
|
|
316
|
-
loss = loss.mean() # Mean over batch
|
|
317
|
-
|
|
318
|
-
# Unscale predictions
|
|
319
|
-
quantile_preds = self.instance_norm.inverse(
|
|
320
|
-
quantile_preds.view(batch_size, -1),
|
|
321
|
-
loc_scale,
|
|
322
|
-
).view(*quantile_preds_shape)
|
|
323
|
-
|
|
324
|
-
return ChronosBoltOutput(
|
|
325
|
-
loss=loss,
|
|
326
|
-
quantile_preds=quantile_preds,
|
|
327
|
-
)
|
|
328
|
-
|
|
329
|
-
def _init_decoder(self, config):
|
|
330
|
-
decoder_config = copy.deepcopy(config)
|
|
331
|
-
decoder_config.is_decoder = True
|
|
332
|
-
decoder_config.is_encoder_decoder = False
|
|
333
|
-
decoder_config.num_layers = config.num_decoder_layers
|
|
334
|
-
self.decoder = T5Stack(decoder_config, self.shared)
|
|
335
|
-
|
|
336
|
-
def decode(
|
|
337
|
-
self,
|
|
338
|
-
input_embeds,
|
|
339
|
-
attention_mask,
|
|
340
|
-
hidden_states,
|
|
341
|
-
output_attentions=False,
|
|
342
|
-
):
|
|
343
|
-
"""
|
|
344
|
-
Parameters
|
|
345
|
-
----------
|
|
346
|
-
input_embeds: torch.Tensor
|
|
347
|
-
Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
|
|
348
|
-
attention_mask: torch.Tensor
|
|
349
|
-
Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
|
|
350
|
-
hidden_states: torch.Tensor
|
|
351
|
-
Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
|
|
352
|
-
|
|
353
|
-
Returns
|
|
354
|
-
-------
|
|
355
|
-
last_hidden_state
|
|
356
|
-
Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
|
|
357
|
-
"""
|
|
358
|
-
batch_size = input_embeds.shape[0]
|
|
359
|
-
decoder_input_ids = torch.full(
|
|
360
|
-
(batch_size, 1),
|
|
361
|
-
self.config.decoder_start_token_id,
|
|
362
|
-
device=input_embeds.device,
|
|
363
|
-
)
|
|
364
|
-
decoder_outputs = self.decoder(
|
|
365
|
-
input_ids=decoder_input_ids,
|
|
366
|
-
encoder_hidden_states=hidden_states,
|
|
367
|
-
encoder_attention_mask=attention_mask,
|
|
368
|
-
output_attentions=output_attentions,
|
|
369
|
-
return_dict=True,
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
class ChronosBoltPipeline(BaseChronosPipeline):
|
|
376
|
-
forecast_type: ForecastType = ForecastType.QUANTILES
|
|
377
|
-
default_context_length: int = 2048
|
|
378
|
-
# register this class name with this alias for backward compatibility
|
|
379
|
-
_aliases = ["PatchedT5Pipeline"]
|
|
380
|
-
|
|
381
|
-
def __init__(self, model: ChronosBoltModelForForecasting):
|
|
382
|
-
super().__init__(inner_model=model)
|
|
383
|
-
self.model = model
|
|
384
|
-
|
|
385
|
-
@property
|
|
386
|
-
def quantiles(self) -> List[float]:
|
|
387
|
-
return self.model.config.chronos_config["quantiles"]
|
|
388
|
-
|
|
389
|
-
def predict( # type: ignore[override]
|
|
390
|
-
self,
|
|
391
|
-
context: Union[torch.Tensor, List[torch.Tensor]],
|
|
392
|
-
prediction_length: Optional[int] = None,
|
|
393
|
-
limit_prediction_length: bool = False,
|
|
394
|
-
):
|
|
395
|
-
context_tensor = self._prepare_and_validate_context(context=context)
|
|
396
|
-
|
|
397
|
-
model_context_length = self.model.config.chronos_config["context_length"]
|
|
398
|
-
model_prediction_length = self.model.config.chronos_config["prediction_length"]
|
|
399
|
-
if prediction_length is None:
|
|
400
|
-
prediction_length = model_prediction_length
|
|
401
|
-
|
|
402
|
-
if prediction_length > model_prediction_length:
|
|
403
|
-
msg = (
|
|
404
|
-
f"We recommend keeping prediction length <= {model_prediction_length}. "
|
|
405
|
-
"The quality of longer predictions may degrade since the model is not optimized for it. "
|
|
406
|
-
)
|
|
407
|
-
if limit_prediction_length:
|
|
408
|
-
msg += "You can turn off this check by setting `limit_prediction_length=False`."
|
|
409
|
-
raise ValueError(msg)
|
|
410
|
-
warnings.warn(msg)
|
|
411
|
-
|
|
412
|
-
predictions = []
|
|
413
|
-
remaining = prediction_length
|
|
414
|
-
|
|
415
|
-
# We truncate the context here because otherwise batches with very long
|
|
416
|
-
# context could take up large amounts of GPU memory unnecessarily.
|
|
417
|
-
if context_tensor.shape[-1] > model_context_length:
|
|
418
|
-
context_tensor = context_tensor[..., -model_context_length:]
|
|
419
|
-
|
|
420
|
-
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
|
|
421
|
-
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
|
|
422
|
-
# every 64 steps.
|
|
423
|
-
while remaining > 0:
|
|
424
|
-
with torch.no_grad():
|
|
425
|
-
prediction = self.model(
|
|
426
|
-
context=context_tensor.to(
|
|
427
|
-
device=self.model.device,
|
|
428
|
-
dtype=torch.float32, # scaling should be done in 32-bit precision
|
|
429
|
-
),
|
|
430
|
-
).quantile_preds.to(context_tensor)
|
|
431
|
-
|
|
432
|
-
predictions.append(prediction)
|
|
433
|
-
remaining -= prediction.shape[-1]
|
|
434
|
-
|
|
435
|
-
if remaining <= 0:
|
|
436
|
-
break
|
|
437
|
-
|
|
438
|
-
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
|
|
439
|
-
central_prediction = prediction[:, central_idx]
|
|
440
|
-
|
|
441
|
-
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
|
|
442
|
-
|
|
443
|
-
return torch.cat(predictions, dim=-1)[..., :prediction_length]
|
|
444
|
-
|
|
445
|
-
def predict_quantiles(
|
|
446
|
-
self, context: torch.Tensor, prediction_length: int, quantile_levels: List[float], **kwargs
|
|
447
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
448
|
-
# shape (batch_size, prediction_length, len(training_quantile_levels))
|
|
449
|
-
predictions = (
|
|
450
|
-
self.predict(
|
|
451
|
-
context,
|
|
452
|
-
prediction_length=prediction_length,
|
|
453
|
-
)
|
|
454
|
-
.detach()
|
|
455
|
-
.cpu()
|
|
456
|
-
.swapaxes(1, 2)
|
|
457
|
-
)
|
|
458
|
-
|
|
459
|
-
training_quantile_levels = self.quantiles
|
|
460
|
-
|
|
461
|
-
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
|
462
|
-
# no need to perform intra/extrapolation
|
|
463
|
-
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
|
|
464
|
-
else:
|
|
465
|
-
# we rely on torch for interpolating quantiles if quantiles that
|
|
466
|
-
# Chronos Bolt was trained on were not provided
|
|
467
|
-
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
|
|
468
|
-
training_quantile_levels
|
|
469
|
-
):
|
|
470
|
-
logger.warning(
|
|
471
|
-
f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
|
|
472
|
-
f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
|
|
473
|
-
"Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
|
|
474
|
-
"was trained on. This may significantly affect the quality of the predictions."
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
# TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
|
|
478
|
-
# made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
|
|
479
|
-
# While this holds for official Chronos-Bolt models, this may not be true in the future, and this
|
|
480
|
-
# function may have to be revised.
|
|
481
|
-
augmented_predictions = torch.cat(
|
|
482
|
-
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
|
483
|
-
dim=-1,
|
|
484
|
-
)
|
|
485
|
-
quantiles = torch.quantile(
|
|
486
|
-
augmented_predictions, q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype), dim=-1
|
|
487
|
-
).permute(1, 2, 0)
|
|
488
|
-
mean = predictions[:, :, training_quantile_levels.index(0.5)]
|
|
489
|
-
return quantiles, mean
|
|
490
|
-
|
|
491
|
-
@classmethod
|
|
492
|
-
def from_pretrained(cls, *args, **kwargs):
|
|
493
|
-
"""
|
|
494
|
-
Load the model, either from a local path or from the HuggingFace Hub.
|
|
495
|
-
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
|
|
496
|
-
from ``transformers``.
|
|
497
|
-
"""
|
|
498
|
-
# if optimization_strategy is provided, pop this as it won't be used
|
|
499
|
-
kwargs.pop("optimization_strategy", None)
|
|
500
|
-
|
|
501
|
-
config = AutoConfig.from_pretrained(*args, **kwargs)
|
|
502
|
-
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
|
503
|
-
|
|
504
|
-
context_length = kwargs.pop("context_length", None)
|
|
505
|
-
if context_length is not None:
|
|
506
|
-
config.chronos_config["context_length"] = context_length
|
|
507
|
-
|
|
508
|
-
architecture = config.architectures[0]
|
|
509
|
-
class_ = globals().get(architecture)
|
|
510
|
-
|
|
511
|
-
# TODO: remove this once all models carry the correct architecture names in their configuration
|
|
512
|
-
# and raise an error instead.
|
|
513
|
-
if class_ is None:
|
|
514
|
-
logger.warning(f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting")
|
|
515
|
-
class_ = ChronosBoltModelForForecasting
|
|
516
|
-
|
|
517
|
-
model = class_.from_pretrained(*args, **kwargs)
|
|
518
|
-
return cls(model=model)
|
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from typing import Dict, List, Optional
|
|
3
|
-
|
|
4
|
-
from autogluon.core.utils.exceptions import TimeLimitExceeded
|
|
5
|
-
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
6
|
-
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
7
|
-
|
|
8
|
-
logger = logging.getLogger(__name__)
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class AbstractTimeSeriesEnsembleModel(AbstractTimeSeriesModel):
|
|
12
|
-
"""Abstract class for time series ensemble models."""
|
|
13
|
-
|
|
14
|
-
@property
|
|
15
|
-
def model_names(self) -> List[str]:
|
|
16
|
-
"""Names of base models included in the ensemble."""
|
|
17
|
-
raise NotImplementedError
|
|
18
|
-
|
|
19
|
-
def fit_ensemble(
|
|
20
|
-
self,
|
|
21
|
-
predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
|
|
22
|
-
data_per_window: List[TimeSeriesDataFrame],
|
|
23
|
-
time_limit: Optional[float] = None,
|
|
24
|
-
**kwargs,
|
|
25
|
-
):
|
|
26
|
-
"""Fit ensemble model given predictions of candidate base models and the true data.
|
|
27
|
-
|
|
28
|
-
Parameters
|
|
29
|
-
----------
|
|
30
|
-
predictions_per_window : Dict[str, List[TimeSeriesDataFrame]]
|
|
31
|
-
Dictionary that maps the names of component models to their respective predictions for each validation
|
|
32
|
-
window.
|
|
33
|
-
data_per_window : List[TimeSeriesDataFrame]
|
|
34
|
-
Observed ground truth data used to train the ensemble for each validation window. Each entry in the list
|
|
35
|
-
includes both the forecast horizon (for which the predictions are given in ``predictions``), as well as the
|
|
36
|
-
"history".
|
|
37
|
-
time_limit : Optional[int]
|
|
38
|
-
Maximum allowed time for training in seconds.
|
|
39
|
-
"""
|
|
40
|
-
if time_limit is not None and time_limit <= 0:
|
|
41
|
-
logger.warning(
|
|
42
|
-
f"\tWarning: Model has no time left to train, skipping model... (Time Left = {round(time_limit, 1)}s)"
|
|
43
|
-
)
|
|
44
|
-
raise TimeLimitExceeded
|
|
45
|
-
if isinstance(data_per_window, TimeSeriesDataFrame):
|
|
46
|
-
raise ValueError("When fitting ensemble, `data` should contain ground truth for each validation window")
|
|
47
|
-
num_val_windows = len(data_per_window)
|
|
48
|
-
for model, preds in predictions_per_window.items():
|
|
49
|
-
if len(preds) != num_val_windows:
|
|
50
|
-
raise ValueError(f"For model {model} predictions are unavailable for some validation windows")
|
|
51
|
-
self._fit_ensemble(
|
|
52
|
-
predictions_per_window=predictions_per_window,
|
|
53
|
-
data_per_window=data_per_window,
|
|
54
|
-
time_limit=time_limit,
|
|
55
|
-
)
|
|
56
|
-
return self
|
|
57
|
-
|
|
58
|
-
def _fit_ensemble(
|
|
59
|
-
self,
|
|
60
|
-
predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
|
|
61
|
-
data_per_window: List[TimeSeriesDataFrame],
|
|
62
|
-
time_limit: Optional[int] = None,
|
|
63
|
-
**kwargs,
|
|
64
|
-
):
|
|
65
|
-
"""Private method for `fit_ensemble`. See `fit_ensemble` for documentation of arguments. Apart from the model
|
|
66
|
-
training logic, `fit_ensemble` additionally implements other logic such as keeping track of the time limit.
|
|
67
|
-
"""
|
|
68
|
-
raise NotImplementedError
|
|
69
|
-
|
|
70
|
-
def predict(self, data: Dict[str, Optional[TimeSeriesDataFrame]], **kwargs) -> TimeSeriesDataFrame:
|
|
71
|
-
raise NotImplementedError
|
|
72
|
-
|
|
73
|
-
def remap_base_models(self, model_refit_map: Dict[str, str]) -> None:
|
|
74
|
-
"""Update names of the base models based on the mapping in model_refit_map.
|
|
75
|
-
|
|
76
|
-
This method should be called after performing refit_full to point to the refitted base models, if necessary.
|
|
77
|
-
"""
|
|
78
|
-
raise NotImplementedError
|