autogluon.timeseries 1.1.2b20241111__py3-none-any.whl → 1.1.2b20241113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. autogluon/timeseries/metrics/__init__.py +13 -3
  2. autogluon/timeseries/metrics/point.py +50 -0
  3. autogluon/timeseries/models/chronos/model.py +67 -38
  4. autogluon/timeseries/models/chronos/pipeline/__init__.py +11 -0
  5. autogluon/timeseries/models/chronos/pipeline/base.py +146 -0
  6. autogluon/timeseries/models/chronos/{pipeline.py → pipeline/chronos.py} +66 -102
  7. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +511 -0
  8. autogluon/timeseries/models/chronos/{utils.py → pipeline/utils.py} +37 -1
  9. autogluon/timeseries/models/gluonts/torch/models.py +3 -0
  10. autogluon/timeseries/utils/warning_filters.py +20 -0
  11. autogluon/timeseries/version.py +1 -1
  12. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/METADATA +5 -5
  13. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/RECORD +20 -17
  14. /autogluon.timeseries-1.1.2b20241111-py3.8-nspkg.pth → /autogluon.timeseries-1.1.2b20241113-py3.8-nspkg.pth +0 -0
  15. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/LICENSE +0 -0
  16. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/NOTICE +0 -0
  17. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/WHEEL +0 -0
  18. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/namespace_packages.txt +0 -0
  19. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/top_level.txt +0 -0
  20. {autogluon.timeseries-1.1.2b20241111.dist-info → autogluon.timeseries-1.1.2b20241113.dist-info}/zip-safe +0 -0
@@ -2,7 +2,7 @@ from pprint import pformat
2
2
  from typing import Type, Union
3
3
 
4
4
  from .abstract import TimeSeriesScorer
5
- from .point import MAE, MAPE, MASE, MSE, RMSE, RMSLE, RMSSE, SMAPE, WAPE
5
+ from .point import MAE, MAPE, MASE, MSE, RMSE, RMSLE, RMSSE, SMAPE, WAPE, WCD
6
6
  from .quantile import SQL, WQL
7
7
 
8
8
  __all__ = [
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "RMSSE",
17
17
  "SQL",
18
18
  "WAPE",
19
+ "WCD",
19
20
  "WQL",
20
21
  ]
21
22
 
@@ -40,6 +41,11 @@ DEPRECATED_METRICS = {
40
41
  "mean_wQuantileLoss": "WQL",
41
42
  }
42
43
 
44
+ # Experimental metrics that are not yet user facing
45
+ EXPERIMENTAL_METRICS = {
46
+ "WCD": WCD,
47
+ }
48
+
43
49
 
