autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +106 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +256 -141
  5. autogluon/timeseries/learner.py +86 -52
  6. autogluon/timeseries/metrics/__init__.py +42 -8
  7. autogluon/timeseries/metrics/abstract.py +89 -19
  8. autogluon/timeseries/metrics/point.py +142 -53
  9. autogluon/timeseries/metrics/quantile.py +46 -21
  10. autogluon/timeseries/metrics/utils.py +4 -4
  11. autogluon/timeseries/models/__init__.py +8 -2
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
  14. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
  20. autogluon/timeseries/models/chronos/__init__.py +2 -1
  21. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  22. autogluon/timeseries/models/chronos/model.py +219 -138
  23. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
  24. autogluon/timeseries/models/ensemble/__init__.py +37 -2
  25. autogluon/timeseries/models/ensemble/abstract.py +107 -0
  26. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  27. autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
  28. autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  34. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  35. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  36. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  37. autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
  38. autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
  39. autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
  40. autogluon/timeseries/models/gluonts/__init__.py +1 -1
  41. autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
  42. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  43. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
  44. autogluon/timeseries/models/local/__init__.py +0 -7
  45. autogluon/timeseries/models/local/abstract_local_model.py +71 -74
  46. autogluon/timeseries/models/local/naive.py +13 -9
  47. autogluon/timeseries/models/local/npts.py +9 -2
  48. autogluon/timeseries/models/local/statsforecast.py +52 -36
  49. autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
  50. autogluon/timeseries/models/registry.py +64 -0
  51. autogluon/timeseries/models/toto/__init__.py +3 -0
  52. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  62. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  63. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  64. autogluon/timeseries/models/toto/dataloader.py +108 -0
  65. autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
  66. autogluon/timeseries/models/toto/model.py +249 -0
  67. autogluon/timeseries/predictor.py +685 -297
  68. autogluon/timeseries/regressor.py +94 -44
  69. autogluon/timeseries/splitter.py +8 -32
  70. autogluon/timeseries/trainer/__init__.py +3 -0
  71. autogluon/timeseries/trainer/ensemble_composer.py +444 -0
  72. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  73. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  74. autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
  75. autogluon/timeseries/trainer/utils.py +17 -0
  76. autogluon/timeseries/transforms/__init__.py +2 -13
  77. autogluon/timeseries/transforms/covariate_scaler.py +34 -40
  78. autogluon/timeseries/transforms/target_scaler.py +37 -20
  79. autogluon/timeseries/utils/constants.py +10 -0
  80. autogluon/timeseries/utils/datetime/lags.py +3 -5
  81. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  82. autogluon/timeseries/utils/datetime/time_features.py +2 -2
  83. autogluon/timeseries/utils/features.py +70 -47
  84. autogluon/timeseries/utils/forecast.py +19 -14
  85. autogluon/timeseries/utils/timer.py +173 -0
  86. autogluon/timeseries/utils/warning_filters.py +4 -2
  87. autogluon/timeseries/version.py +1 -1
  88. autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
  89. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
  90. autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
  91. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
  92. autogluon/timeseries/configs/presets_configs.py +0 -79
  93. autogluon/timeseries/evaluator.py +0 -6
  94. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
  95. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  96. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
  97. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
  98. autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
  99. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  100. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  101. autogluon/timeseries/models/presets.py +0 -360
  102. autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
  103. autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
  104. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
