autogluon.timeseries 1.1.2b20241111__py3-none-any.whl → 1.1.2b20241112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. autogluon/timeseries/models/chronos/model.py +67 -38
  2. autogluon/timeseries/models/chronos/pipeline/__init__.py +11 -0
  3. autogluon/timeseries/models/chronos/pipeline/base.py +146 -0
  4. autogluon/timeseries/models/chronos/{pipeline.py → pipeline/chronos.py} +66 -102
  5. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +511 -0
  6. autogluon/timeseries/models/chronos/{utils.py → pipeline/utils.py} +37 -1
  7. autogluon/timeseries/models/gluonts/torch/models.py +3 -0
  8. autogluon/timeseries/utils/warning_filters.py +20 -0
  9. autogluon/timeseries/version.py +1 -1
  10. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/METADATA +5 -5
  11. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/RECORD +18 -15
  12. /autogluon.timeseries-1.1.2b20241111-py3.8-nspkg.pth → /autogluon.timeseries-1.1.2b20241112-py3.8-nspkg.pth +0 -0
  13. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/LICENSE +0 -0
  14. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/NOTICE +0 -0
  15. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/WHEEL +0 -0
  16. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/namespace_packages.txt +0 -0
  17. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/top_level.txt +0 -0
  18. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241112.dist-info}/zip-safe +0 -0
@@ -9,9 +9,9 @@ from autogluon.common.loaders import load_pkl
9
9
  from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame
10
10
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
11
11
  from autogluon.timeseries.utils.forecast import get_forecast_horizon_index_ts_dataframe
12
- from autogluon.timeseries.utils.warning_filters import warning_filter
12
+ from autogluon.timeseries.utils.warning_filters import disable_duplicate_logs, warning_filter
13
13
 
14
- logger = logging.getLogger(__name__)
14
+ logger = logging.getLogger("autogluon.timeseries.models.chronos")
15
15
 
16
16
 
17
17
  # allowed HuggingFace model paths with custom parameter definitions
@@ -41,6 +41,21 @@ MODEL_CONFIGS = {
41
41
  "default_torch_dtype": "bfloat16",
42
42
  "default_batch_size": 8,
43
43
  },
44
+ "chronos-bolt-mini": {
45
+ "num_gpus": 0,
46
+ "default_torch_dtype": "auto",
47
+ "default_batch_size": 256,
48
+ },
49
+ "chronos-bolt-small": {
50
+ "num_gpus": 0,
51
+ "default_torch_dtype": "auto",
52
+ "default_batch_size": 256,
53
+ },
54
+ "chronos-bolt-base": {
55
+ "num_gpus": 0,
56
+ "default_torch_dtype": "auto",
57
+ "default_batch_size": 256,
58
+ },
44
59
  }
45
60
 
46
61
 
@@ -50,22 +65,29 @@ MODEL_ALIASES = {
50
65
  "small": "autogluon/chronos-t5-small",
51
66
  "base": "autogluon/chronos-t5-base",
52
67
  "large": "autogluon/chronos-t5-large",
68
+ "bolt-mini": "autogluon/chronos-bolt-mini",
69
+ "bolt-small": "autogluon/chronos-bolt-small",
70
+ "bolt-base": "autogluon/chronos-bolt-base",
53
71
  }
54
72
 
55
73
 
56
74
  class ChronosModel(AbstractTimeSeriesModel):
57
- """Chronos pretrained time series forecasting models, based on the original
58
- `ChronosModel <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py>`_ implementation.
75
+ """Chronos pretrained time series forecasting models. Models can be based on the original
76
+ `ChronosModel <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py>`_ implementation,
77
+ as well as a newer family of Chronos-Bolt models which are capable of much faster inference.
59
78
 
60
- Chronos is family of pretrained models, based on the T5 family, with number of parameters ranging between 8M and 710M.
61
- The full collection of Chronos models is available on
79
+ The original Chronos is a family of pretrained models, based on the T5 family, with number of parameters ranging between
80
+ 8M and 710M. The full collection of Chronos models is available on
62
81
  `Hugging Face <https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444>`_. For Chronos small,