44
50
  def check_get_evaluation_metric(
45
51
  eval_metric: Union[str, TimeSeriesScorer, Type[TimeSeriesScorer], None] = None
@@ -51,12 +57,16 @@ def check_get_evaluation_metric(
51
57
  eval_metric = eval_metric()
52
58
  elif isinstance(eval_metric, str):
53
59
  eval_metric = DEPRECATED_METRICS.get(eval_metric, eval_metric)
54
- if eval_metric.upper() not in AVAILABLE_METRICS:
60
+ metric_name = eval_metric.upper()
61
+ if metric_name in AVAILABLE_METRICS:
62
+ eval_metric = AVAILABLE_METRICS[metric_name]()
63
+ elif metric_name in EXPERIMENTAL_METRICS:
64
+ eval_metric = EXPERIMENTAL_METRICS[metric_name]()
65
+ else:
55
66
  raise ValueError(
56
67
  f"Time series metric {eval_metric} not supported. Available metrics are:\n"
57
68
  f"{pformat(sorted(AVAILABLE_METRICS.keys()))}"
58
69
  )
59
- eval_metric = AVAILABLE_METRICS[eval_metric.upper()]()
60
70
  elif eval_metric is None:
61
71
  eval_metric = AVAILABLE_METRICS[DEFAULT_METRIC_NAME]()
62
72
  else:
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import warnings
2
3
  from typing import Optional
3
4
 
4
5
  import numpy as np
@@ -359,3 +360,52 @@ class RMSLE(TimeSeriesScorer):
359
360
  seasonal_period=seasonal_period,
360
361
  **kwargs,
361
362
  )
363
+
364
+
365
+ class WCD(TimeSeriesScorer):
366
+ r"""Weighted cumulative discrepancy.
367
+
368
+ Measures the discrepancy between the cumulative sum of the forecast and the cumulative sum of the actual values.
369
+
370
+ .. math::
371
+
372
+ \operatorname{WCD} = 2 \cdot \frac{1}{N} \frac{1}{H} \sum_{i=1}^{N} \sum_{t=T+1}^{T+H} \alpha \cdot \max(0, -d_{i, t}) + (1 - \alpha) \cdot \max(0, d_{i, t})
373
+
374
+ where :math:`d_{i, t}` is the difference between the cumulative predicted value and the cumulative actual value
375
+
376
+ .. math::
377
+
378
+ d_{i, t} = \left(\sum_{s=T+1}^t f_{i, s}) - \left(\sum_{s=T+1}^t y_{i, s})
379
+
380
+ Parameters
381
+ ----------
382
+ alpha : float, default = 0.5
383
+ Values > 0.5 correspond put a stronger penalty on underpredictions (when cumulative forecast is below the
384
+ cumulative actual value). Values < 0.5 put a stronger penalty on overpredictions.
385
+ """
386
+
387
+ def __init__(self, alpha: float = 0.5):
388
+ assert 0 < alpha < 1, "alpha must be in (0, 1)"
389
+ self.alpha = alpha
390
+ self.num_items: Optional[int] = None
391
+ warnings.warn(
392
+ f"{self.name} is an experimental metric. Its behavior may change in the future version of AutoGluon."
393
+ )
394
+
395
+ def save_past_metrics(self, data_past: TimeSeriesDataFrame, **kwargs) -> None:
396
+ self.num_items = data_past.num_items
397
+
398
+ def _fast_cumsum(self, y: np.ndarray) -> np.ndarray:
399
+ """Compute the cumulative sum for each consecutive `prediction_length` items in the array."""
400
+ y = y.reshape(self.num_items, -1)
401
+ return np.nancumsum(y, axis=1).ravel()
402
+
403
+ def compute_metric(
404
+ self, data_future: TimeSeriesDataFrame, predictions: TimeSeriesDataFrame, target: str = "target", **kwargs
405
+ ) -> float:
406
+ y_true, y_pred = self._get_point_forecast_score_inputs(data_future, predictions, target=target)
407
+ cumsum_true = self._fast_cumsum(y_true.to_numpy())
408
+ cumsum_pred = self._fast_cumsum(y_pred.to_numpy())
409
+ diffs = cumsum_pred - cumsum_true
410
+ error = diffs * np.where(diffs < 0, -self.alpha, (1 - self.alpha))
411
+ return 2 * self._safemean(error)
@@ -9,9 +9,9 @@ from autogluon.common.loaders import load_pkl
9
9
  from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame
10
10
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
11
11
  from autogluon.timeseries.utils.forecast import get_forecast_horizon_index_ts_dataframe
12
- from autogluon.timeseries.utils.warning_filters import warning_filter
12
+ from autogluon.timeseries.utils.warning_filters import disable_duplicate_logs, warning_filter
13
13
 
14
- logger = logging.getLogger(__name__)
14
+ logger = logging.getLogger("autogluon.timeseries.models.chronos")
15
15
 
16
16
 
17
17
  # allowed HuggingFace model paths with custom parameter definitions
@@ -41,6 +41,21 @@ MODEL_CONFIGS = {
41
41
  "default_torch_dtype": "bfloat16",
42
42
  "default_batch_size": 8,
43
43
  },
44
+ "chronos-bolt-mini": {
45
+ "num_gpus": 0,
46
+ "default_torch_dtype": "auto",
47
+ "default_batch_size": 256,
48
+ },
49
+ "chronos-bolt-small": {
50
+ "num_gpus": 0,
51
+ "default_torch_dtype": "auto",
52
+ "default_batch_size": 256,
53
+ },
54
+ "chronos-bolt-base": {
55
+ "num_gpus": 0,
56
+ "default_torch_dtype": "auto",
57
+ "default_batch_size": 256,
58
+ },
44
59
  }
