autogluon.timeseries 1.1.2b20241109__py3-none-any.whl → 1.1.2b20241112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/timeseries/dataset/ts_dataframe.py +5 -1
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +73 -5
- autogluon/timeseries/models/chronos/model.py +67 -38
- autogluon/timeseries/models/chronos/pipeline/__init__.py +11 -0
- autogluon/timeseries/models/chronos/pipeline/base.py +146 -0
- autogluon/timeseries/models/chronos/{pipeline.py → pipeline/chronos.py} +66 -102
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +511 -0
- autogluon/timeseries/models/chronos/{utils.py → pipeline/utils.py} +37 -1
- autogluon/timeseries/models/gluonts/abstract_gluonts.py +1 -0
- autogluon/timeseries/models/gluonts/torch/models.py +3 -0
- autogluon/timeseries/models/local/abstract_local_model.py +4 -1
- autogluon/timeseries/models/local/statsforecast.py +3 -0
- autogluon/timeseries/models/multi_window/multi_window_model.py +5 -0
- autogluon/timeseries/predictor.py +1 -1
- autogluon/timeseries/regressor.py +146 -0
- autogluon/timeseries/transforms/scaler.py +1 -1
- autogluon/timeseries/utils/warning_filters.py +20 -0
- autogluon/timeseries/version.py +1 -1
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/METADATA +5 -5
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/RECORD +27 -23
- /autogluon.timeseries-1.1.2b20241109-py3.8-nspkg.pth → /autogluon.timeseries-1.1.2b20241112-py3.8-nspkg.pth +0 -0
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/LICENSE +0 -0
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/NOTICE +0 -0
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/WHEEL +0 -0
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.1.2b20241109.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/zip-safe +0 -0
@@ -0,0 +1,511 @@
|
|
1
|
+
# Implements Chronos with T5 architecture but with patched inputs instead of
|
2
|
+
# per-time-step tokenization. a.k.a. Chronos-Bolt
|
3
|
+
|
4
|
+
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>
|
5
|
+
|
6
|
+
import copy
|
7
|
+
import logging
|
8
|
+
import warnings
|
9
|
+
from dataclasses import dataclass, fields
|
10
|
+
from typing import List, Optional, Tuple, Union
|
11
|
+
|
12
|
+
import torch
|
13
|
+
import torch.nn as nn
|
14
|
+
from transformers import AutoConfig
|
15
|
+
from transformers.models.t5.modeling_t5 import (
|
16
|
+
ACT2FN,
|
17
|
+
T5Config,
|
18
|
+
T5LayerNorm,
|
19
|
+
T5PreTrainedModel,
|
20
|
+
T5Stack,
|
21
|
+
)
|
22
|
+
from transformers.utils import ModelOutput
|
23
|
+
|
24
|
+
from .base import BaseChronosPipeline, ForecastType
|
25
|
+
|
26
|
+
logger = logging.getLogger("autogluon.timeseries.models.chronos")
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class ChronosBoltConfig:
|
31
|
+
context_length: int
|
32
|
+
prediction_length: int
|
33
|
+
input_patch_size: int
|
34
|
+
input_patch_stride: int
|
35
|
+
quantiles: List[float]
|
36
|
+
use_reg_token: bool = False
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class ChronosBoltOutput(ModelOutput):
|
41
|
+
loss: Optional[torch.Tensor] = None
|
42
|
+
quantile_preds: Optional[torch.Tensor] = None
|
43
|
+
attentions: Optional[torch.Tensor] = None
|
44
|
+
cross_attentions: Optional[torch.Tensor] = None
|
45
|
+
|
46
|
+
|
47
|
+
class Patch(nn.Module):
|
48
|
+
def __init__(self, patch_size: int, patch_stride: int) -> None:
|
49
|
+
super().__init__()
|
50
|
+
self.patch_size = patch_size
|
51
|
+
self.patch_stride = patch_stride
|
52
|
+
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54
|
+
length = x.shape[-1]
|
55
|
+
|
56
|
+
if length % self.patch_size != 0:
|
57
|
+
padding_size = (
|
58
|
+
*x.shape[:-1],
|
59
|
+
self.patch_size - (length % self.patch_size),
|
60
|
+
)
|
61
|
+
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
62
|
+
x = torch.concat((padding, x), dim=-1)
|
63
|
+
|
64
|
+
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
65
|
+
return x
|
66
|
+
|
67
|
+
|
68
|
+
class InstanceNorm(nn.Module):
|
69
|
+
"""
|
70
|
+
See, also, RevIN. Apply standardization along the last dimension.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(self, eps: float = 1e-5) -> None:
|
74
|
+
super().__init__()
|
75
|
+
self.eps = eps
|
76
|
+
|
77
|
+
def forward(
|
78
|
+
self,
|
79
|
+
x: torch.Tensor,
|
80
|
+
loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
81
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
82
|
+
if loc_scale is None:
|
83
|
+
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
|
84
|
+
scale = torch.nan_to_num((x - loc).square().nanmean(dim=-1, keepdim=True).sqrt(), nan=1.0)
|
85
|
+
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
86
|
+
else:
|
87
|
+
loc, scale = loc_scale
|
88
|
+
|
89
|
+
return (x - loc) / scale, (loc, scale)
|
90
|
+
|
91
|
+
def inverse(self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
92
|
+
loc, scale = loc_scale
|
93
|
+
return x * scale + loc
|
94
|
+
|
95
|
+
|
96
|
+
class ResidualBlock(nn.Module):
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
in_dim: int,
|
100
|
+
h_dim: int,
|
101
|
+
out_dim: int,
|
102
|
+
act_fn_name: str,
|
103
|
+
dropout_p: float = 0.0,
|
104
|
+
use_layer_norm: bool = False,
|
105
|
+
) -> None:
|
106
|
+
super().__init__()
|
107
|
+
|
108
|
+
self.dropout = nn.Dropout(dropout_p)
|
109
|
+
self.hidden_layer = nn.Linear(in_dim, h_dim)
|
110
|
+
self.act = ACT2FN[act_fn_name]
|
111
|
+
self.output_layer = nn.Linear(h_dim, out_dim)
|
112
|
+
self.residual_layer = nn.Linear(in_dim, out_dim)
|
113
|
+
|
114
|
+
self.use_layer_norm = use_layer_norm
|
115
|
+
if use_layer_norm:
|
116
|
+
self.layer_norm = T5LayerNorm(out_dim)
|
117
|
+
|
118
|
+
def forward(self, x: torch.Tensor):
|
119
|
+
hid = self.act(self.hidden_layer(x))
|
120
|
+
out = self.dropout(self.output_layer(hid))
|
121
|
+
res = self.residual_layer(x)
|
122
|
+
|
123
|
+
out = out + res
|
124
|
+
|
125
|
+
if self.use_layer_norm:
|
126
|
+
return self.layer_norm(out)
|
127
|
+
return out
|
128
|
+
|
129
|
+
|
130
|
+
class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
131
|
+
_keys_to_ignore_on_load_missing = [
|
132
|
+
r"input_patch_embedding\.",
|
133
|
+
r"output_patch_embedding\.",
|
134
|
+
]
|
135
|
+
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
|
136
|
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
137
|
+
|
138
|
+
def __init__(self, config: T5Config):
|
139
|
+
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
140
|
+
|
141
|
+
super().__init__(config)
|
142
|
+
self.model_dim = config.d_model
|
143
|
+
|
144
|
+
# TODO: remove filtering eventually, added for backward compatibility
|
145
|
+
config_fields = {f.name for f in fields(ChronosBoltConfig)}
|
146
|
+
self.chronos_config = ChronosBoltConfig(
|
147
|
+
**{k: v for k, v in config.chronos_config.items() if k in config_fields}
|
148
|
+
)
|
149
|
+
|
150
|
+
# Only decoder_start_id (and optionally REG token)
|
151
|
+
if self.chronos_config.use_reg_token:
|
152
|
+
config.reg_token_id = 1
|
153
|
+
|
154
|
+
config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
|
155
|
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
156
|
+
|
157
|
+
# Input patch embedding layer
|
158
|
+
self.input_patch_embedding = ResidualBlock(
|
159
|
+
in_dim=self.chronos_config.input_patch_size * 2,
|
160
|
+
h_dim=config.d_ff,
|
161
|
+
out_dim=config.d_model,
|
162
|
+
act_fn_name=config.dense_act_fn,
|
163
|
+
dropout_p=config.dropout_rate,
|
164
|
+
)
|
165
|
+
|
166
|
+
# patching layer
|
167
|
+
self.patch = Patch(
|
168
|
+
patch_size=self.chronos_config.input_patch_size,
|
169
|
+
patch_stride=self.chronos_config.input_patch_stride,
|
170
|
+
)
|
171
|
+
|
172
|
+
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
|
173
|
+
self.instance_norm = InstanceNorm()
|
174
|
+
|
175
|
+
encoder_config = copy.deepcopy(config)
|
176
|
+
encoder_config.is_decoder = False
|
177
|
+
encoder_config.use_cache = False
|
178
|
+
encoder_config.is_encoder_decoder = False
|
179
|
+
self.encoder = T5Stack(encoder_config, self.shared)
|
180
|
+
|
181
|
+
self._init_decoder(config)
|
182
|
+
|
183
|
+
self.num_quantiles = len(self.chronos_config.quantiles)
|
184
|
+
quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
|
185
|
+
self.register_buffer("quantiles", quantiles, persistent=False)
|
186
|
+
|
187
|
+
self.output_patch_embedding = ResidualBlock(
|
188
|
+
in_dim=config.d_model,
|
189
|
+
h_dim=config.d_ff,
|
190
|
+
out_dim=self.num_quantiles * self.chronos_config.prediction_length,
|
191
|
+
act_fn_name=config.dense_act_fn,
|
192
|
+
dropout_p=config.dropout_rate,
|
193
|
+
)
|
194
|
+
|
195
|
+
# Initialize weights and apply final processing
|
196
|
+
self.post_init()
|
197
|
+
|
198
|
+
# Model parallel
|
199
|
+
self.model_parallel = False
|
200
|
+
self.device_map = None
|
201
|
+
|
202
|
+
def _init_weights(self, module):
|
203
|
+
super()._init_weights(module)
|
204
|
+
"""Initialize the weights"""
|
205
|
+
factor = self.config.initializer_factor
|
206
|
+
if isinstance(module, (self.__class__)):
|
207
|
+
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
208
|
+
elif isinstance(module, ResidualBlock):
|
209
|
+
module.hidden_layer.weight.data.normal_(
|
210
|
+
mean=0.0,
|
211
|
+
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
212
|
+
)
|
213
|
+
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
|
214
|
+
module.hidden_layer.bias.data.zero_()
|
215
|
+
|
216
|
+
module.residual_layer.weight.data.normal_(
|
217
|
+
mean=0.0,
|
218
|
+
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
219
|
+
)
|
220
|
+
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
|
221
|
+
module.residual_layer.bias.data.zero_()
|
222
|
+
|
223
|
+
module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
224
|
+
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
|
225
|
+
module.output_layer.bias.data.zero_()
|
226
|
+
|
227
|
+
def forward(
|
228
|
+
self,
|
229
|
+
context: torch.Tensor,
|
230
|
+
mask: Optional[torch.Tensor] = None,
|
231
|
+
target: Optional[torch.Tensor] = None,
|
232
|
+
target_mask: Optional[torch.Tensor] = None,
|
233
|
+
) -> ChronosBoltOutput:
|
234
|
+
mask = mask.to(context.dtype) if mask is not None else torch.isnan(context).logical_not().to(context.dtype)
|
235
|
+
|
236
|
+
batch_size, _ = context.shape
|
237
|
+
if context.shape[-1] > self.chronos_config.context_length:
|
238
|
+
context = context[..., -self.chronos_config.context_length :]
|
239
|
+
mask = mask[..., -self.chronos_config.context_length :]
|
240
|
+
|
241
|
+
# scaling
|
242
|
+
context, loc_scale = self.instance_norm(context)
|
243
|
+
|
244
|
+
# the scaling op above is done in 32-bit precision,
|
245
|
+
# then the context is moved to model's dtype
|
246
|
+
context = context.to(self.dtype)
|
247
|
+
mask = mask.to(self.dtype)
|
248
|
+
|
249
|
+
# patching
|
250
|
+
patched_context = self.patch(context)
|
251
|
+
patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
|
252
|
+
patched_context[~(patched_mask > 0)] = 0.0
|
253
|
+
# concat context and mask along patch dim
|
254
|
+
patched_context = torch.cat([patched_context, patched_mask], dim=-1)
|
255
|
+
|
256
|
+
# attention_mask = 1 if at least one item in the patch is observed
|
257
|
+
attention_mask = patched_mask.sum(dim=-1) > 0 # (batch_size, patched_seq_length)
|
258
|
+
|
259
|
+
input_embeds = self.input_patch_embedding(patched_context)
|
260
|
+
|
261
|
+
if self.chronos_config.use_reg_token:
|
262
|
+
# Append [REG]
|
263
|
+
reg_input_ids = torch.full(
|
264
|
+
(batch_size, 1),
|
265
|
+
self.config.reg_token_id,
|
266
|
+
device=input_embeds.device,
|
267
|
+
)
|
268
|
+
reg_embeds = self.shared(reg_input_ids)
|
269
|
+
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
|
270
|
+
attention_mask = torch.cat([attention_mask, torch.ones_like(reg_input_ids)], dim=-1)
|
271
|
+
|
272
|
+
encoder_outputs = self.encoder(
|
273
|
+
attention_mask=attention_mask,
|
274
|
+
inputs_embeds=input_embeds,
|
275
|
+
)
|
276
|
+
hidden_states = encoder_outputs[0]
|
277
|
+
|
278
|
+
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
|
279
|
+
|
280
|
+
quantile_preds_shape = (
|
281
|
+
batch_size,
|
282
|
+
self.num_quantiles,
|
283
|
+
self.chronos_config.prediction_length,
|
284
|
+
)
|
285
|
+
quantile_preds = self.output_patch_embedding(sequence_output).view(*quantile_preds_shape)
|
286
|
+
|
287
|
+
loss = None
|
288
|
+
if target is not None:
|
289
|
+
# normalize target
|
290
|
+
target, _ = self.instance_norm(target, loc_scale)
|
291
|
+
target = target.unsqueeze(1) # type: ignore
|
292
|
+
assert self.chronos_config.prediction_length == target.shape[-1]
|
293
|
+
|
294
|
+
target = target.to(quantile_preds.device)
|
295
|
+
target_mask = (
|
296
|
+
target_mask.unsqueeze(1).to(quantile_preds.device) if target_mask is not None else ~torch.isnan(target)
|
297
|
+
)
|
298
|
+
target[~target_mask] = 0.0
|
299
|
+
|
300
|
+
loss = (
|
301
|
+
2
|
302
|
+
* torch.abs(
|
303
|
+
(target - quantile_preds)
|
304
|
+
* ((target <= quantile_preds).float() - self.quantiles.view(1, self.num_quantiles, 1))
|
305
|
+
)
|
306
|
+
* target_mask.float()
|
307
|
+
)
|
308
|
+
loss = loss.mean(dim=-2) # Mean over prediction horizon
|
309
|
+
loss = loss.sum(dim=-1) # Sum over quantile levels
|
310
|
+
loss = loss.mean() # Mean over batch
|
311
|
+
|
312
|
+
# Unscale predictions
|
313
|
+
quantile_preds = self.instance_norm.inverse(
|
314
|
+
quantile_preds.view(batch_size, -1),
|
315
|
+
loc_scale,
|
316
|
+
).view(*quantile_preds_shape)
|
317
|
+
|
318
|
+
return ChronosBoltOutput(
|
319
|
+
loss=loss,
|
320
|
+
quantile_preds=quantile_preds,
|
321
|
+
)
|
322
|
+
|
323
|
+
def _init_decoder(self, config):
|
324
|
+
decoder_config = copy.deepcopy(config)
|
325
|
+
decoder_config.is_decoder = True
|
326
|
+
decoder_config.is_encoder_decoder = False
|
327
|
+
decoder_config.num_layers = config.num_decoder_layers
|
328
|
+
self.decoder = T5Stack(decoder_config, self.shared)
|
329
|
+
|
330
|
+
def decode(
|
331
|
+
self,
|
332
|
+
input_embeds,
|
333
|
+
attention_mask,
|
334
|
+
hidden_states,
|
335
|
+
output_attentions=False,
|
336
|
+
):
|
337
|
+
"""
|
338
|
+
Parameters
|
339
|
+
----------
|
340
|
+
input_embeds: torch.Tensor
|
341
|
+
Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
|
342
|
+
attention_mask: torch.Tensor
|
343
|
+
Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
|
344
|
+
hidden_states: torch.Tensor
|
345
|
+
Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
|
346
|
+
|
347
|
+
Returns
|
348
|
+
-------
|
349
|
+
last_hidden_state
|
350
|
+
Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
|
351
|
+
"""
|
352
|
+
batch_size = input_embeds.shape[0]
|
353
|
+
decoder_input_ids = torch.full(
|
354
|
+
(batch_size, 1),
|
355
|
+
self.config.decoder_start_token_id,
|
356
|
+
device=input_embeds.device,
|
357
|
+
)
|
358
|
+
decoder_outputs = self.decoder(
|
359
|
+
input_ids=decoder_input_ids,
|
360
|
+
encoder_hidden_states=hidden_states,
|
361
|
+
encoder_attention_mask=attention_mask,
|
362
|
+
output_attentions=output_attentions,
|
363
|
+
return_dict=True,
|
364
|
+
)
|
365
|
+
|
366
|
+
return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
|
367
|
+
|
368
|
+
|
369
|
+
class ChronosBoltPipeline(BaseChronosPipeline):
|
370
|
+
forecast_type: ForecastType = ForecastType.QUANTILES
|
371
|
+
default_context_length: int = 2048
|
372
|
+
# register this class name with this alias for backward compatibility
|
373
|
+
_aliases = ["PatchedT5Pipeline"]
|
374
|
+
|
375
|
+
def __init__(self, model: ChronosBoltModelForForecasting):
|
376
|
+
self.model = model
|
377
|
+
|
378
|
+
@property
|
379
|
+
def quantiles(self) -> List[float]:
|
380
|
+
return self.model.config.chronos_config["quantiles"]
|
381
|
+
|
382
|
+
def predict( # type: ignore[override]
|
383
|
+
self,
|
384
|
+
context: Union[torch.Tensor, List[torch.Tensor]],
|
385
|
+
prediction_length: Optional[int] = None,
|
386
|
+
limit_prediction_length: bool = False,
|
387
|
+
):
|
388
|
+
context_tensor = self._prepare_and_validate_context(context=context)
|
389
|
+
|
390
|
+
model_context_length = self.model.config.chronos_config["context_length"]
|
391
|
+
model_prediction_length = self.model.config.chronos_config["prediction_length"]
|
392
|
+
if prediction_length is None:
|
393
|
+
prediction_length = model_prediction_length
|
394
|
+
|
395
|
+
if prediction_length > model_prediction_length:
|
396
|
+
msg = (
|
397
|
+
f"We recommend keeping prediction length <= {model_prediction_length}. "
|
398
|
+
"The quality of longer predictions may degrade since the model is not optimized for it. "
|
399
|
+
)
|
400
|
+
if limit_prediction_length:
|
401
|
+
msg += "You can turn off this check by setting `limit_prediction_length=False`."
|
402
|
+
raise ValueError(msg)
|
403
|
+
warnings.warn(msg)
|
404
|
+
|
405
|
+
predictions = []
|
406
|
+
remaining = prediction_length
|
407
|
+
|
408
|
+
# We truncate the context here because otherwise batches with very long
|
409
|
+
# context could take up large amounts of GPU memory unnecessarily.
|
410
|
+
if context_tensor.shape[-1] > model_context_length:
|
411
|
+
context_tensor = context_tensor[..., -model_context_length:]
|
412
|
+
|
413
|
+
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
|
414
|
+
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
|
415
|
+
# every 64 steps.
|
416
|
+
while remaining > 0:
|
417
|
+
with torch.no_grad():
|
418
|
+
prediction = self.model(
|
419
|
+
context=context_tensor.to(
|
420
|
+
device=self.model.device,
|
421
|
+
dtype=torch.float32, # scaling should be done in 32-bit precision
|
422
|
+
),
|
423
|
+
).quantile_preds.to(context_tensor)
|
424
|
+
|
425
|
+
predictions.append(prediction)
|
426
|
+
remaining -= prediction.shape[-1]
|
427
|
+
|
428
|
+
if remaining <= 0:
|
429
|
+
break
|
430
|
+
|
431
|
+
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
|
432
|
+
central_prediction = prediction[:, central_idx]
|
433
|
+
|
434
|
+
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
|
435
|
+
|
436
|
+
return torch.cat(predictions, dim=-1)[..., :prediction_length]
|
437
|
+
|
438
|
+
def predict_quantiles(
|
439
|
+
self, context: torch.Tensor, prediction_length: int, quantile_levels: List[float], **kwargs
|
440
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
441
|
+
# shape (batch_size, prediction_length, len(training_quantile_levels))
|
442
|
+
predictions = (
|
443
|
+
self.predict(
|
444
|
+
context,
|
445
|
+
prediction_length=prediction_length,
|
446
|
+
)
|
447
|
+
.detach()
|
448
|
+
.cpu()
|
449
|
+
.swapaxes(1, 2)
|
450
|
+
)
|
451
|
+
|
452
|
+
training_quantile_levels = self.quantiles
|
453
|
+
|
454
|
+
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
455
|
+
# no need to perform intra/extrapolation
|
456
|
+
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
|
457
|
+
else:
|
458
|
+
# we rely on torch for interpolating quantiles if quantiles that
|
459
|
+
# Chronos Bolt was trained on were not provided
|
460
|
+
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
|
461
|
+
training_quantile_levels
|
462
|
+
):
|
463
|
+
logger.warning(
|
464
|
+
f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
|
465
|
+
f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
|
466
|
+
"Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
|
467
|
+
"was trained on. This may significantly affect the quality of the predictions."
|
468
|
+
)
|
469
|
+
|
470
|
+
# TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
|
471
|
+
# made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
|
472
|
+
# While this holds for official Chronos-Bolt models, this may not be true in the future, and this
|
473
|
+
# function may have to be revised.
|
474
|
+
augmented_predictions = torch.cat(
|
475
|
+
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
476
|
+
dim=-1,
|
477
|
+
)
|
478
|
+
quantiles = torch.quantile(
|
479
|
+
augmented_predictions, q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype), dim=-1
|
480
|
+
).permute(1, 2, 0)
|
481
|
+
mean = predictions[:, :, training_quantile_levels.index(0.5)]
|
482
|
+
return quantiles, mean
|
483
|
+
|
484
|
+
@classmethod
|
485
|
+
def from_pretrained(cls, *args, **kwargs):
|
486
|
+
"""
|
487
|
+
Load the model, either from a local path or from the HuggingFace Hub.
|
488
|
+
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
|
489
|
+
from ``transformers``.
|
490
|
+
"""
|
491
|
+
# if optimization_strategy is provided, pop this as it won't be used
|
492
|
+
kwargs.pop("optimization_strategy", None)
|
493
|
+
|
494
|
+
config = AutoConfig.from_pretrained(*args, **kwargs)
|
495
|
+
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
496
|
+
|
497
|
+
context_length = kwargs.pop("context_length", None)
|
498
|
+
if context_length is not None:
|
499
|
+
config.chronos_config["context_length"] = context_length
|
500
|
+
|
501
|
+
architecture = config.architectures[0]
|
502
|
+
class_ = globals().get(architecture)
|
503
|
+
|
504
|
+
# TODO: remove this once all models carry the correct architecture names in their configuration
|
505
|
+
# and raise an error instead.
|
506
|
+
if class_ is None:
|
507
|
+
logger.warning(f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting")
|
508
|
+
class_ = ChronosBoltModelForForecasting
|
509
|
+
|
510
|
+
model = class_.from_pretrained(*args, **kwargs)
|
511
|
+
return cls(model=model)
|
@@ -1,13 +1,49 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
1
3
|
import time
|
2
|
-
from
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Callable, List, Optional
|
3
6
|
|
4
7
|
import numpy as np
|
5
8
|
import torch
|
6
9
|
|
10
|
+
from autogluon.common.loaders.load_s3 import download, list_bucket_prefix_suffix_contains_s3
|
7
11
|
from autogluon.core.utils.exceptions import TimeLimitExceeded
|
8
12
|
from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame
|
9
13
|
|
10
14
|
|
15
|
+
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
16
|
+
max_len = max(len(c) for c in tensors)
|
17
|
+
padded = []
|
18
|
+
for c in tensors:
|
19
|
+
assert isinstance(c, torch.Tensor)
|
20
|
+
assert c.ndim == 1
|
21
|
+
padding = torch.full(size=(max_len - len(c),), fill_value=torch.nan, device=c.device)
|
22
|
+
padded.append(torch.concat((padding, c), dim=-1))
|
23
|
+
return torch.stack(padded)
|
24
|
+
|
25
|
+
|
26
|
+
def cache_model_from_s3(s3_uri: str, force=False):
|
27
|
+
if re.match("^s3://([^/]+)/(.*?([^/]+)/?)$", s3_uri) is None:
|
28
|
+
raise ValueError(f"Not a valid S3 URI: {s3_uri}")
|
29
|
+
|
30
|
+
# we expect the prefix to point to a "directory" on S3
|
31
|
+
if not s3_uri.endswith("/"):
|
32
|
+
s3_uri += "/"
|
33
|
+
|
34
|
+
cache_home = Path(os.environ.get("XDG_CACHE_HOME") or Path.home() / ".cache")
|
35
|
+
bucket, prefix = s3_uri.replace("s3://", "").split("/", 1)
|
36
|
+
bucket_cache_path = cache_home / "autogluon" / "timeseries" / bucket
|
37
|
+
|
38
|
+
for obj_path in list_bucket_prefix_suffix_contains_s3(bucket=bucket, prefix=prefix):
|
39
|
+
destination_path = bucket_cache_path / obj_path
|
40
|
+
if not force and destination_path.exists():
|
41
|
+
continue
|
42
|
+
download(bucket, obj_path, local_path=str(destination_path))
|
43
|
+
|
44
|
+
return str(bucket_cache_path / prefix)
|
45
|
+
|
46
|
+
|
11
47
|
class ChronosInferenceDataset:
|
12
48
|
"""A container for time series datasets that implements the ``torch.utils.data.Dataset`` interface"""
|
13
49
|
|
@@ -353,6 +353,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
353
353
|
columns = self.metadata.known_covariates_real
|
354
354
|
if self.supports_known_covariates and len(columns) > 0:
|
355
355
|
assert "known" in self._real_column_transformers, "Preprocessing pipeline must be fit first"
|
356
|
+
known_covariates = known_covariates.copy()
|
356
357
|
known_covariates[columns] = self._real_column_transformers["known"].transform(known_covariates[columns])
|
357
358
|
return known_covariates
|
358
359
|
|
@@ -339,6 +339,8 @@ class PatchTSTModel(AbstractGluonTSModel):
|
|
339
339
|
If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
|
340
340
|
"""
|
341
341
|
|
342
|
+
supports_known_covariates = True
|
343
|
+
|
342
344
|
@property
|
343
345
|
def default_context_length(self) -> int:
|
344
346
|
return 96
|
@@ -351,6 +353,7 @@ class PatchTSTModel(AbstractGluonTSModel):
|
|
351
353
|
def _get_estimator_init_args(self) -> Dict[str, Any]:
|
352
354
|
init_kwargs = super()._get_estimator_init_args()
|
353
355
|
init_kwargs.setdefault("patch_len", 16)
|
356
|
+
init_kwargs["num_feat_dynamic_real"] = self.num_feat_dynamic_real
|
354
357
|
return init_kwargs
|
355
358
|
|
356
359
|
|
@@ -113,8 +113,11 @@ class AbstractLocalModel(AbstractTimeSeriesModel):
|
|
113
113
|
local_model_args = {}
|
114
114
|
# TODO: Move filtering logic to AbstractTimeSeriesModel
|
115
115
|
for key, value in raw_local_model_args.items():
|
116
|
-
if key in self.
|
116
|
+
if key in self.allowed_local_model_args:
|
117
117
|
local_model_args[key] = value
|
118
|
+
elif key in self.allowed_hyperparameters:
|
119
|
+
# Quietly ignore params in self.allowed_hyperparameters - they are used by AbstractTimeSeriesModel
|
120
|
+
pass
|
118
121
|
else:
|
119
122
|
unused_local_model_args.append(key)
|
120
123
|
|
@@ -129,6 +129,7 @@ class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
|
|
129
129
|
This significantly speeds up fitting and usually leads to no change in accuracy.
|
130
130
|
"""
|
131
131
|
|
132
|
+
init_time_in_seconds = 0 # C++ models require no compilation
|
132
133
|
allowed_local_model_args = [
|
133
134
|
"d",
|
134
135
|
"D",
|
@@ -206,6 +207,7 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
|
|
206
207
|
This significantly speeds up fitting and usually leads to no change in accuracy.
|
207
208
|
"""
|
208
209
|
|
210
|
+
init_time_in_seconds = 0 # C++ models require no compilation
|
209
211
|
allowed_local_model_args = [
|
210
212
|
"order",
|
211
213
|
"seasonal_order",
|
@@ -261,6 +263,7 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
|
|
261
263
|
This significantly speeds up fitting and usually leads to no change in accuracy.
|
262
264
|
"""
|
263
265
|
|
266
|
+
init_time_in_seconds = 0 # C++ models require no compilation
|
264
267
|
allowed_local_model_args = [
|
265
268
|
"damped",
|
266
269
|
"model",
|
@@ -12,6 +12,7 @@ import autogluon.core as ag
|
|
12
12
|
from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame
|
13
13
|
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
14
14
|
from autogluon.timeseries.models.local.abstract_local_model import AbstractLocalModel
|
15
|
+
from autogluon.timeseries.regressor import CovariateRegressor
|
15
16
|
from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
|
16
17
|
from autogluon.timeseries.transforms import LocalTargetScaler
|
17
18
|
|
@@ -89,6 +90,10 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
89
90
|
# Do not use scaler in the MultiWindowModel to avoid duplication; it will be created in the inner model
|
90
91
|
return None
|
91
92
|
|
93
|
+
def _create_covariates_regressor(self) -> Optional[CovariateRegressor]:
|
94
|
+
# Do not use regressor in the MultiWindowModel to avoid duplication; it will be created in the inner model
|
95
|
+
return None
|
96
|
+
|
92
97
|
def _fit(
|
93
98
|
self,
|
94
99
|
train_data: TimeSeriesDataFrame,
|
@@ -293,7 +293,7 @@ class TimeSeriesPredictor(TimeSeriesPredictorDeprecatedMixin):
|
|
293
293
|
df = self._to_data_frame(data, name=name)
|
294
294
|
if not pd.api.types.is_numeric_dtype(df[self.target]):
|
295
295
|
raise ValueError(f"Target column {name}['{self.target}'] has a non-numeric dtype {df[self.target].dtype}")
|
296
|
-
df
|
296
|
+
df = df.assign(**{self.target: df[self.target].astype("float64")})
|
297
297
|
# MultiIndex.is_monotonic_increasing checks if index is sorted by ["item_id", "timestamp"]
|
298
298
|
if not df.index.is_monotonic_increasing:
|
299
299
|
df = df.sort_index()
|