autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +339 -186
  5. autogluon/timeseries/learner.py +192 -60
  6. autogluon/timeseries/metrics/__init__.py +55 -11
  7. autogluon/timeseries/metrics/abstract.py +96 -25
  8. autogluon/timeseries/metrics/point.py +186 -39
  9. autogluon/timeseries/metrics/quantile.py +47 -20
  10. autogluon/timeseries/metrics/utils.py +6 -6
  11. autogluon/timeseries/models/__init__.py +13 -7
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
  14. autogluon/timeseries/models/abstract/model_trial.py +10 -10
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
  20. autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
  21. autogluon/timeseries/models/chronos/__init__.py +4 -0
  22. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  23. autogluon/timeseries/models/chronos/model.py +738 -0
  24. autogluon/timeseries/models/chronos/utils.py +369 -0
  25. autogluon/timeseries/models/ensemble/__init__.py +35 -2
  26. autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
  27. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  28. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  29. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  34. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  35. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  36. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  37. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  38. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  39. autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
  40. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  41. autogluon/timeseries/models/gluonts/__init__.py +3 -1
  42. autogluon/timeseries/models/gluonts/abstract.py +583 -0
  43. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  44. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
  45. autogluon/timeseries/models/local/__init__.py +1 -10
  46. autogluon/timeseries/models/local/abstract_local_model.py +150 -97
  47. autogluon/timeseries/models/local/naive.py +31 -23
  48. autogluon/timeseries/models/local/npts.py +6 -2
  49. autogluon/timeseries/models/local/statsforecast.py +99 -112
  50. autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
  51. autogluon/timeseries/models/registry.py +64 -0
  52. autogluon/timeseries/models/toto/__init__.py +3 -0
  53. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  62. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  63. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  64. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  65. autogluon/timeseries/models/toto/dataloader.py +108 -0
  66. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  67. autogluon/timeseries/models/toto/model.py +236 -0
  68. autogluon/timeseries/predictor.py +826 -305
  69. autogluon/timeseries/regressor.py +253 -0
  70. autogluon/timeseries/splitter.py +10 -31
  71. autogluon/timeseries/trainer/__init__.py +2 -3
  72. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  73. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  74. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  75. autogluon/timeseries/trainer/trainer.py +1298 -0
  76. autogluon/timeseries/trainer/utils.py +17 -0
  77. autogluon/timeseries/transforms/__init__.py +2 -0
  78. autogluon/timeseries/transforms/covariate_scaler.py +164 -0
  79. autogluon/timeseries/transforms/target_scaler.py +149 -0
  80. autogluon/timeseries/utils/constants.py +10 -0
  81. autogluon/timeseries/utils/datetime/base.py +38 -20
  82. autogluon/timeseries/utils/datetime/lags.py +18 -16
  83. autogluon/timeseries/utils/datetime/seasonality.py +14 -14
  84. autogluon/timeseries/utils/datetime/time_features.py +17 -14
  85. autogluon/timeseries/utils/features.py +317 -53
  86. autogluon/timeseries/utils/forecast.py +31 -17
  87. autogluon/timeseries/utils/timer.py +173 -0
  88. autogluon/timeseries/utils/warning_filters.py +44 -6
  89. autogluon/timeseries/version.py +2 -1
  90. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  91. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
  92. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  93. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  94. autogluon/timeseries/configs/presets_configs.py +0 -11
  95. autogluon/timeseries/evaluator.py +0 -6
  96. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  97. autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
  98. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  99. autogluon/timeseries/models/presets.py +0 -325
  100. autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
  101. autogluon/timeseries/trainer/auto_trainer.py +0 -74
  102. autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
  103. autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
  104. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -0,0 +1,369 @@