45
60
 
46
61
 
@@ -50,22 +65,29 @@ MODEL_ALIASES = {
50
65
  "small": "autogluon/chronos-t5-small",
51
66
  "base": "autogluon/chronos-t5-base",
52
67
  "large": "autogluon/chronos-t5-large",
68
+ "bolt-mini": "autogluon/chronos-bolt-mini",
69
+ "bolt-small": "autogluon/chronos-bolt-small",
70
+ "bolt-base": "autogluon/chronos-bolt-base",
53
71
  }
54
72
 
55
73
 
56
74
  class ChronosModel(AbstractTimeSeriesModel):
57
- """Chronos pretrained time series forecasting models, based on the original
58
- `ChronosModel <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py>`_ implementation.
75
+ """Chronos pretrained time series forecasting models. Models can be based on the original
76
+ `ChronosModel <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py>`_ implementation,
77
+ as well as a newer family of Chronos-Bolt models which are capable of much faster inference.
59
78
 
60
- Chronos is family of pretrained models, based on the T5 family, with number of parameters ranging between 8M and 710M.
61
- The full collection of Chronos models is available on
79
+ The original Chronos is a family of pretrained models, based on the T5 family, with number of parameters ranging between
80
+ 8M and 710M. The full collection of Chronos models is available on
62
81
  `Hugging Face <https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444>`_. For Chronos small,
63
- base, and large variants a GPU is required to perform inference efficiently.
64
-
65
- Chronos takes a minimalistic approach to pretraining time series models, by discretizing time series data directly into bins
66
- which are treated as tokens, effectively performing regression by classification. This results in a simple and flexible framework
82
+ base, and large variants a GPU is required to perform inference efficiently. Chronos takes a minimalistic approach to
83
+ pretraining time series models, by discretizing time series data directly into bins which are treated as tokens,
84
+ effectively performing regression by classification. This results in a simple and flexible framework
67
85
  for using any language model in the context of time series forecasting. See [Ansari2024]_ for more information.
68
86
 
87
+ The newer Chronos-Bolt variants enable much faster inference by first "patching" the time series. The resulting
88
+ time series is then fed into a T5 model for forecasting. The Chronos-Bolt variants are capable of much faster inference,
89
+ and can all run on CPUs. Chronos-Bolt models are also available on Hugging Face <https://huggingface.co/autogluon/>`_.
90
+
69
91
  References
70
92
  ----------
71
93
  .. [Ansari2024] Ansari, Abdul Fatir, Stella, Lorenzo et al.
@@ -79,7 +101,8 @@ class ChronosModel(AbstractTimeSeriesModel):
79
101
  Model path used for the model, i.e., a HuggingFace transformers ``name_or_path``. Can be a
80
102
  compatible model name on HuggingFace Hub or a local path to a model directory. Original
81
103
  Chronos models (i.e., ``autogluon/chronos-t5-{model_size}``) can be specified with aliases
82
- ``tiny``, ``mini`` , ``small``, ``base``, and ``large``.
104
+ ``tiny``, ``mini`` , ``small``, ``base``, and ``large``. Chronos-Bolt models can be specified
105
+ with ``bolt-mini``, ``bolt-small``, and ``bolt-base``.
83
106
  batch_size : int, default = 16
84
107
  Size of batches used during inference
85
108
  num_samples : int, default = 20
@@ -90,11 +113,15 @@ class ChronosModel(AbstractTimeSeriesModel):
90
113
  context_length : int or None, default = None
91
114
  The context length to use in the model. Shorter context lengths will decrease model accuracy, but result
92
115
  in faster inference. If None, the model will infer context length from the data set length at inference
93
- time, but set it to a maximum of 512.
116
+ time, but set it to a maximum of 2048. Note that this is only the context length used to pass data into
117
+ the model. Individual model implementations may have different context lengths specified in their configuration,
118
+ and may truncate the context further. For example, original Chronos models have a context length of 512, but
119
+ Chronos-Bolt models handle contexts up to 2048.
94
120
  optimization_strategy : {None, "onnx", "openvino"}, default = None