63
- base, and large variants a GPU is required to perform inference efficiently.
64
-
65
- Chronos takes a minimalistic approach to pretraining time series models, by discretizing time series data directly into bins
66
- which are treated as tokens, effectively performing regression by classification. This results in a simple and flexible framework
82
+ base, and large variants a GPU is required to perform inference efficiently. Chronos takes a minimalistic approach to
83
+ pretraining time series models, by discretizing time series data directly into bins which are treated as tokens,
84
+ effectively performing regression by classification. This results in a simple and flexible framework
67
85
  for using any language model in the context of time series forecasting. See [Ansari2024]_ for more information.
68
86
 
87
+ The newer Chronos-Bolt variants enable much faster inference by first "patching" the time series. The resulting
88
+ time series is then fed into a T5 model for forecasting. The Chronos-Bolt variants are capable of much faster inference,
89
+ and can all run on CPUs. Chronos-Bolt models are also available on Hugging Face <https://huggingface.co/autogluon/>`_.
90
+
69
91
  References
70
92
  ----------
71
93
  .. [Ansari2024] Ansari, Abdul Fatir, Stella, Lorenzo et al.
@@ -79,7 +101,8 @@ class ChronosModel(AbstractTimeSeriesModel):
79
101
  Model path used for the model, i.e., a HuggingFace transformers ``name_or_path``. Can be a
80
102
  compatible model name on HuggingFace Hub or a local path to a model directory. Original
81
103
  Chronos models (i.e., ``autogluon/chronos-t5-{model_size}``) can be specified with aliases
82
- ``tiny``, ``mini`` , ``small``, ``base``, and ``large``.
104
+ ``tiny``, ``mini`` , ``small``, ``base``, and ``large``. Chronos-Bolt models can be specified
105
+ with ``bolt-mini``, ``bolt-small``, and ``bolt-base``.
83
106
  batch_size : int, default = 16
84
107
  Size of batches used during inference
85
108
  num_samples : int, default = 20
@@ -90,11 +113,15 @@ class ChronosModel(AbstractTimeSeriesModel):
90
113
  context_length : int or None, default = None
91
114
  The context length to use in the model. Shorter context lengths will decrease model accuracy, but result
92
115
  in faster inference. If None, the model will infer context length from the data set length at inference
93
- time, but set it to a maximum of 512.
116
+ time, but set it to a maximum of 2048. Note that this is only the context length used to pass data into
117
+ the model. Individual model implementations may have different context lengths specified in their configuration,
118
+ and may truncate the context further. For example, original Chronos models have a context length of 512, but
119
+ Chronos-Bolt models handle contexts up to 2048.
94
120
  optimization_strategy : {None, "onnx", "openvino"}, default = None
95
121
  Optimization strategy to use for inference on CPUs. If None, the model will use the default implementation.
96
122
  If `onnx`, the model will be converted to ONNX and the inference will be performed using ONNX. If ``openvino``,
97
- inference will be performed with the model compiled to OpenVINO.
123
+ inference will be performed with the model compiled to OpenVINO. These optimizations are only available for
124
+ the original set of Chronos models, and not in Chronos-Bolt where they are not needed.
98
125
  torch_dtype : torch.dtype or {"auto", "bfloat16", "float32", "float64"}, default = "auto"
99
126
  Torch data type for model weights, provided to ``from_pretrained`` method of Hugging Face AutoModels. If
100
127
  original Chronos models are specified and the model size is ``small``, ``base``, or ``large``, the
@@ -107,7 +134,7 @@ class ChronosModel(AbstractTimeSeriesModel):
107
134
  # default number of samples for prediction
108
135
  default_num_samples: int = 20
109
136
  default_model_path = "autogluon/chronos-t5-small"
110
- maximum_context_length = 512
137
+ maximum_context_length = 2048
111
138
 
112
139
  def __init__(
113
140
  self,
@@ -159,7 +186,7 @@ class ChronosModel(AbstractTimeSeriesModel):
159
186
  **kwargs,
160
187
  )
161
188
 
162
- self.model_pipeline: Optional[Any] = None # of type OptimizedChronosPipeline
189
+ self.model_pipeline: Optional[Any] = None # of type BaseChronosPipeline
163
190
  self.time_limit: Optional[float] = None