1
+ import logging
2
+ import time
3
+ from itertools import chain, cycle
4
+ from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from chronos.chronos_bolt import ChronosBoltModelForForecasting, ResidualBlock
9
+ from gluonts.dataset.field_names import FieldName
10
+ from gluonts.transform import ExpectedNumInstanceSampler, InstanceSplitter, ValidationSplitSampler
11
+ from torch.utils.data import IterableDataset
12
+ from transformers import TrainerCallback
13
+
14
+ from autogluon.core.utils.exceptions import TimeLimitExceeded
15
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
16
+ from autogluon.timeseries.models.gluonts.dataset import SimpleGluonTSDataset
17
+
18
+ if TYPE_CHECKING:
19
+ # TODO: fix the underlying reason for this circular import, the pipeline should handle tokenization
20
+ from chronos import ChronosTokenizer
21
+
22
+
23
+ logger = logging.getLogger("autogluon.timeseries.models.chronos")
24
+
25
+
26
+ class PseudoShuffledIterableDataset(IterableDataset):
27
+ """
28
+ Shuffle entries from an iterable by temporarily accumulating them
29
+ in an intermediate buffer.
30
+
31
+ Parameters
32
+ ----------
33
+ base_dataset
34
+ The original iterable object, representing the dataset.
35
+ shuffle_buffer_size
36
+ Size of the buffer use to shuffle entries from the base dataset.
37
+ """
38
+
39
+ def __init__(self, base_dataset, shuffle_buffer_size: int = 100) -> None:
40
+ super().__init__()
41
+ assert shuffle_buffer_size > 0
42
+ self.base_dataset = base_dataset
43
+ self.shuffle_buffer_size = shuffle_buffer_size
44
+ self.generator = torch.Generator()
45
+
46
+ def __iter__(self):
47
+ shuffle_buffer = []
48
+
49
+ for element in self.base_dataset:
50
+ shuffle_buffer.append(element)
51
+ if len(shuffle_buffer) >= self.shuffle_buffer_size:
52
+ idx = torch.randint(len(shuffle_buffer), size=(), generator=self.generator)
53
+ yield shuffle_buffer.pop(idx)
54
+
55
+ while shuffle_buffer:
56
+ idx = torch.randint(len(shuffle_buffer), size=(), generator=self.generator)
57
+ yield shuffle_buffer.pop(idx)
58
+
59
+
60
+ class ChronosFineTuningDataset(IterableDataset):
61
+ """
62
+ Dataset wrapper to convert a ``TimeSeriesDataFrame`` into an iterable dataset
63
+ compatible with Chronos models.
64
+
65
+ When a ``tokenizer`` is provided, data is converted into HuggingFace-compatible set of
66
+ ``input_ids``, ``attention_mask`` and ``labels``, used by the original Chronos models.
67
+
68
+ When the ``tokenizer`` is omitted, data is converted into the format compatible with
69
+ ChronosBolt models, i.e., ``context`` and ``target``.
70
+
71
+ Parameters
72
+ ----------
73
+ target_df
74
+ The ``TimeSeriesDataFrame`` to be converted
75
+ target_column
76
+ The name of the column which contains the target time series, by default "target"
77
+ context_length
78
+ The length of the historical context
79
+ prediction_length
80
+ The prediction_length, i.e., length of label or target
81
+ tokenizer
82
+ When a ``ChronosTokenizer`` object is provided, data will be converted into the
83
+ HuggingFace format accepted by the original Chronos models using this ``ChronosTokenizer``.
84
+ If None, data will be converted into the format accepted by ChronosBolt models.
85
+ mode
86
+ When ``training``, random slices from the time series will be returned for training purposes.
87
+ If ``validation``, the last slice of each time series returned in the original order.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ target_df: TimeSeriesDataFrame,
93
+ target_column: str = "target",
94
+ context_length: int = 512,
95
+ prediction_length: int = 64,
96
+ tokenizer: "ChronosTokenizer | None" = None,
97
+ mode: Literal["training", "validation"] = "training",
98
+ ) -> None:
99
+ super().__init__()
100
+
101
+ assert mode in ("training", "validation")
102
+
103
+ # A dummy hourly freq is used because the model doesn't actually need the freq
104
+ self.gluonts_dataset = SimpleGluonTSDataset(target_df=target_df, freq="h", target_column=target_column)
105
+ self.tokenizer = tokenizer
106
+ self.context_length = context_length
107
+ self.prediction_length = prediction_length
108
+ self.mode = mode
109
+
110
+ def _create_instance_splitter(self, mode: str):
111
+ instance_sampler = {
112
+ "training": ExpectedNumInstanceSampler(
113
+ num_instances=1.0, min_future=self.prediction_length, min_instances=1
114
+ ),
115
+ "validation": ValidationSplitSampler(min_future=self.prediction_length),
116
+ }[mode]
117
+
118
+ return InstanceSplitter(
119
+ target_field=FieldName.TARGET,
120
+ is_pad_field=FieldName.IS_PAD,
121
+ start_field=FieldName.START,
122
+ forecast_start_field=FieldName.FORECAST_START,
123
+ instance_sampler=instance_sampler,
124
+ past_length=self.context_length,
125
+ future_length=self.prediction_length,
126
+ dummy_value=np.nan,
127
+ )
128
+
129
+ def _create_training_data(self, data: Iterable[dict]):
130
+ data = chain.from_iterable(cycle([data]))
131
+ split_transform = self._create_instance_splitter("training")
132
+ data = split_transform.apply(data, is_train=True) # type: ignore
133
+ return data
134
+
135
+ def _create_validation_data(self, data: Iterable[dict]):
136
+ data = self._create_instance_splitter("validation").apply(data, is_train=False) # type: ignore
137
+ return data
138
+
139
+ def to_chronos_format(self, entry: dict) -> dict:
140
+ """Converts an entry from GluonTS data format with past and future targets
141
+ to the HuggingFace format accepted by the original Chronos models using the ChronosTokenizer.
142
+
143
+ Parameters
144
+ ----------
145
+ entry
146
+ time series data entry in GluonTS format with ``past_target`` and ``future_target`` keys
147
+
148
+ Returns
149
+ -------
150
+ dict
151
+ time series data entry in HuggingFace format with ``input_ids``, ``attention_mask``, and ``labels``
152
+ """
153
+ assert self.tokenizer is not None, "A ChronosTokenizer is required to convert data into the Chronos format"
154
+ past_target = torch.tensor(entry[f"past_{FieldName.TARGET}"]).unsqueeze(0)
155
+ input_ids, attention_mask, scale = self.tokenizer.context_input_transform(past_target)
156
+ future_target = torch.tensor(entry[f"future_{FieldName.TARGET}"]).unsqueeze(0)
157
+ labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
158
+ labels[labels_mask == 0] = -100
159
+
160
+ return {
161
+ "input_ids": input_ids.squeeze(0),
162
+ "attention_mask": attention_mask.squeeze(0),
163
+ "labels": labels.squeeze(0),
164
+ }
165
+
166
+ def to_chronos_bolt_format(self, entry: dict) -> dict:
167
+ """Converts an entry from GluonTS data format with past and future targets
168
+ to the format accepted by the ChronosBolt models.
169
+
170
+ Parameters
171
+ ----------
172
+ entry
173
+ time series data entry in GluonTS format with ``past_target`` and ``future_target`` keys
174
+
175
+ Returns
176
+ -------
177
+ dict
178
+ time series data entry in ChronosBolt format with ``context`` and ``target``
179
+ """
180
+ past_target = torch.tensor(entry[f"past_{FieldName.TARGET}"])
181
+ future_target = torch.tensor(entry[f"future_{FieldName.TARGET}"])
182
+
183
+ return {"context": past_target, "target": future_target}
184
+
185
+ def __iter__(self) -> Iterator:
186
+ if self.mode == "training":
187
+ iterable = self._create_training_data(self.gluonts_dataset)
188
+ elif self.mode == "validation":
189
+ iterable = self._create_validation_data(self.gluonts_dataset)
190
+ else:
191
+ raise ValueError(f"Unknown mode {self.mode}")
192
+
193
+ format_transform_fn = self.to_chronos_format if self.tokenizer is not None else self.to_chronos_bolt_format
194
+ for entry in iterable:
195
+ yield format_transform_fn(entry)
196
+
197
+ def shuffle(self, shuffle_buffer_size: int | None = None):
198
+ """Returns a (pseudo) shuffled version of this iterable dataset.
199
+
200
+ Parameters
201
+ ----------
202
+ shuffle_buffer_size
203
+ The shuffle buffer size used for pseudo shuffling
204
+ """
205
+ assert shuffle_buffer_size is None or shuffle_buffer_size >= 0
206
+ if not shuffle_buffer_size:
207
+ return self
208
+ return PseudoShuffledIterableDataset(self, shuffle_buffer_size)
209
+
210
+
211
+ def left_pad_and_stack_1D(tensors: list[torch.Tensor]) -> torch.Tensor:
212
+ max_len = max(len(c) for c in tensors)
213
+ padded = []
214
+ for c in tensors:
215
+ assert isinstance(c, torch.Tensor)
216
+ assert c.ndim == 1
217
+ padding = torch.full(size=(max_len - len(c),), fill_value=torch.nan, device=c.device)
218
+ padded.append(torch.concat((padding, c), dim=-1))
219
+ return torch.stack(padded)
220
+
221
+
222
+ class ChronosInferenceDataset:
223
+ """A container for time series datasets that implements the ``torch.utils.data.Dataset`` interface"""
224
+
225
+ def __init__(
226
+ self,
227
+ target_df: TimeSeriesDataFrame,
228
+ context_length: int,
229
+ target_column: str = "target",
230
+ ):
231
+ assert context_length > 0
232
+ self.context_length = context_length
233
+ self.target_array = target_df[target_column].to_numpy(dtype=np.float32)
234
+
235
+ # store pointer to start:end of each time series
236
+ self.indptr = target_df.get_indptr()
237
+
238
+ def __len__(self):
239
+ return len(self.indptr) - 1 # noqa
240
+
241
+ def _get_context(self, a: np.ndarray, pad_value=np.nan):
242
+ a = a[-self.context_length :]
243
+ pad_size = self.context_length - len(a)
244
+ if pad_size > 0:
245
+ pad = np.full(shape=(pad_size,), fill_value=pad_value)
246
+ a = np.concatenate((pad, a))
247
+ return a
248
+
249
+ def __getitem__(self, idx) -> np.ndarray:
250
+ start_idx = self.indptr[idx]
251
+ end_idx = self.indptr[idx + 1]
252
+
253
+ return self._get_context(self.target_array[start_idx:end_idx])
254
+
255
+
256
+ class ChronosInferenceDataLoader(torch.utils.data.DataLoader):
257
+ def __init__(self, *args, **kwargs):
258
+ self.callback: Callable = kwargs.pop("on_batch", lambda: None)
259
+ super().__init__(*args, **kwargs)
260
+
261
+ def __iter__(self): # type: ignore
262
+ for item in super().__iter__():
263
+ yield item
264
+ self.callback()
265
+
266
+
267
+ class EvaluateAndSaveFinalStepCallback(TrainerCallback):
268
+ """Callback to evaluate and save the model at last training step."""
269
+
270
+ def on_step_end(self, args, state, control, **kwargs):
271
+ if state.global_step >= state.max_steps:
272
+ control.should_log = True
273
+ control.should_evaluate = True
274
+ control.should_save = True
275
+
276
+
277
+ class TimeLimitCallback(TrainerCallback):
278
+ def __init__(self, time_limit: float):
279
+ """
280
+ Callback to stop training once a specified time has elapsed.
281
+
282
+ Parameters
283
+ ----------
284
+ time_limit
285
+ maximum time allowed for training in seconds.
286
+ """
287
+ self.time_limit = time_limit
288
+ self.start_time = None
289
+
290
+ def on_train_begin(self, args, state, control, **kwargs):
291
+ self.start_time = time.monotonic() # type: ignore
292
+
293
+ def on_step_end(self, args, state, control, **kwargs):
294
+ elapsed_time = time.monotonic() - self.start_time # type: ignore
295
+ if elapsed_time > self.time_limit:
296
+ logger.log(15, "Stopping fine-tuning since time_limit is reached")
297
+ control.should_training_stop = True
298
+
299
+
300
+ class LoggerCallback(TrainerCallback):
301
+ def on_log(self, args, state, control, logs=None, **kwargs):
302
+ if logs:
303
+ logs.pop("total_flos", None)
304
+ if state.is_local_process_zero:
305
+ logger.info(logs)
306
+
307
+
308
+ def timeout_callback(seconds: float | None) -> Callable:
309
+ """Return a callback object that raises an exception if time limit is exceeded."""
310
+ start_time = time.monotonic()
311
+
312
+ def callback() -> None:
313
+ if seconds is not None and time.monotonic() - start_time > seconds:
314
+ raise TimeLimitExceeded
315
+
316
+ return callback
317
+
318
+
319
+ def update_output_quantiles(model: ChronosBoltModelForForecasting, new_quantiles: list[float]) -> None:
320
+ """In-place updates model's output layer to support only the specified new quantiles by copying
321
+ weights from closest existing quantiles.
322
+ """
323
+ old_quantiles = model.chronos_config.quantiles
324
+ new_quantiles = sorted(new_quantiles)
325
+
326
+ if new_quantiles == old_quantiles:
327
+ return
328
+
329
+ model.chronos_config.quantiles = new_quantiles
330
+ model.num_quantiles = len(new_quantiles)
331
+ model.register_buffer("quantiles", torch.tensor(new_quantiles, dtype=model.dtype), persistent=False)
332
+
333
+ old_output_layer = model.output_patch_embedding
334
+ new_output_layer = ResidualBlock(
335
+ in_dim=model.config.d_model,
336
+ h_dim=model.config.d_ff,
337
+ out_dim=len(new_quantiles) * model.chronos_config.prediction_length,
338
+ act_fn_name=model.config.dense_act_fn,
339
+ dropout_p=model.config.dropout_rate,
340
+ )
341
+
342
+ # hidden_layer is shared across all quantiles
343
+ new_output_layer.hidden_layer.weight.data.copy_(old_output_layer.hidden_layer.weight.data)
344
+ if old_output_layer.hidden_layer.bias is not None:
345
+ new_output_layer.hidden_layer.bias.data.copy_(old_output_layer.hidden_layer.bias.data)
346
+
347
+ def copy_quantile_weights(src_idx: int, dst_idx: int):
348
+ """Copy weights for one quantile from src_idx to dst_idx"""
349
+ prediction_length = model.chronos_config.prediction_length
350
+ src_start, src_end = src_idx * prediction_length, (src_idx + 1) * prediction_length
351
+ dst_start, dst_end = dst_idx * prediction_length, (dst_idx + 1) * prediction_length
352
+
353
+ for layer_name in ["output_layer", "residual_layer"]:
354
+ old_layer_attr = getattr(old_output_layer, layer_name)
355
+ new_layer_attr = getattr(new_output_layer, layer_name)
356
+
357
+ new_layer_attr.weight[dst_start:dst_end] = old_layer_attr.weight[src_start:src_end]
358
+ if old_layer_attr.bias is not None:
359
+ new_layer_attr.bias[dst_start:dst_end] = old_layer_attr.bias[src_start:src_end]
360
+
361
+ with torch.no_grad():
362
+ for new_idx, new_q in enumerate(new_quantiles):
363
+ closest_q = min(old_quantiles, key=lambda x: abs(x - new_q))
364
+ closest_idx = old_quantiles.index(closest_q)
365
+ copy_quantile_weights(closest_idx, new_idx)
366
+
367
+ model.output_patch_embedding = new_output_layer
368
+ model.config.chronos_config["quantiles"] = new_quantiles
369
+ model.chronos_config.quantiles = new_quantiles
@@ -1,2 +1,35 @@
1
- from .abstract_timeseries_ensemble import AbstractTimeSeriesEnsembleModel
2
- from .greedy_ensemble import TimeSeriesGreedyEnsemble
1
+ from .abstract import AbstractTimeSeriesEnsembleModel
2
+ from .array_based import LinearStackerEnsemble, MedianEnsemble, PerQuantileTabularEnsemble, TabularEnsemble
3
+ from .per_item_greedy import PerItemGreedyEnsemble
4
+ from .weighted import GreedyEnsemble, PerformanceWeightedEnsemble, SimpleAverageEnsemble
5
+
6
+
7
+ def get_ensemble_class(name: str):
8
+ mapping = {
9
+ "GreedyEnsemble": GreedyEnsemble,
10
+ "PerItemGreedyEnsemble": PerItemGreedyEnsemble,
11
+ "PerformanceWeightedEnsemble": PerformanceWeightedEnsemble,
12
+ "SimpleAverageEnsemble": SimpleAverageEnsemble,
13
+ "WeightedEnsemble": GreedyEnsemble, # old alias for this model
14
+ "MedianEnsemble": MedianEnsemble,
15
+ "TabularEnsemble": TabularEnsemble,
16
+ "PerQuantileTabularEnsemble": PerQuantileTabularEnsemble,
17
+ "LinearStackerEnsemble": LinearStackerEnsemble,
18
+ }
19
+ if name not in mapping:
20
+ raise ValueError(f"Unknown ensemble type: {name}. Available: {list(mapping.keys())}")
21
+ return mapping[name]
22
+
23
+
24
+ __all__ = [
25
+ "AbstractTimeSeriesEnsembleModel",
26
+ "GreedyEnsemble",
27
+ "LinearStackerEnsemble",
28
+ "MedianEnsemble",
29
+ "PerformanceWeightedEnsemble",
30
+ "PerItemGreedyEnsemble",
31
+ "PerQuantileTabularEnsemble",
32
+ "SimpleAverageEnsemble",
33
+ "TabularEnsemble",
34
+ "get_ensemble_class",
35
+ ]
@@ -1,40 +1,46 @@
1
1
  import logging
2
- from typing import Dict, List, Optional
2
+ from abc import ABC, abstractmethod
3
+
4
+ from typing_extensions import final
3
5
 
4
6
  from autogluon.core.utils.exceptions import TimeLimitExceeded
5
7
  from autogluon.timeseries.dataset import TimeSeriesDataFrame
6
- from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
8
+ from autogluon.timeseries.models.abstract import TimeSeriesModelBase
7
9
 
8
10
  logger = logging.getLogger(__name__)
9
11
 
10
12
 
11
- class AbstractTimeSeriesEnsembleModel(AbstractTimeSeriesModel):
13
+ class AbstractTimeSeriesEnsembleModel(TimeSeriesModelBase, ABC):
12
14
  """Abstract class for time series ensemble models."""
13
15
 
14
16
  @property
15
- def model_names(self) -> List[str]:
17
+ @abstractmethod
18
+ def model_names(self) -> list[str]:
16
19
  """Names of base models included in the ensemble."""
17
- raise NotImplementedError
20
+ pass
18
21
 
19
- def fit_ensemble(
22
+ @final
23
+ def fit(
20
24
  self,
21
- predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
22
- data_per_window: List[TimeSeriesDataFrame],
23
- time_limit: Optional[int] = None,
24
- **kwargs,
25
+ predictions_per_window: dict[str, list[TimeSeriesDataFrame]],
26
+ data_per_window: list[TimeSeriesDataFrame],
27
+ model_scores: dict[str, float] | None = None,
28
+ time_limit: float | None = None,
25
29
  ):
26
30
  """Fit ensemble model given predictions of candidate base models and the true data.