@@ -1,518 +0,0 @@
1
- # Implements Chronos with T5 architecture but with patched inputs instead of
2
- # per-time-step tokenization. a.k.a. Chronos-Bolt
3
-
4
- # Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>
5
-
6
- import copy
7
- import logging
8
- import warnings
9
- from dataclasses import dataclass, fields
10
- from typing import List, Optional, Tuple, Union
11
-
12
- import torch
13
- import torch.nn as nn
14
- from transformers import AutoConfig
15
- from transformers.models.t5.modeling_t5 import (
16
- ACT2FN,
17
- T5Config,
18
- T5LayerNorm,
19
- T5PreTrainedModel,
20
- T5Stack,
21
- )
22
- from transformers.utils import ModelOutput
23
-
24
- from .base import BaseChronosPipeline, ForecastType
25
-
26
- logger = logging.getLogger("autogluon.timeseries.models.chronos")
27
-
28
-
29
- @dataclass
30
- class ChronosBoltConfig:
31
- context_length: int
32
- prediction_length: int
33
- input_patch_size: int
34
- input_patch_stride: int
35
- quantiles: List[float]
36
- use_reg_token: bool = False
37
-
38
-
39
- @dataclass
40
- class ChronosBoltOutput(ModelOutput):
41
- loss: Optional[torch.Tensor] = None
42
- quantile_preds: Optional[torch.Tensor] = None
43
- attentions: Optional[torch.Tensor] = None
44
- cross_attentions: Optional[torch.Tensor] = None
45
-
46
-
47
- class Patch(nn.Module):
48
- def __init__(self, patch_size: int, patch_stride: int) -> None:
49
- super().__init__()
50
- self.patch_size = patch_size
51
- self.patch_stride = patch_stride
52
-
53
- def forward(self, x: torch.Tensor) -> torch.Tensor:
54
- length = x.shape[-1]
55
-
56
- if length % self.patch_size != 0:
57
- padding_size = (
58
- *x.shape[:-1],
59
- self.patch_size - (length % self.patch_size),
60
- )
61
- padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
62
- x = torch.concat((padding, x), dim=-1)
63
-
64
- x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
65
- return x
66
-
67
-
68
- class InstanceNorm(nn.Module):
69
- """
70
- See, also, RevIN. Apply standardization along the last dimension.
71
- """
72
-
73
- def __init__(self, eps: float = 1e-5) -> None:
74
- super().__init__()
75
- self.eps = eps
76
-
77
- def forward(
78
- self,
79
- x: torch.Tensor,
80
- loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
81
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
82
- if loc_scale is None:
83
- loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
84
- scale = torch.nan_to_num((x - loc).square().nanmean(dim=-1, keepdim=True).sqrt(), nan=1.0)
85
- scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
86
- else:
87
- loc, scale = loc_scale
88
-
89
- return (x - loc) / scale, (loc, scale)
90
-
91
- def inverse(self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
92
- loc, scale = loc_scale
93
- return x * scale + loc
94
-
95
-
96
- class ResidualBlock(nn.Module):
97
- def __init__(
98
- self,
99
- in_dim: int,
100
- h_dim: int,
101
- out_dim: int,
102
- act_fn_name: str,
103
- dropout_p: float = 0.0,
104
- use_layer_norm: bool = False,
105
- ) -> None:
106
- super().__init__()
107
-
108
- self.dropout = nn.Dropout(dropout_p)
109
- self.hidden_layer = nn.Linear(in_dim, h_dim)
110
- self.act = ACT2FN[act_fn_name]
111
- self.output_layer = nn.Linear(h_dim, out_dim)
112
- self.residual_layer = nn.Linear(in_dim, out_dim)
113
-
114
- self.use_layer_norm = use_layer_norm
115
- if use_layer_norm:
116
- self.layer_norm = T5LayerNorm(out_dim)
117
-
118
- def forward(self, x: torch.Tensor):
119
- hid = self.act(self.hidden_layer(x))
120
- out = self.dropout(self.output_layer(hid))
121
- res = self.residual_layer(x)
122
-
123
- out = out + res
124
-
125
- if self.use_layer_norm:
126
- return self.layer_norm(out)
127
- return out
128
-
129
-
130
- class ChronosBoltModelForForecasting(T5PreTrainedModel):
131
- _keys_to_ignore_on_load_missing = [
132
- r"input_patch_embedding\.",
133
- r"output_patch_embedding\.",
134
- ]
135
- _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
136
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
137
-
138
- def __init__(self, config: T5Config):
139
- assert hasattr(config, "chronos_config"), "Not a Chronos config file"
140
-
141
- super().__init__(config)
142
- self.model_dim = config.d_model
143
-
144
- # TODO: remove filtering eventually, added for backward compatibility
145
- config_fields = {f.name for f in fields(ChronosBoltConfig)}
146
- self.chronos_config = ChronosBoltConfig(
147
- **{k: v for k, v in config.chronos_config.items() if k in config_fields}
148
- )
149
-
150
- # Only decoder_start_id (and optionally REG token)
151
- if self.chronos_config.use_reg_token:
152
- config.reg_token_id = 1
153
-
154
- config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
155
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
156
-
157
- # Input patch embedding layer
158
- self.input_patch_embedding = ResidualBlock(
159
- in_dim=self.chronos_config.input_patch_size * 2,
160
- h_dim=config.d_ff,
161
- out_dim=config.d_model,
162
- act_fn_name=config.dense_act_fn,
163
- dropout_p=config.dropout_rate,
164
- )
165
-
166
- # patching layer
167
- self.patch = Patch(
168
- patch_size=self.chronos_config.input_patch_size,
169
- patch_stride=self.chronos_config.input_patch_stride,
170
- )
171
-
172
- # instance normalization, also referred to as "scaling" in Chronos and GluonTS
173
- self.instance_norm = InstanceNorm()
174
-
175
- encoder_config = copy.deepcopy(config)
176
- encoder_config.is_decoder = False
177
- encoder_config.use_cache = False
178
- encoder_config.is_encoder_decoder = False
179
- self.encoder = T5Stack(encoder_config, self.shared)
180
-
181
- self._init_decoder(config)
182
-
183
- self.num_quantiles = len(self.chronos_config.quantiles)
184
- quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
185
- self.register_buffer("quantiles", quantiles, persistent=False)
186
-
187
- self.output_patch_embedding = ResidualBlock(
188
- in_dim=config.d_model,
189
- h_dim=config.d_ff,
190
- out_dim=self.num_quantiles * self.chronos_config.prediction_length,
191
- act_fn_name=config.dense_act_fn,
192
- dropout_p=config.dropout_rate,
193
- )
194
-
195
- # Initialize weights and apply final processing
196
- self.post_init()
197
-
198
- # Model parallel
199
- self.model_parallel = False
200
- self.device_map = None
201
-
202
- def _init_weights(self, module):
203
- super()._init_weights(module)
204
- """Initialize the weights"""
205
- factor = self.config.initializer_factor
206
- if isinstance(module, (self.__class__)):
207
- module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
208
- elif isinstance(module, ResidualBlock):
209
- module.hidden_layer.weight.data.normal_(
210
- mean=0.0,
211
- std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
212
- )
213
- if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
214
- module.hidden_layer.bias.data.zero_()
215
-
216
- module.residual_layer.weight.data.normal_(
217
- mean=0.0,
218
- std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
219
- )
220
- if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
221
- module.residual_layer.bias.data.zero_()
222
-
223
- module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
224
- if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
225
- module.output_layer.bias.data.zero_()
226
-
227
- def forward(
228
- self,
229
- context: torch.Tensor,
230
- mask: Optional[torch.Tensor] = None,
231
- target: Optional[torch.Tensor] = None,
232
- target_mask: Optional[torch.Tensor] = None,
233
- ) -> ChronosBoltOutput:
234
- mask = mask.to(context.dtype) if mask is not None else torch.isnan(context).logical_not().to(context.dtype)
235
-
236
- batch_size, _ = context.shape
237
- if context.shape[-1] > self.chronos_config.context_length:
238
- context = context[..., -self.chronos_config.context_length :]
239
- mask = mask[..., -self.chronos_config.context_length :]
240
-
241
- # scaling
242
- context, loc_scale = self.instance_norm(context)
243
-
244
- # the scaling op above is done in 32-bit precision,
245
- # then the context is moved to model's dtype
246
- context = context.to(self.dtype)
247
- mask = mask.to(self.dtype)
248
-
249
- # patching
250
- patched_context = self.patch(context)
251
- patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
252
- patched_context[~(patched_mask > 0)] = 0.0
253
- # concat context and mask along patch dim
254
- patched_context = torch.cat([patched_context, patched_mask], dim=-1)
255
-
256
- # attention_mask = 1 if at least one item in the patch is observed
257
- attention_mask = patched_mask.sum(dim=-1) > 0 # (batch_size, patched_seq_length)
258
-
259
- input_embeds = self.input_patch_embedding(patched_context)
260
-
261
- if self.chronos_config.use_reg_token:
262
- # Append [REG]
263
- reg_input_ids = torch.full(
264
- (batch_size, 1),
265
- self.config.reg_token_id,
266
- device=input_embeds.device,
267
- )
268
- reg_embeds = self.shared(reg_input_ids)
269
- input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
270
- attention_mask = torch.cat([attention_mask, torch.ones_like(reg_input_ids)], dim=-1)
271
-
272
- encoder_outputs = self.encoder(
273
- attention_mask=attention_mask,
274
- inputs_embeds=input_embeds,
275
- )
276
- hidden_states = encoder_outputs[0]
277
-
278
- sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
279
-
280
- quantile_preds_shape = (
281
- batch_size,
282
- self.num_quantiles,
283
- self.chronos_config.prediction_length,
284
- )
285
- quantile_preds = self.output_patch_embedding(sequence_output).view(*quantile_preds_shape)
286
-
287
- loss = None
288
- if target is not None:
289
- # normalize target
290
- target, _ = self.instance_norm(target, loc_scale)
291
- target = target.unsqueeze(1) # type: ignore
292
- assert self.chronos_config.prediction_length >= target.shape[-1]
293
-
294
- target = target.to(quantile_preds.device)
295
- target_mask = (
296
- target_mask.unsqueeze(1).to(quantile_preds.device) if target_mask is not None else ~torch.isnan(target)
297
- )
298
- target[~target_mask] = 0.0
299
-
300
- # pad target and target_mask if they are shorter than model's prediction_length
301
- if self.chronos_config.prediction_length > target.shape[-1]:
302
- padding_shape = (*target.shape[:-1], self.chronos_config.prediction_length - target.shape[-1])
303
- target = torch.cat([target, torch.zeros(padding_shape).to(target)], dim=-1)
304
- target_mask = torch.cat([target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1)
305
-
306
- loss = (
307
- 2
308
- * torch.abs(
309
- (target - quantile_preds)
310
- * ((target <= quantile_preds).float() - self.quantiles.view(1, self.num_quantiles, 1))
311
- )
312
- * target_mask.float()
313
- )
314
- loss = loss.mean(dim=-2) # Mean over prediction horizon
315
- loss = loss.sum(dim=-1) # Sum over quantile levels
316
- loss = loss.mean() # Mean over batch
317
-
318
- # Unscale predictions
319
- quantile_preds = self.instance_norm.inverse(
320
- quantile_preds.view(batch_size, -1),
321
- loc_scale,
322
- ).view(*quantile_preds_shape)
323
-
324
- return ChronosBoltOutput(
325
- loss=loss,
326
- quantile_preds=quantile_preds,
327
- )
328
-
329
- def _init_decoder(self, config):
330
- decoder_config = copy.deepcopy(config)
331
- decoder_config.is_decoder = True
332
- decoder_config.is_encoder_decoder = False
333
- decoder_config.num_layers = config.num_decoder_layers
334
- self.decoder = T5Stack(decoder_config, self.shared)
335
-
336
- def decode(
337
- self,
338
- input_embeds,
339
- attention_mask,
340
- hidden_states,
341
- output_attentions=False,
342
- ):
343
- """
344
- Parameters
345
- ----------
346
- input_embeds: torch.Tensor
347
- Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
348
- attention_mask: torch.Tensor
349
- Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
350
- hidden_states: torch.Tensor
351
- Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
352
-
353
- Returns
354
- -------
355
- last_hidden_state
356
- Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
357
- """
358
- batch_size = input_embeds.shape[0]
359
- decoder_input_ids = torch.full(
360
- (batch_size, 1),
361
- self.config.decoder_start_token_id,
362
- device=input_embeds.device,
363
- )
364
- decoder_outputs = self.decoder(
365
- input_ids=decoder_input_ids,
366
- encoder_hidden_states=hidden_states,
367
- encoder_attention_mask=attention_mask,
368
- output_attentions=output_attentions,
369
- return_dict=True,
370
- )
371
-
372
- return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
373
-
374
-
375
- class ChronosBoltPipeline(BaseChronosPipeline):
376
- forecast_type: ForecastType = ForecastType.QUANTILES
377
- default_context_length: int = 2048
378
- # register this class name with this alias for backward compatibility
379
- _aliases = ["PatchedT5Pipeline"]
380
-
381
- def __init__(self, model: ChronosBoltModelForForecasting):
382
- super().__init__(inner_model=model)
383
- self.model = model
384
-
385
- @property
386
- def quantiles(self) -> List[float]:
387
- return self.model.config.chronos_config["quantiles"]
388
-
389
- def predict( # type: ignore[override]
390
- self,
391
- context: Union[torch.Tensor, List[torch.Tensor]],
392
- prediction_length: Optional[int] = None,
393
- limit_prediction_length: bool = False,
394
- ):
395
- context_tensor = self._prepare_and_validate_context(context=context)
396
-
397
- model_context_length = self.model.config.chronos_config["context_length"]
398
- model_prediction_length = self.model.config.chronos_config["prediction_length"]
399
- if prediction_length is None:
400
- prediction_length = model_prediction_length
401
-
402
- if prediction_length > model_prediction_length:
403
- msg = (
404
- f"We recommend keeping prediction length <= {model_prediction_length}. "
405
- "The quality of longer predictions may degrade since the model is not optimized for it. "
406
- )
407
- if limit_prediction_length:
408
- msg += "You can turn off this check by setting `limit_prediction_length=False`."
409
- raise ValueError(msg)
410
- warnings.warn(msg)
411
-
412
- predictions = []
413
- remaining = prediction_length
414
-
415
- # We truncate the context here because otherwise batches with very long
416
- # context could take up large amounts of GPU memory unnecessarily.
417
- if context_tensor.shape[-1] > model_context_length:
418
- context_tensor = context_tensor[..., -model_context_length:]
419
-
420
- # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
421
- # horizon that the model was trained with (i.e., 64). This results in variance collapsing
422
- # every 64 steps.
423
- while remaining > 0:
424
- with torch.no_grad():
425
- prediction = self.model(
426
- context=context_tensor.to(
427
- device=self.model.device,
428
- dtype=torch.float32, # scaling should be done in 32-bit precision
429
- ),
430
- ).quantile_preds.to(context_tensor)
431
-
432
- predictions.append(prediction)
433
- remaining -= prediction.shape[-1]
434
-
435
- if remaining <= 0:
436
- break
437
-
438
- central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
439
- central_prediction = prediction[:, central_idx]
440
-
441
- context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
442
-
443
- return torch.cat(predictions, dim=-1)[..., :prediction_length]
444
-
445
- def predict_quantiles(
446
- self, context: torch.Tensor, prediction_length: int, quantile_levels: List[float], **kwargs
447
- ) -> Tuple[torch.Tensor, torch.Tensor]:
448
- # shape (batch_size, prediction_length, len(training_quantile_levels))
449
- predictions = (
450
- self.predict(
451
- context,
452
- prediction_length=prediction_length,
453
- )
454
- .detach()
455
- .cpu()
456
- .swapaxes(1, 2)
457
- )
458
-
459
- training_quantile_levels = self.quantiles
460
-
461
- if set(quantile_levels).issubset(set(training_quantile_levels)):
462
- # no need to perform intra/extrapolation
463
- quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
464
- else:
465
- # we rely on torch for interpolating quantiles if quantiles that
466
- # Chronos Bolt was trained on were not provided
467
- if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
468
- training_quantile_levels
469
- ):
470
- logger.warning(
471
- f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
472
- f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
473
- "Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
474
- "was trained on. This may significantly affect the quality of the predictions."
475
- )
476
-
477
- # TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
478
- # made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
479
- # While this holds for official Chronos-Bolt models, this may not be true in the future, and this
480
- # function may have to be revised.
481
- augmented_predictions = torch.cat(
482
- [predictions[..., [0]], predictions, predictions[..., [-1]]],
483
- dim=-1,
484
- )
485
- quantiles = torch.quantile(
486
- augmented_predictions, q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype), dim=-1
487
- ).permute(1, 2, 0)
488
- mean = predictions[:, :, training_quantile_levels.index(0.5)]
489
- return quantiles, mean
490
-
491
- @classmethod
492
- def from_pretrained(cls, *args, **kwargs):
493
- """
494
- Load the model, either from a local path or from the HuggingFace Hub.
495
- Supports the same arguments as ``AutoConfig`` and ``AutoModel``
496
- from ``transformers``.
497
- """
498
- # if optimization_strategy is provided, pop this as it won't be used
499
- kwargs.pop("optimization_strategy", None)
500
-
501
- config = AutoConfig.from_pretrained(*args, **kwargs)
502
- assert hasattr(config, "chronos_config"), "Not a Chronos config file"
503
-
504
- context_length = kwargs.pop("context_length", None)
505
- if context_length is not None:
506
- config.chronos_config["context_length"] = context_length
507
-
508
- architecture = config.architectures[0]
509
- class_ = globals().get(architecture)
510
-
511
- # TODO: remove this once all models carry the correct architecture names in their configuration
512
- # and raise an error instead.
513
- if class_ is None:
514
- logger.warning(f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting")
515
- class_ = ChronosBoltModelForForecasting
516
-
517
- model = class_.from_pretrained(*args, **kwargs)
518
- return cls(model=model)
@@ -1,78 +0,0 @@
1
- import logging
2
- from typing import Dict, List, Optional
3
-
4
- from autogluon.core.utils.exceptions import TimeLimitExceeded
5
- from autogluon.timeseries.dataset import TimeSeriesDataFrame
6
- from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- class AbstractTimeSeriesEnsembleModel(AbstractTimeSeriesModel):
12
- """Abstract class for time series ensemble models."""
13
-
14
- @property
15
- def model_names(self) -> List[str]:
16
- """Names of base models included in the ensemble."""
17
- raise NotImplementedError
18
-
19
- def fit_ensemble(
20
- self,
21
- predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
22
- data_per_window: List[TimeSeriesDataFrame],
23
- time_limit: Optional[float] = None,
24
- **kwargs,
25
- ):
26
- """Fit ensemble model given predictions of candidate base models and the true data.
27
-
28
- Parameters
29
- ----------
30
- predictions_per_window : Dict[str, List[TimeSeriesDataFrame]]
31
- Dictionary that maps the names of component models to their respective predictions for each validation
32
- window.
33
- data_per_window : List[TimeSeriesDataFrame]
34
- Observed ground truth data used to train the ensemble for each validation window. Each entry in the list
35
- includes both the forecast horizon (for which the predictions are given in ``predictions``), as well as the
36
- "history".
37
- time_limit : Optional[int]
38
- Maximum allowed time for training in seconds.
39
- """
40
- if time_limit is not None and time_limit <= 0:
41
- logger.warning(
42
- f"\tWarning: Model has no time left to train, skipping model... (Time Left = {round(time_limit, 1)}s)"
43
- )
44
- raise TimeLimitExceeded
45
- if isinstance(data_per_window, TimeSeriesDataFrame):
46
- raise ValueError("When fitting ensemble, `data` should contain ground truth for each validation window")
47
- num_val_windows = len(data_per_window)
48
- for model, preds in predictions_per_window.items():
49
- if len(preds) != num_val_windows:
50
- raise ValueError(f"For model {model} predictions are unavailable for some validation windows")
51
- self._fit_ensemble(
52
- predictions_per_window=predictions_per_window,
53
- data_per_window=data_per_window,
54
- time_limit=time_limit,
55
- )
56
- return self
57
-
58
- def _fit_ensemble(
59
- self,
60
- predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
61
- data_per_window: List[TimeSeriesDataFrame],
62
- time_limit: Optional[int] = None,
63
- **kwargs,
64
- ):
65
- """Private method for `fit_ensemble`. See `fit_ensemble` for documentation of arguments. Apart from the model
66
- training logic, `fit_ensemble` additionally implements other logic such as keeping track of the time limit.
67
- """
68
- raise NotImplementedError
69
-
70
- def predict(self, data: Dict[str, Optional[TimeSeriesDataFrame]], **kwargs) -> TimeSeriesDataFrame:
71
- raise NotImplementedError
72
-
73
- def remap_base_models(self, model_refit_map: Dict[str, str]) -> None:
74
- """Update names of the base models based on the mapping in model_refit_map.
75
-
76
- This method should be called after performing refit_full to point to the refitted base models, if necessary.
77
- """
78
- raise NotImplementedError