164
191
 
165
192
  def save(self, path: str = None, verbose: bool = True) -> str:
@@ -218,8 +245,8 @@ class ChronosModel(AbstractTimeSeriesModel):
218
245
  minimum_resources["num_gpus"] = self.min_num_gpus
219
246
  return minimum_resources
220
247
 
221
- def load_model_pipeline(self, context_length: Optional[int] = None):
222
- from .pipeline import OptimizedChronosPipeline
248
+ def load_model_pipeline(self):
249
+ from .pipeline import BaseChronosPipeline
223
250
 
224
251
  gpu_available = self._is_gpu_available()
225
252
 
@@ -232,18 +259,17 @@ class ChronosModel(AbstractTimeSeriesModel):
232
259
 
233
260
  device = self.device or ("cuda" if gpu_available else "cpu")
234
261
 
235
- pipeline = OptimizedChronosPipeline.from_pretrained(
262
+ pipeline = BaseChronosPipeline.from_pretrained(
236
263
  self.model_path,
237
264
  device_map=device,
238
- optimization_strategy=self.optimization_strategy,
239
265
  torch_dtype=self.torch_dtype,
240
- context_length=context_length or self.context_length,
266
+ optimization_strategy=self.optimization_strategy,
241
267
  )
242
268
 
243
269
  self.model_pipeline = pipeline
244
270
 
245
271
  def persist(self) -> "ChronosModel":
246
- self.load_model_pipeline(context_length=self.context_length or self.maximum_context_length)
272
+ self.load_model_pipeline()
247
273
  return self
248
274
 
249
275
  def _fit(
@@ -263,7 +289,7 @@ class ChronosModel(AbstractTimeSeriesModel):
263
289
  num_workers: int = 0,
264
290
  time_limit: Optional[float] = None,
265
291
  ):
266
- from .utils import ChronosInferenceDataLoader, ChronosInferenceDataset, timeout_callback
292
+ from .pipeline.utils import ChronosInferenceDataLoader, ChronosInferenceDataset, timeout_callback
267
293
 
268
294
  chronos_dataset = ChronosInferenceDataset(
269
295
  target_df=data,
@@ -290,6 +316,9 @@ class ChronosModel(AbstractTimeSeriesModel):
290
316
  # and use that to determine the context length of the model. If the context length is specified
291
317
  # during initialization, this is always used. If not, the context length is set to the longest
292
318
  # item length. The context length is always capped by self.maximum_context_length.
319
+ # Note that this is independent of the model's own context length set in the model's config file.
320
+ # For example, if the context_length is set to 2048 here but the model expects context length
321
+ # (according to its config.json file) of 512, it will further truncate the series during inference.
293
322
  context_length = self.context_length or min(
294
323
  data.num_timesteps_per_item().max(),
295
324
  self.maximum_context_length,
@@ -300,7 +329,7 @@ class ChronosModel(AbstractTimeSeriesModel):
300
329
 
301
330
  if self.model_pipeline is None:
302
331
  # load model pipeline to device memory
303
- self.load_model_pipeline(context_length=context_length)
332
+ self.load_model_pipeline()
304
333
 
305
334
  inference_data_loader = self._get_inference_data_loader(
306
335
  data=data,
@@ -308,28 +337,28 @@ class ChronosModel(AbstractTimeSeriesModel):
308
337
  context_length=context_length,
309
338
  time_limit=kwargs.get("time_limit"),
310
339
  )
340
+
311
341
  self.model_pipeline.model.eval()
312
- with torch.inference_mode():
313
- prediction_samples = [
314
- self.model_pipeline.predict(
342
+ with torch.inference_mode(), disable_duplicate_logs(logger):
343
+ batch_quantiles, batch_means = [], []
344
+ for batch in inference_data_loader:
345
+ qs, mn = self.model_pipeline.predict_quantiles(
315
346
  batch,
316
347
  prediction_length=self.prediction_length,
348
+ quantile_levels=self.quantile_levels,
317
349
  num_samples=self.num_samples,
318
- limit_prediction_length=False,
319
350
  )
320
- .detach()
321
- .cpu()
322
- .numpy()
323
- for batch in inference_data_loader
324
- ]
325
-
326
- samples = np.concatenate(prediction_samples, axis=0).swapaxes(1, 2).reshape(-1, self.num_samples)
327
-
328
- mean = samples.mean(axis=-1, keepdims=True)
329
- quantiles = np.quantile(samples, self.quantile_levels, axis=-1).T
351
+ batch_quantiles.append(qs.numpy())
352
+ batch_means.append(mn.numpy())
330
353
 
331
354
  df = pd.DataFrame(
332
- np.concatenate([mean, quantiles], axis=1),
355
+ np.concatenate(
356
+ [
357
+ np.concatenate(batch_means, axis=0).reshape(-1, 1),
358
+ np.concatenate(batch_quantiles, axis=0).reshape(-1, len(self.quantile_levels)),
359
+ ],
360
+ axis=1,
361
+ ),
333
362
  columns=["mean"] + [str(q) for q in self.quantile_levels],
334
363
  index=get_forecast_horizon_index_ts_dataframe(data, self.prediction_length, freq=self.freq),
335
364
  )
@@ -0,0 +1,11 @@
1
+ from .chronos import ChronosPipeline
2
+ from .chronos_bolt import ChronosBoltPipeline
3
+ from .base import BaseChronosPipeline, ForecastType
4
+
5
+
6
+ __all__ = [
7
+ "BaseChronosPipeline",
8
+ "ChronosBoltPipeline",
9
+ "ChronosPipeline",
10
+ "ForecastType",
11
+ ]
@@ -0,0 +1,146 @@
1
+ # Authors: Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from .utils import left_pad_and_stack_1D
10
+
11
+
12
+ class ForecastType(Enum):
13
+ SAMPLES = "samples"
14
+ QUANTILES = "quantiles"
15
+
16
+
17
+ class PipelineRegistry(type):
18
+ REGISTRY: Dict[str, "PipelineRegistry"] = {}
19
+
20
+ def __new__(cls, name, bases, attrs):
21
+ """See, https://github.com/faif/python-patterns."""
22
+ new_cls = type.__new__(cls, name, bases, attrs)
23
+ if name is not None:
24
+ cls.REGISTRY[name] = new_cls
25
+ if aliases := attrs.get("_aliases"):
26
+ for alias in aliases:
27
+ cls.REGISTRY[alias] = new_cls
28
+ return new_cls
29
+
30
+
31
+ class BaseChronosPipeline(metaclass=PipelineRegistry):
32
+ forecast_type: ForecastType
33
+ dtypes = {
34
+ "bfloat16": torch.bfloat16,
35
+ "float32": torch.float32,
36
+ "float64": torch.float64,
37
+ }
38
+
39
+ def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
40
+ if isinstance(context, list):
41
+ context = left_pad_and_stack_1D(context)
42
+ assert isinstance(context, torch.Tensor)
43
+ if context.ndim == 1:
44
+ context = context.unsqueeze(0)
45
+ assert context.ndim == 2
46
+
47
+ return context
48
+
49
+ def predict(
50
+ self,
51
+ context: Union[torch.Tensor, List[torch.Tensor]],
52
+ prediction_length: Optional[int] = None,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Get forecasts for the given time series.
57
+
58
+ Parameters
59
+ ----------
60
+ context
61
+ Input series. This is either a 1D tensor, or a list
62
+ of 1D tensors, or a 2D tensor whose first dimension
63
+ is batch. In the latter case, use left-padding with
64
+ ``torch.nan`` to align series of different lengths.
65
+ prediction_length
66
+ Time steps to predict. Defaults to a model-dependent
67
+ value if not given.
68
+
69
+ Returns
70
+ -------
71
+ forecasts
72
+ Tensor containing forecasts. The layout and meaning
73
+ of the forecasts values depends on ``self.forecast_type``.
74
+ """
75
+ raise NotImplementedError()
76
+
77
+ def predict_quantiles(
78
+ self, context: torch.Tensor, prediction_length: int, quantile_levels: List[float], **kwargs
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """
81
+ Get quantile and mean forecasts for given time series. All
82
+ predictions are returned on the CPU.
83
+
84
+ Parameters
85
+ ----------
86
+ context
87
+ Input series. This is either a 1D tensor, or a list
88
+ of 1D tensors, or a 2D tensor whose first dimension
89
+ is batch. In the latter case, use left-padding with
90
+ ``torch.nan`` to align series of different lengths.
91
+ prediction_length
92
+ Time steps to predict. Defaults to a model-dependent
93
+ value if not given.
94
+ quantile_levels: List[float]
95
+ Quantile levels to compute
96
+
97
+ Returns
98
+ -------
99
+ quantiles
100
+ Tensor containing quantile forecasts. Shape
101
+ (batch_size, prediction_length, num_quantiles)
102
+ mean
103
+ Tensor containing mean (point) forecasts. Shape
104
+ (batch_size, prediction_length)
105
+ """
106
+ raise NotImplementedError()
107
+
108
+ @classmethod
109
+ def from_pretrained(
110
+ cls,
111
+ pretrained_model_name_or_path: Union[str, Path],
112
+ *model_args,
113
+ force=False,
114
+ **kwargs,
115
+ ):
116
+ """
117
+ Load the model, either from a local path or from the HuggingFace Hub.
118
+ Supports the same arguments as ``AutoConfig`` and ``AutoModel``
119
+ from ``transformers``.
120
+
121
+ When a local path is provided, supports both a folder or a .tar.gz archive.
122
+ """
123
+ from transformers import AutoConfig
124
+
125
+ if str(pretrained_model_name_or_path).startswith("s3://"):
126
+ from .utils import cache_model_from_s3
127
+
128
+ local_model_path = cache_model_from_s3(str(pretrained_model_name_or_path), force=force)
129
+ return cls.from_pretrained(local_model_path, *model_args, **kwargs)
130
+
131
+ torch_dtype = kwargs.get("torch_dtype", "auto")
132
+ if torch_dtype != "auto" and isinstance(torch_dtype, str):
133
+ kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
134
+
135
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
136
+ is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")
137
+
138
+ if not is_valid_config:
139
+ raise ValueError("Not a Chronos config file")
140
+
141
+ pipeline_class_name = getattr(config, "chronos_pipeline_class", "ChronosPipeline")
142
+ class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
143
+ if class_ is None:
144
+ raise ValueError(f"Trying to load unknown pipeline class: {pipeline_class_name}")
145
+
146
+ return class_.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@@ -11,14 +11,16 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
11
11
 
12
12
  import torch
13
13
  import torch.nn as nn
14
- from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig, PreTrainedModel
14
+ from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig, PreTrainedModel
15
15
 
16
16
  from autogluon.timeseries.utils.warning_filters import set_loggers_level
17
17
 
18
- logger = logging.getLogger(__name__)
18
+ from .base import BaseChronosPipeline, ForecastType
19
19
 
20
+ logger = logging.getLogger("autogluon.timeseries.models.chronos")
20
21
 
21
- __all__ = ["ChronosConfig", "ChronosPipeline", "OptimizedChronosPipeline"]
22
+
23
+ __all__ = ["ChronosConfig", "ChronosPipeline"]
22
24
 
23
25
 
24
26
  @dataclass
@@ -35,7 +37,7 @@ class ChronosConfig:
35
37
  pad_token_id: int
36
38
  eos_token_id: int
37
39
  use_eos_token: bool
38
- model_type: Literal["causal", "seq2seq"]
40
+ model_type: Literal["seq2seq"]
39
41
  context_length: int
40
42
  prediction_length: int
41
43
  num_samples: int
@@ -279,18 +281,7 @@ class ChronosPretrainedModel(nn.Module):
279
281
  return preds.reshape(input_ids.size(0), num_samples, -1)
280
282
 
281
283
 
282
- def left_pad_and_stack_1D(tensors: List[torch.Tensor]):
283
- max_len = max(len(c) for c in tensors)
284
- padded = []
285
- for c in tensors:
286
- assert isinstance(c, torch.Tensor)
287
- assert c.ndim == 1
288
- padding = torch.full(size=(max_len - len(c),), fill_value=torch.nan, device=c.device)
289
- padded.append(torch.concat((padding, c), dim=-1))
290
- return torch.stack(padded)
291
-
292
-
293
- class ChronosPipeline:
284
+ class ChronosPipeline(BaseChronosPipeline):
294
285
  """
295
286
  A ``ChronosPipeline`` uses the given tokenizer and model to forecast
296
287
  input time series.
@@ -308,21 +299,12 @@ class ChronosPipeline:
308
299
 
309
300
  tokenizer: ChronosTokenizer
310
301
  model: ChronosPretrainedModel
302
+ forecast_type: ForecastType = ForecastType.SAMPLES
311
303
 
312
304
  def __init__(self, tokenizer, model):
313
305
  self.tokenizer = tokenizer
314
306
  self.model = model
315
307
 
316
- def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
317
- if isinstance(context, list):
318
- context = left_pad_and_stack_1D(context)
319
- assert isinstance(context, torch.Tensor)
320
- if context.ndim == 1:
321
- context = context.unsqueeze(0)
322
- assert context.ndim == 2
323
-
324
- return context
325
-
326
308
  @torch.no_grad()
327
309
  def embed(self, context: Union[torch.Tensor, List[torch.Tensor]]) -> Tuple[torch.Tensor, Any]:
328
310
  """
@@ -363,7 +345,7 @@ class ChronosPipeline:
363
345
  temperature: Optional[float] = None,
364
346
  top_k: Optional[int] = None,
365
347
  top_p: Optional[float] = None,
366
- limit_prediction_length: bool = True,
348
+ limit_prediction_length: bool = False,
367
349
  ) -> torch.Tensor:
368
350
  """
369
351
  Get forecasts for the given time series.
@@ -442,42 +424,33 @@ class ChronosPipeline:
442
424
 
443
425
  return torch.cat(predictions, dim=-1)
444
426
 
445
- @classmethod
446
- def from_pretrained(cls, *args, **kwargs):
447
- """
448
- Load the model, either from a local path or from the HuggingFace Hub.
449
- Supports the same arguments as ``AutoConfig`` and ``AutoModel``
450
- from ``transformers``.
451
- """
452
-
453
- config = AutoConfig.from_pretrained(*args, **kwargs)
454
-
455
- assert hasattr(config, "chronos_config"), "Not a Chronos config file"
456
-
457
- chronos_config = ChronosConfig(**config.chronos_config)
458
-
459
- if chronos_config.model_type == "seq2seq":
460
- inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
461
- else:
462
- assert config.model_type == "causal"
463
- inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
464
-
465
- return cls(
466
- tokenizer=chronos_config.create_tokenizer(),
467
- model=ChronosPretrainedModel(config=chronos_config, model=inner_model),
427
+ def predict_quantiles(
428
+ self,
429
+ context: torch.Tensor,
430
+ prediction_length: int,
431
+ quantile_levels: List[float],
432
+ num_samples: Optional[int] = None,
433
+ **kwargs,
434
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
435
+ num_samples = num_samples or self.model.config.num_samples
436
+ prediction_samples = (
437
+ self.predict(
438
+ context,
439
+ prediction_length=prediction_length,
440
+ num_samples=num_samples,
441
+ )
442
+ .detach()
443
+ .cpu()
444
+ .swapaxes(1, 2)
468
445
  )
446
+ mean = prediction_samples.mean(axis=-1, keepdims=True)
447
+ quantiles = torch.quantile(
448
+ prediction_samples,
449
+ q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype),
450
+ dim=-1,
451
+ ).permute(1, 2, 0)
469
452
 
470
-
471
- class OptimizedChronosPipeline(ChronosPipeline):
472
- """A wrapper around the ChronosPipeline object for CPU-optimized model classes from
473
- HuggingFace optimum.
474
- """
475
-
476
- dtypes = {
477
- "bfloat16": torch.bfloat16,
478
- "float32": torch.float32,
479
- "float64": torch.float64,
480
- }
453
+ return quantiles, mean
481
454
 
482
455
  @classmethod
483
456
  def from_pretrained(cls, *args, **kwargs):
@@ -498,49 +471,40 @@ class OptimizedChronosPipeline(ChronosPipeline):
498
471
  config.chronos_config["context_length"] = context_length
499
472
  chronos_config = ChronosConfig(**config.chronos_config)
500
473
 
501
- torch_dtype = kwargs.get("torch_dtype", "auto")
502
- if torch_dtype != "auto" and isinstance(torch_dtype, str):
503
- kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
504
-
505
- if chronos_config.model_type == "seq2seq":
506
- if optimization_strategy is None:
507
- inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
508
- else:
509
- assert optimization_strategy in [
510
- "onnx",
511
- "openvino",
512
- ], "optimization_strategy not recognized. Please provide one of `onnx` or `openvino`"
513
- torch_dtype = kwargs.pop("torch_dtype", "auto")
514
- if torch_dtype != "auto":
515
- logger.warning(
516
- f"\t`torch_dtype` will be ignored for optimization_strategy {optimization_strategy}"
474
+ assert chronos_config.model_type == "seq2seq"
475
+ if optimization_strategy is None:
476
+ inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
477
+ else:
478
+ assert optimization_strategy in [
479
+ "onnx",
480
+ "openvino",
481
+ ], "optimization_strategy not recognized. Please provide one of `onnx` or `openvino`"
482
+ torch_dtype = kwargs.pop("torch_dtype", "auto")
483
+ if torch_dtype != "auto":
484
+ logger.warning(f"\t`torch_dtype` will be ignored for optimization_strategy {optimization_strategy}")
485
+
486
+ if optimization_strategy == "onnx":
487
+ try:
488
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
489
+ except ImportError:
490
+ raise ImportError(
491
+ "Huggingface Optimum library must be installed with ONNX for using the `onnx` strategy"
517
492
  )
518
493
 
519
- if optimization_strategy == "onnx":
520
- try:
521
- from optimum.onnxruntime import ORTModelForSeq2SeqLM
522
- except ImportError:
523
- raise ImportError(
524
- "Huggingface Optimum library must be installed with ONNX for using the `onnx` strategy"
525
- )
526
-
527
- assert kwargs.pop("device_map", "cpu") in ["cpu", "auto"], "ONNX mode only available on the CPU"
528
- with set_loggers_level(regex=r"^optimum.*", level=logging.ERROR):
529
- inner_model = ORTModelForSeq2SeqLM.from_pretrained(*args, **{**kwargs, "export": True})
530
- elif optimization_strategy == "openvino":
531
- try:
532
- from optimum.intel import OVModelForSeq2SeqLM
533
- except ImportError:
534
- raise ImportError(
535
- "Huggingface Optimum library must be installed with OpenVINO for using the `openvino` strategy"
536
- )
537
- with set_loggers_level(regex=r"^optimum.*", level=logging.ERROR):
538
- inner_model = OVModelForSeq2SeqLM.from_pretrained(
539
- *args, **{**kwargs, "device_map": "cpu", "export": True}
540
- )
541
- else:
542
- assert config.model_type == "causal"
543
- inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
494
+ assert kwargs.pop("device_map", "cpu") in ["cpu", "auto"], "ONNX mode only available on the CPU"
495
+ with set_loggers_level(regex=r"^optimum.*", level=logging.ERROR):
496
+ inner_model = ORTModelForSeq2SeqLM.from_pretrained(*args, **{**kwargs, "export": True})
497
+ elif optimization_strategy == "openvino":
498
+ try:
499
+ from optimum.intel import OVModelForSeq2SeqLM
500
+ except ImportError:
501
+ raise ImportError(
502
+ "Huggingface Optimum library must be installed with OpenVINO for using the `openvino` strategy"
503
+ )
504
+ with set_loggers_level(regex=r"^optimum.*", level=logging.ERROR):
505
+ inner_model = OVModelForSeq2SeqLM.from_pretrained(
506
+ *args, **{**kwargs, "device_map": "cpu", "export": True}
507
+ )
544
508
 
545
509
  return cls(
546
510
  tokenizer=chronos_config.create_tokenizer(),