95
121
  Optimization strategy to use for inference on CPUs. If None, the model will use the default implementation.
96
122
  If `onnx`, the model will be converted to ONNX and the inference will be performed using ONNX. If ``openvino``,
97
- inference will be performed with the model compiled to OpenVINO.
123
+ inference will be performed with the model compiled to OpenVINO. These optimizations are only available for
124
+ the original set of Chronos models, and not in Chronos-Bolt where they are not needed.
98
125
  torch_dtype : torch.dtype or {"auto", "bfloat16", "float32", "float64"}, default = "auto"
99
126
  Torch data type for model weights, provided to ``from_pretrained`` method of Hugging Face AutoModels. If
100
127
  original Chronos models are specified and the model size is ``small``, ``base``, or ``large``, the
@@ -107,7 +134,7 @@ class ChronosModel(AbstractTimeSeriesModel):
107
134
  # default number of samples for prediction
108
135
  default_num_samples: int = 20
109
136
  default_model_path = "autogluon/chronos-t5-small"
110
- maximum_context_length = 512
137
+ maximum_context_length = 2048
111
138
 
112
139
  def __init__(
113
140
  self,
@@ -159,7 +186,7 @@ class ChronosModel(AbstractTimeSeriesModel):
159
186
  **kwargs,
160
187
  )
161
188
 
162
- self.model_pipeline: Optional[Any] = None # of type OptimizedChronosPipeline
189
+ self.model_pipeline: Optional[Any] = None # of type BaseChronosPipeline
163
190
  self.time_limit: Optional[float] = None
164
191
 
165
192
  def save(self, path: str = None, verbose: bool = True) -> str:
@@ -218,8 +245,8 @@ class ChronosModel(AbstractTimeSeriesModel):
218
245
  minimum_resources["num_gpus"] = self.min_num_gpus
219
246
  return minimum_resources
220
247
 
221
- def load_model_pipeline(self, context_length: Optional[int] = None):
222
- from .pipeline import OptimizedChronosPipeline
248
+ def load_model_pipeline(self):
249
+ from .pipeline import BaseChronosPipeline
223
250
 
224
251
  gpu_available = self._is_gpu_available()
225
252
 
@@ -232,18 +259,17 @@ class ChronosModel(AbstractTimeSeriesModel):
232
259
 
233
260
  device = self.device or ("cuda" if gpu_available else "cpu")
234
261
 