27
31
 
28
32
  Parameters
29
33
  ----------
30
- predictions_per_window : Dict[str, List[TimeSeriesDataFrame]]
34
+ predictions_per_window
31
35
  Dictionary that maps the names of component models to their respective predictions for each validation
32
36
  window.
33
- data_per_window : List[TimeSeriesDataFrame]
37
+ data_per_window
34
38
  Observed ground truth data used to train the ensemble for each validation window. Each entry in the list
35
39
  includes both the forecast horizon (for which the predictions are given in ``predictions``), as well as the
36
40
  "history".
37
- time_limit : Optional[int]
41
+ model_scores
42
+ Scores (higher is better) for the models that will constitute the ensemble.
43
+ time_limit
38
44
  Maximum allowed time for training in seconds.
39
45
  """
40
46
  if time_limit is not None and time_limit <= 0:
@@ -48,31 +54,49 @@ class AbstractTimeSeriesEnsembleModel(AbstractTimeSeriesModel):
48
54
  for model, preds in predictions_per_window.items():
49
55
  if len(preds) != num_val_windows:
50
56
  raise ValueError(f"For model {model} predictions are unavailable for some validation windows")
51
- self._fit_ensemble(
57
+ self._fit(
52
58
  predictions_per_window=predictions_per_window,
53
59
  data_per_window=data_per_window,
60
+ model_scores=model_scores,
54
61
  time_limit=time_limit,
55
62
  )
56
63
  return self
57
64
 
58
- def _fit_ensemble(
65
+ def _fit(
59
66
  self,
60
- predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
61
- data_per_window: List[TimeSeriesDataFrame],
62
- time_limit: Optional[int] = None,
63
- **kwargs,
64
- ):
65
- """Private method for `fit_ensemble`. See `fit_ensemble` for documentation of arguments. Apart from the model
66
- training logic, `fit_ensemble` additionally implements other logic such as keeping track of the time limit.
67
+ predictions_per_window: dict[str, list[TimeSeriesDataFrame]],
68
+ data_per_window: list[TimeSeriesDataFrame],
69
+ model_scores: dict[str, float] | None = None,
70
+ time_limit: float | None = None,
71
+ ) -> None:
72
+ """Private method for `fit`. See `fit` for documentation of arguments. Apart from the model
73
+ training logic, `fit` additionally implements other logic such as keeping track of the time limit.
67
74
  """