235
- pipeline = OptimizedChronosPipeline.from_pretrained(
262
+ pipeline = BaseChronosPipeline.from_pretrained(
236
263
  self.model_path,
237
264
  device_map=device,
238
- optimization_strategy=self.optimization_strategy,
239
265
  torch_dtype=self.torch_dtype,
240
- context_length=context_length or self.context_length,
266
+ optimization_strategy=self.optimization_strategy,
241
267
  )
242
268
 
243
269
  self.model_pipeline = pipeline
244
270
 
245
271
  def persist(self) -> "ChronosModel":
246
- self.load_model_pipeline(context_length=self.context_length or self.maximum_context_length)
272
+ self.load_model_pipeline()
247
273
  return self
248
274
 
249
275
  def _fit(
@@ -263,7 +289,7 @@ class ChronosModel(AbstractTimeSeriesModel):
263
289
  num_workers: int = 0,
264
290
  time_limit: Optional[float] = None,
265
291
  ):
266
- from .utils import ChronosInferenceDataLoader, ChronosInferenceDataset, timeout_callback
292
+ from .pipeline.utils import ChronosInferenceDataLoader, ChronosInferenceDataset, timeout_callback
267
293
 
268
294
  chronos_dataset = ChronosInferenceDataset(
269
295
  target_df=data,
@@ -290,6 +316,9 @@ class ChronosModel(AbstractTimeSeriesModel):
290
316
  # and use that to determine the context length of the model. If the context length is specified
291
317
  # during initialization, this is always used. If not, the context length is set to the longest
292
318
  # item length. The context length is always capped by self.maximum_context_length.
319
+ # Note that this is independent of the model's own context length set in the model's config file.
320
+ # For example, if the context_length is set to 2048 here but the model expects context length
321
+ # (according to its config.json file) of 512, it will further truncate the series during inference.
293
322
  context_length = self.context_length or min(
294
323
  data.num_timesteps_per_item().max(),
295
324
  self.maximum_context_length,
@@ -300,7 +329,7 @@ class ChronosModel(AbstractTimeSeriesModel):
300
329
 
301
330
  if self.model_pipeline is None:
302
331
  # load model pipeline to device memory
303
- self.load_model_pipeline(context_length=context_length)
332
+ self.load_model_pipeline()
304
333
 
305
334
  inference_data_loader = self._get_inference_data_loader(
306
335
  data=data,
@@ -308,28 +337,28 @@ class ChronosModel(AbstractTimeSeriesModel):
308
337
  context_length=context_length,
309
338
  time_limit=kwargs.get("time_limit"),
310
339
  )
340
+
311
341
  self.model_pipeline.model.eval()
312
- with torch.inference_mode():
313
- prediction_samples = [
314
- self.model_pipeline.predict(
342
+ with torch.inference_mode(), disable_duplicate_logs(logger):
343
+ batch_quantiles, batch_means = [], []
344
+ for batch in inference_data_loader:
345
+ qs, mn = self.model_pipeline.predict_quantiles(
315
346
  batch,
316
347
  prediction_length=self.prediction_length,
348
+ quantile_levels=self.quantile_levels,
317
349
  num_samples=self.num_samples,
318
- limit_prediction_length=False,
319
350
  )
320
- .detach()
321
- .cpu()
322
- .numpy()
323
- for batch in inference_data_loader
324
- ]
325
-
326
- samples = np.concatenate(prediction_samples, axis=0).swapaxes(1, 2).reshape(-1, self.num_samples)
327
-
328
- mean = samples.mean(axis=-1, keepdims=True)
329
- quantiles = np.quantile(samples, self.quantile_levels, axis=-1).T
351
+ batch_quantiles.append(qs.numpy())
352
+ batch_means.append(mn.numpy())
330
353
 
331
354
  df = pd.DataFrame(
332
- np.concatenate([mean, quantiles], axis=1),
355
+ np.concatenate(
356
+ [
357
+ np.concatenate(batch_means, axis=0).reshape(-1, 1),
358
+ np.concatenate(batch_quantiles, axis=0).reshape(-1, len(self.quantile_levels)),
359
+ ],
360
+ axis=1,
361
+ ),
333
362
  columns=["mean"] + [str(q) for q in self.quantile_levels],
334
363
  index=get_forecast_horizon_index_ts_dataframe(data, self.prediction_length, freq=self.freq),
335
364
  )
@@ -0,0 +1,11 @@
1
+ from .chronos import ChronosPipeline
2
+ from .chronos_bolt import ChronosBoltPipeline
3
+ from .base import BaseChronosPipeline, ForecastType
4
+
5
+
6
+ __all__ = [
7
+ "BaseChronosPipeline",
8
+ "ChronosBoltPipeline",
9
+ "ChronosPipeline",
10
+ "ForecastType",
11
+ ]
@@ -0,0 +1,146 @@
1
+ # Authors: Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from .utils import left_pad_and_stack_1D
10
+
11
+
12
+ class ForecastType(Enum):
13
+ SAMPLES = "samples"
14
+ QUANTILES = "quantiles"
15
+
16
+
17
+ class PipelineRegistry(type):
18
+ REGISTRY: Dict[str, "PipelineRegistry"] = {}
19
+
20
+ def __new__(cls, name, bases, attrs):
21
+ """See, https://github.com/faif/python-patterns."""
22
+ new_cls = type.__new__(cls, name, bases, attrs)
23
+ if name is not None:
24
+ cls.REGISTRY[name] = new_cls
25
+ if aliases := attrs.get("_aliases"):
26
+ for alias in aliases:
27
+ cls.REGISTRY[alias] = new_cls
28
+ return new_cls
29
+
30
+
31
+ class BaseChronosPipeline(metaclass=PipelineRegistry):
32
+ forecast_type: ForecastType
33
+ dtypes = {
34
+ "bfloat16": torch.bfloat16,
35
+ "float32": torch.float32,
36
+ "float64": torch.float64,
37
+ }
38
+
39
+ def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
40
+ if isinstance(context, list):
41
+ context = left_pad_and_stack_1D(context)
42
+ assert isinstance(context, torch.Tensor)
43
+ if context.ndim == 1:
44
+ context = context.unsqueeze(0)
45
+ assert context.ndim == 2
46
+
47
+ return context
48
+
49
+ def predict(
50
+ self,
51
+ context: Union[torch.Tensor, List[torch.Tensor]],
52
+ prediction_length: Optional[int] = None,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Get forecasts for the given time series.
57
+
58
+ Parameters
59
+ ----------
60
+ context
61
+ Input series. This is either a 1D tensor, or a list
62
+ of 1D tensors, or a 2D tensor whose first dimension
63
+ is batch. In the latter case, use left-padding with
64
+ ``torch.nan`` to align series of different lengths.
65
+ prediction_length
66
+ Time steps to predict. Defaults to a model-dependent
67
+ value if not given.
68
+
69
+ Returns
70
+ -------
71
+ forecasts
72
+ Tensor containing forecasts. The layout and meaning
73
+ of the forecasts values depends on ``self.forecast_type``.
74
+ """
75
+ raise NotImplementedError()
76
+
77
+ def predict_quantiles(
78
+ self, context: torch.Tensor, prediction_length: int, quantile_levels: List[float], **kwargs
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """
81
+ Get quantile and mean forecasts for given time series. All
82
+ predictions are returned on the CPU.
83
+
84
+ Parameters
85
+ ----------
86
+ context
87
+ Input series. This is either a 1D tensor, or a list
88
+ of 1D tensors, or a 2D tensor whose first dimension
89
+ is batch. In the latter case, use left-padding with
90
+ ``torch.nan`` to align series of different lengths.
91
+ prediction_length
92
+ Time steps to predict. Defaults to a model-dependent
93
+ value if not given.
94
+ quantile_levels: List[float]
95
+ Quantile levels to compute
96
+
97
+ Returns
98
+ -------
99
+ quantiles
100
+ Tensor containing quantile forecasts. Shape
101
+ (batch_size, prediction_length, num_quantiles)
102
+ mean
103
+ Tensor containing mean (point) forecasts. Shape
104
+ (batch_size, prediction_length)
105
+ """
106
+ raise NotImplementedError()
107
+
108
+ @classmethod
109
+ def from_pretrained(
110
+ cls,
111
+ pretrained_model_name_or_path: Union[str, Path],
112
+ *model_args,
113
+ force=False,
114
+ **kwargs,
115
+ ):
116
+ """
117
+ Load the model, either from a local path or from the HuggingFace Hub.
118
+ Supports the same arguments as ``AutoConfig`` and ``AutoModel``
119
+ from ``transformers``.
120
+
121
+ When a local path is provided, supports both a folder or a .tar.gz archive.
122
+ """
123
+ from transformers import AutoConfig
124
+
125
+ if str(pretrained_model_name_or_path).startswith("s3://"):
126
+ from .utils import cache_model_from_s3
127
+
128
+ local_model_path = cache_model_from_s3(str(pretrained_model_name_or_path), force=force)
129
+ return cls.from_pretrained(local_model_path, *model_args, **kwargs)
130
+
131
+ torch_dtype = kwargs.get("torch_dtype", "auto")
132
+ if torch_dtype != "auto" and isinstance(torch_dtype, str):
133
+ kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
134
+
135
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
136
+ is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")
137
+
138
+ if not is_valid_config:
139
+ raise ValueError("Not a Chronos config file")
140
+
141
+ pipeline_class_name = getattr(config, "chronos_pipeline_class", "ChronosPipeline")
142
+ class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
143
+ if class_ is None:
144
+ raise ValueError(f"Trying to load unknown pipeline class: {pipeline_class_name}")
145
+
146
+ return class_.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)