68
75
  raise NotImplementedError
69
76
 
70
- def predict(self, data: Dict[str, TimeSeriesDataFrame], **kwargs) -> TimeSeriesDataFrame:
71
- raise NotImplementedError
77
+ @final
78
+ def predict(self, data: dict[str, TimeSeriesDataFrame], **kwargs) -> TimeSeriesDataFrame:
79
+ if not set(self.model_names).issubset(set(data.keys())):
80
+ raise ValueError(
81
+ f"Set of models given for prediction in {self.name} differ from those provided during initialization."
82
+ )
83
+ for model_name, model_pred in data.items():
84
+ if model_pred is None:
85
+ raise RuntimeError(f"{self.name} cannot predict because base model {model_name} failed.")
72
86
 
73
- def remap_base_models(self, model_refit_map: Dict[str, str]) -> None:
87
+ # Make sure that all predictions have same shape
88
+ assert len(set(pred.shape for pred in data.values())) == 1
89
+
90
+ return self._predict(data=data, **kwargs)
91
+
92
+ @abstractmethod
93
+ def _predict(self, data: dict[str, TimeSeriesDataFrame], **kwargs) -> TimeSeriesDataFrame:
94
+ pass
95
+
96
+ @abstractmethod
97
+ def remap_base_models(self, model_refit_map: dict[str, str]) -> None:
74
98
  """Update names of the base models based on the mapping in model_refit_map.
75
99
 
76
100
  This method should be called after performing refit_full to point to the refitted base models, if necessary.
77
101
  """
78
- raise NotImplementedError
102
+ pass
@@ -0,0 +1,3 @@
1
+ from .models import LinearStackerEnsemble, MedianEnsemble, PerQuantileTabularEnsemble, TabularEnsemble
2
+
3
+ __all__ = ["LinearStackerEnsemble", "MedianEnsemble", "PerQuantileTabularEnsemble", "TabularEnsemble"]