autogluon.timeseries 1.1.2b20241112__py3-none-any.whl → 1.1.2b20241114__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. autogluon/timeseries/metrics/__init__.py +13 -3
  2. autogluon/timeseries/metrics/point.py +50 -0
  3. autogluon/timeseries/models/chronos/model.py +269 -12
  4. autogluon/timeseries/models/chronos/pipeline/base.py +14 -1
  5. autogluon/timeseries/models/chronos/pipeline/chronos.py +86 -19
  6. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +8 -1
  7. autogluon/timeseries/models/chronos/pipeline/utils.py +239 -3
  8. autogluon/timeseries/models/gluonts/abstract_gluonts.py +33 -22
  9. autogluon/timeseries/models/gluonts/torch/models.py +39 -27
  10. autogluon/timeseries/version.py +1 -1
  11. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/METADATA +4 -4
  12. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/RECORD +19 -19
  13. /autogluon.timeseries-1.1.2b20241112-py3.8-nspkg.pth → /autogluon.timeseries-1.1.2b20241114-py3.8-nspkg.pth +0 -0
  14. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/LICENSE +0 -0
  15. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/NOTICE +0 -0
  16. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/WHEEL +0 -0
  17. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/namespace_packages.txt +0 -0
  18. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/top_level.txt +0 -0
  19. {autogluon.timeseries-1.1.2b20241112.dist-info → autogluon.timeseries-1.1.2b20241114.dist-info}/zip-safe +0 -0
@@ -2,7 +2,7 @@ from pprint import pformat
2
2
  from typing import Type, Union
3
3
 
4
4
  from .abstract import TimeSeriesScorer
5
- from .point import MAE, MAPE, MASE, MSE, RMSE, RMSLE, RMSSE, SMAPE, WAPE
5
+ from .point import MAE, MAPE, MASE, MSE, RMSE, RMSLE, RMSSE, SMAPE, WAPE, WCD
6
6
  from .quantile import SQL, WQL
7
7
 
8
8
  __all__ = [
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "RMSSE",
17
17
  "SQL",
18
18
  "WAPE",
19
+ "WCD",
19
20
  "WQL",
20
21
  ]
21
22
 
@@ -40,6 +41,11 @@ DEPRECATED_METRICS = {
40
41
  "mean_wQuantileLoss": "WQL",
41
42
  }
42
43
 
44
+ # Experimental metrics that are not yet user facing
45
+ EXPERIMENTAL_METRICS = {
46
+ "WCD": WCD,
47
+ }
48
+
43
49
 
44
50
  def check_get_evaluation_metric(
45
51
  eval_metric: Union[str, TimeSeriesScorer, Type[TimeSeriesScorer], None] = None
@@ -51,12 +57,16 @@ def check_get_evaluation_metric(
51
57
  eval_metric = eval_metric()
52
58
  elif isinstance(eval_metric, str):
53
59
  eval_metric = DEPRECATED_METRICS.get(eval_metric, eval_metric)
54
- if eval_metric.upper() not in AVAILABLE_METRICS:
60
+ metric_name = eval_metric.upper()
61
+ if metric_name in AVAILABLE_METRICS:
62
+ eval_metric = AVAILABLE_METRICS[metric_name]()
63
+ elif metric_name in EXPERIMENTAL_METRICS:
64
+ eval_metric = EXPERIMENTAL_METRICS[metric_name]()
65
+ else:
55
66
  raise ValueError(
56
67
  f"Time series metric {eval_metric} not supported. Available metrics are:\n"
57
68
  f"{pformat(sorted(AVAILABLE_METRICS.keys()))}"
58
69
  )
59
- eval_metric = AVAILABLE_METRICS[eval_metric.upper()]()
60
70
  elif eval_metric is None:
61
71
  eval_metric = AVAILABLE_METRICS[DEFAULT_METRIC_NAME]()
62
72
  else:
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import warnings
2
3
  from typing import Optional
3
4
 
4
5
  import numpy as np
@@ -359,3 +360,52 @@ class RMSLE(TimeSeriesScorer):
359
360
  seasonal_period=seasonal_period,
360
361
  **kwargs,
361
362
  )
363
+
364
+
365
+ class WCD(TimeSeriesScorer):
366
+ r"""Weighted cumulative discrepancy.
367
+
368
+ Measures the discrepancy between the cumulative sum of the forecast and the cumulative sum of the actual values.
369
+
370
+ .. math::
371
+
372
+ \operatorname{WCD} = 2 \cdot \frac{1}{N} \frac{1}{H} \sum_{i=1}^{N} \sum_{t=T+1}^{T+H} \alpha \cdot \max(0, -d_{i, t}) + (1 - \alpha) \cdot \max(0, d_{i, t})
373
+
374
+ where :math:`d_{i, t}` is the difference between the cumulative predicted value and the cumulative actual value
375
+
376
+ .. math::
377
+
378
+ d_{i, t} = \left(\sum_{s=T+1}^t f_{i, s}) - \left(\sum_{s=T+1}^t y_{i, s})
379
+
380
+ Parameters
381
+ ----------
382
+ alpha : float, default = 0.5
383
+ Values > 0.5 correspond put a stronger penalty on underpredictions (when cumulative forecast is below the
384
+ cumulative actual value). Values < 0.5 put a stronger penalty on overpredictions.
385
+ """
386
+
387
+ def __init__(self, alpha: float = 0.5):
388
+ assert 0 < alpha < 1, "alpha must be in (0, 1)"
389
+ self.alpha = alpha
390
+ self.num_items: Optional[int] = None
391
+ warnings.warn(
392
+ f"{self.name} is an experimental metric. Its behavior may change in the future version of AutoGluon."
393
+ )
394
+
395
+ def save_past_metrics(self, data_past: TimeSeriesDataFrame, **kwargs) -> None:
396
+ self.num_items = data_past.num_items
397
+
398
+ def _fast_cumsum(self, y: np.ndarray) -> np.ndarray:
399
+ """Compute the cumulative sum for each consecutive `prediction_length` items in the array."""
400
+ y = y.reshape(self.num_items, -1)
401
+ return np.nancumsum(y, axis=1).ravel()
402
+
403
+ def compute_metric(
404
+ self, data_future: TimeSeriesDataFrame, predictions: TimeSeriesDataFrame, target: str = "target", **kwargs
405
+ ) -> float:
406
+ y_true, y_pred = self._get_point_forecast_score_inputs(data_future, predictions, target=target)
407
+ cumsum_true = self._fast_cumsum(y_true.to_numpy())
408
+ cumsum_pred = self._fast_cumsum(y_pred.to_numpy())
409
+ diffs = cumsum_pred - cumsum_true
410
+ error = diffs * np.where(diffs < 0, -self.alpha, (1 - self.alpha))
411
+ return 2 * self._safemean(error)
@@ -1,5 +1,8 @@
1
1
  import logging
2
2
  import os
3
+ import shutil
4
+ import time
5
+ from pathlib import Path
3
6
  from typing import Any, Dict, Literal, Optional, Union
4
7
 
5
8
  import numpy as np
@@ -72,9 +75,10 @@ MODEL_ALIASES = {
72
75
 
73
76
 
74
77
  class ChronosModel(AbstractTimeSeriesModel):
75
- """Chronos pretrained time series forecasting models. Models can be based on the original
78
+ """Chronos [Ansari2024]_ pretrained time series forecasting models which can be used for zero-shot forecasting or fine-tuned
79
+ in a task-specific manner. Models can be based on the original
76
80
  `ChronosModel <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py>`_ implementation,
77
- as well as a newer family of Chronos-Bolt models which are capable of much faster inference.
81
+ as well as a newer family of Chronos-Bolt models capable of much faster inference.
78
82
 
79
83
  The original Chronos is a family of pretrained models, based on the T5 family, with number of parameters ranging between
80
84
  8M and 710M. The full collection of Chronos models is available on
@@ -88,6 +92,9 @@ class ChronosModel(AbstractTimeSeriesModel):
88
92
  time series is then fed into a T5 model for forecasting. The Chronos-Bolt variants are capable of much faster inference,
89
93
  and can all run on CPUs. Chronos-Bolt models are also available on Hugging Face <https://huggingface.co/autogluon/>`_.
90
94
 
95
+ Both Chronos and Chronos-Bolt variants can be fine-tuned by setting ``fine_tune=True`` and selecting appropriate
96
+ fine-tuning parameters such as the learning rate (``fine_tune_lr``) and max steps (``fine_tune_steps``).
97
+
91
98
  References
92
99
  ----------
93
100
  .. [Ansari2024] Ansari, Abdul Fatir, Stella, Lorenzo et al.
@@ -108,8 +115,8 @@ class ChronosModel(AbstractTimeSeriesModel):
108
115
  num_samples : int, default = 20
109
116
  Number of samples used during inference
110
117
  device : str, default = None
111
- Device to use for inference. If None, model will use the GPU if available. For larger model sizes
112
- `small`, `base`, and `large`; inference will fail if no GPU is available.
118
+ Device to use for inference (and fine-tuning, if enabled). If None, model will use the GPU if available.
119
+ For larger model sizes `small`, `base`, and `large`; inference will fail if no GPU is available.
113
120
  context_length : int or None, default = None
114
121
  The context length to use in the model. Shorter context lengths will decrease model accuracy, but result
115
122
  in faster inference. If None, the model will infer context length from the data set length at inference
@@ -129,12 +136,34 @@ class ChronosModel(AbstractTimeSeriesModel):
129
136
  data_loader_num_workers : int, default = 0
130
137
  Number of worker processes to be used in the data loader. See documentation on ``torch.utils.data.DataLoader``
131
138
  for more information.
139
+ fine_tune : bool, default = False
140
+ If True, the pretrained model will be fine-tuned
141
+ fine_tune_lr: float, default = 0.0001
142
+ The learning rate used for fine-tuning
143
+ fine_tune_steps : int, default = 5000
144
+ The number of gradient update steps to fine-tune for
145
+ fine_tune_batch_size : int, default = 16
146
+ The batch size to use for fine-tuning
147
+ fine_tune_shuffle_buffer_size : int, default = 10000
148
+ The size of the shuffle buffer to shuffle the data during fine-tuning. If None, shuffling will
149
+ be turned off.
150
+ eval_during_fine_tune : bool, default = False
151
+ If True, validation will be performed during fine-tuning to select the best checkpoint.
152
+ Setting this argument to True may result in slower fine-tuning.
153
+ fine_tune_eval_max_items : int, default = 256
154
+ The maximum number of randomly-sampled time series to use from the validation set for evaluation
155
+ during fine-tuning. If None, the entire validation dataset will be used.
156
+ fine_tune_trainer_kwargs : dict, optional
157
+ Extra keyword arguments passed to ``transformers.TrainingArguments``
158
+ keep_transformers_logs: bool, default = False
159
+ If True, the logs generated by transformers will NOT be removed after fine-tuning
132
160
  """
133
161
 
134
162
  # default number of samples for prediction
135
163
  default_num_samples: int = 20
136
164
  default_model_path = "autogluon/chronos-t5-small"
137
165
  maximum_context_length = 2048
166
+ fine_tuned_ckpt_name: str = "fine-tuned-ckpt"
138
167
 
139
168
  def __init__(
140
169
  self,
@@ -202,6 +231,12 @@ class ChronosModel(AbstractTimeSeriesModel):
202
231
  model = load_pkl.load(path=os.path.join(path, cls.model_file_name), verbose=verbose)
203
232
  if reset_paths:
204
233
  model.set_contexts(path)
234
+
235
+ fine_tune_ckpt_path = Path(model.path) / cls.fine_tuned_ckpt_name
236
+ if fine_tune_ckpt_path.exists():
237
+ logger.debug(f"Fine-tuned checkpoint exists, setting model_path to {fine_tune_ckpt_path}")
238
+ model.model_path = fine_tune_ckpt_path
239
+
205
240
  return model
206
241
 
207
242
  def _is_gpu_available(self) -> bool:
@@ -245,7 +280,7 @@ class ChronosModel(AbstractTimeSeriesModel):
245
280
  minimum_resources["num_gpus"] = self.min_num_gpus
246
281
  return minimum_resources
247
282
 
248
- def load_model_pipeline(self):
283
+ def load_model_pipeline(self, is_training: bool = False):
249
284
  from .pipeline import BaseChronosPipeline
250
285
 
251
286
  gpu_available = self._is_gpu_available()
@@ -262,8 +297,9 @@ class ChronosModel(AbstractTimeSeriesModel):
262
297
  pipeline = BaseChronosPipeline.from_pretrained(
263
298
  self.model_path,
264
299
  device_map=device,
300
+ # optimization cannot be used during fine-tuning
301
+ optimization_strategy=None if is_training else self.optimization_strategy,
265
302
  torch_dtype=self.torch_dtype,
266
- optimization_strategy=self.optimization_strategy,
267
303
  )
268
304
 
269
305
  self.model_pipeline = pipeline
@@ -272,6 +308,59 @@ class ChronosModel(AbstractTimeSeriesModel):
272
308
  self.load_model_pipeline()
273
309
  return self
274
310
 
311
+ def _has_tf32(self):
312
+ import torch.cuda
313
+
314
+ return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
315
+
316
+ def _get_model_params(self) -> dict:
317
+ """Gets params that are passed to the inner model."""
318
+ init_args = super()._get_model_params().copy()
319
+
320
+ init_args.setdefault("fine_tune", False)
321
+ init_args.setdefault("keep_transformers_logs", False)
322
+ init_args.setdefault("fine_tune_lr", 1e-4)
323
+ init_args.setdefault("fine_tune_steps", 5000)
324
+ init_args.setdefault("fine_tune_batch_size", self.default_batch_size)
325
+ init_args.setdefault("eval_during_fine_tune", False)
326
+ init_args.setdefault("fine_tune_eval_max_items", 256)
327
+ init_args.setdefault("fine_tune_shuffle_buffer_size", 10_000)
328
+
329
+ eval_during_fine_tune = init_args["eval_during_fine_tune"]
330
+ output_dir = Path(self.path) / "transformers_logs"
331
+ fine_tune_trainer_kwargs = dict(
332
+ output_dir=str(output_dir),
333
+ per_device_train_batch_size=init_args["fine_tune_batch_size"],
334
+ per_device_eval_batch_size=init_args["fine_tune_batch_size"],
335
+ learning_rate=init_args["fine_tune_lr"],
336
+ lr_scheduler_type="linear",
337
+ warmup_ratio=0.0,
338
+ optim="adamw_torch_fused",
339
+ logging_dir=str(output_dir),
340
+ logging_strategy="steps",
341
+ logging_steps=100,
342
+ report_to="none",
343
+ max_steps=init_args["fine_tune_steps"],
344
+ gradient_accumulation_steps=1,
345
+ dataloader_num_workers=self.data_loader_num_workers,
346
+ tf32=self._has_tf32(),
347
+ save_only_model=True,
348
+ prediction_loss_only=True,
349
+ save_total_limit=1,
350
+ save_strategy="steps" if eval_during_fine_tune else "no",
351
+ save_steps=100 if eval_during_fine_tune else None,
352
+ evaluation_strategy="steps" if eval_during_fine_tune else "no",
353
+ eval_steps=100 if eval_during_fine_tune else None,
354
+ load_best_model_at_end=True if eval_during_fine_tune else False,
355
+ metric_for_best_model="eval_loss" if eval_during_fine_tune else None,
356
+ )
357
+ user_fine_tune_trainer_kwargs = init_args.get("fine_tune_trainer_kwargs", {})
358
+ fine_tune_trainer_kwargs.update(user_fine_tune_trainer_kwargs)
359
+
360
+ init_args["fine_tune_trainer_kwargs"] = fine_tune_trainer_kwargs
361
+
362
+ return init_args
363
+
275
364
  def _fit(
276
365
  self,
277
366
  train_data: TimeSeriesDataFrame,
@@ -279,8 +368,171 @@ class ChronosModel(AbstractTimeSeriesModel):
279
368
  time_limit: int = None,
280
369
  **kwargs,
281
370
  ) -> None:
371
+ from transformers.trainer import PrinterCallback, Trainer, TrainingArguments
372
+
373
+ from .pipeline import ChronosBoltPipeline, ChronosPipeline
374
+ from .pipeline.utils import (
375
+ ChronosFineTuningDataset,
376
+ EvaluateAndSaveFinalStepCallback,
377
+ LoggerCallback,
378
+ TimeLimitCallback,
379
+ )
380
+
381
+ # TODO: Add support for fine-tuning models with context_length longer than the pretrained model
382
+
383
+ # verbosity < 3: all logs and warnings from transformers will be suppressed
384
+ # verbosity >= 3: progress bar and loss logs will be logged
385
+ # verbosity 4: everything will be logged
386
+ verbosity = kwargs.get("verbosity", 2)
387
+ for logger_name in logging.root.manager.loggerDict:
388
+ if "transformers" in logger_name:
389
+ transformers_logger = logging.getLogger(logger_name)
390
+ transformers_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
391
+
282
392
  self._check_fit_params()
283
- self.time_limit = time_limit
393
+
394
+ fine_tune_args = self._get_model_params()
395
+ do_fine_tune = fine_tune_args["fine_tune"]
396
+
397
+ if do_fine_tune:
398
+ assert train_data is not None, "train_data cannot be None when fine_tune=True"
399
+
400
+ eval_during_fine_tune = val_data is not None and fine_tune_args["eval_during_fine_tune"]
401
+
402
+ start_time = time.monotonic()
403
+ if do_fine_tune:
404
+ context_length = self._get_context_length(train_data)
405
+ # load model pipeline to device memory
406
+ self.load_model_pipeline(is_training=True)
407
+
408
+ fine_tune_prediction_length = self.prediction_length
409
+ model_prediction_length = self.model_pipeline.inner_model.config.chronos_config["prediction_length"]
410
+
411
+ if isinstance(self.model_pipeline, ChronosPipeline):
412
+ pipeline_specific_trainer_kwargs = {}
413
+
414
+ # Update prediction_length of the model
415
+ # NOTE: We only do this for ChronosPipeline because the prediction length of ChronosBolt models
416
+ # is fixed due to direct multistep forecasting setup
417
+ self.model_pipeline.model.config.prediction_length = fine_tune_prediction_length
418
+ self.model_pipeline.inner_model.config.chronos_config["prediction_length"] = (
419
+ fine_tune_prediction_length
420
+ )
421
+
422
+ elif isinstance(self.model_pipeline, ChronosBoltPipeline):
423
+ # custom label_names is needed for validation to work with ChronosBolt models
424
+ pipeline_specific_trainer_kwargs = dict(label_names=["target"])
425
+
426
+ # truncate prediction_length if it goes beyond ChronosBolt's prediction_length
427
+ fine_tune_prediction_length = min(model_prediction_length, self.prediction_length)
428
+
429
+ if self.prediction_length != fine_tune_prediction_length:
430
+ logger.debug(
431
+ f"ChronosBolt models can only be fine-tuned with a maximum prediction_length of {model_prediction_length}. "
432
+ f"Fine-tuning prediction_length has been changed to {fine_tune_prediction_length}."
433
+ )
434
+
435
+ fine_tune_trainer_kwargs = fine_tune_args["fine_tune_trainer_kwargs"]
436
+ fine_tune_trainer_kwargs["disable_tqdm"] = fine_tune_trainer_kwargs.get("disable_tqdm", (verbosity < 3))
437
+ fine_tune_trainer_kwargs["use_cpu"] = str(self.model_pipeline.inner_model.device) == "cpu"
438
+ output_dir = Path(fine_tune_trainer_kwargs["output_dir"])
439
+
440
+ if not eval_during_fine_tune:
441
+ # turn off eval-related trainer args
442
+ fine_tune_trainer_kwargs["evaluation_strategy"] = "no"
443
+ fine_tune_trainer_kwargs["eval_steps"] = None
444
+ fine_tune_trainer_kwargs["load_best_model_at_end"] = False
445
+ fine_tune_trainer_kwargs["metric_for_best_model"] = None
446
+
447
+ training_args = TrainingArguments(**fine_tune_trainer_kwargs, **pipeline_specific_trainer_kwargs)
448
+ tokenizer_train_dataset = ChronosFineTuningDataset(
449
+ target_df=train_data,
450
+ target_column=self.target,
451
+ context_length=context_length,
452
+ prediction_length=fine_tune_prediction_length,
453
+ # if tokenizer exists, then the data is returned in the HF-style format accepted by
454
+ # the original Chronos models otherwise the data is returned in ChronosBolt's format
455
+ tokenizer=getattr(self.model_pipeline, "tokenizer", None),
456
+ mode="training",
457
+ ).shuffle(fine_tune_args["fine_tune_shuffle_buffer_size"])
458
+
459
+ callbacks = []
460
+ if time_limit is not None:
461
+ callbacks.append(TimeLimitCallback(time_limit=time_limit))
462
+
463
+ if val_data is not None:
464
+ callbacks.append(EvaluateAndSaveFinalStepCallback())
465
+ # evaluate on a randomly-sampled subset
466
+ fine_tune_eval_max_items = (
467
+ min(val_data.num_items, fine_tune_args["fine_tune_eval_max_items"])
468
+ if fine_tune_args["fine_tune_eval_max_items"] is not None
469
+ else val_data.num_items
470
+ )
471
+
472
+ if fine_tune_eval_max_items < val_data.num_items:
473
+ eval_items = np.random.choice(
474
+ val_data.item_ids.values, size=fine_tune_eval_max_items, replace=False
475
+ )
476
+ val_data = val_data.loc[eval_items]
477
+
478
+ tokenizer_val_dataset = ChronosFineTuningDataset(
479
+ target_df=val_data,
480
+ target_column=self.target,
481
+ context_length=context_length,
482
+ prediction_length=fine_tune_prediction_length,
483
+ tokenizer=getattr(self.model_pipeline, "tokenizer", None),
484
+ mode="validation",
485
+ )
486
+
487
+ trainer = Trainer(
488
+ model=self.model_pipeline.inner_model,
489
+ args=training_args,
490
+ train_dataset=tokenizer_train_dataset,
491
+ eval_dataset=tokenizer_val_dataset if val_data is not None else None,
492
+ callbacks=callbacks,
493
+ )
494
+
495
+ # remove PrinterCallback from callbacks which logs to the console via a print() call,
496
+ # so it cannot be handled by setting the log level
497
+ trainer.pop_callback(PrinterCallback)
498
+
499
+ if verbosity >= 3:
500
+ logger.warning(
501
+ "Transformers logging is turned on during fine-tuning. Note that losses reported by transformers "
502
+ "may not correspond to those specified via `eval_metric`."
503
+ )
504
+ trainer.add_callback(LoggerCallback())
505
+
506
+ if val_data is not None:
507
+ # evaluate once before training
508
+ zero_shot_eval_loss = trainer.evaluate()["eval_loss"]
509
+
510
+ trainer.train()
511
+
512
+ if eval_during_fine_tune:
513
+ # get the best eval_loss logged during fine-tuning
514
+ log_history_df = pd.DataFrame(trainer.state.log_history)
515
+ best_train_eval_loss = log_history_df["eval_loss"].min()
516
+ elif val_data is not None:
517
+ # evaluate at the end of fine-tuning
518
+ best_train_eval_loss = trainer.evaluate()["eval_loss"]
519
+
520
+ if val_data is None or best_train_eval_loss <= zero_shot_eval_loss:
521
+ fine_tuned_ckpt_path = Path(self.path) / self.fine_tuned_ckpt_name
522
+ logger.info(f"Saving fine-tuned model to {fine_tuned_ckpt_path}")
523
+ self.model_pipeline.inner_model.save_pretrained(Path(self.path) / self.fine_tuned_ckpt_name)
524
+ else:
525
+ # Reset the model to its pretrained state
526
+ logger.info("Validation loss worsened after fine-tuning. Reverting to the pretrained model.")
527
+ self.model_pipeline = None
528
+ self.load_model_pipeline(is_training=False)
529
+
530
+ if not fine_tune_args["keep_transformers_logs"]:
531
+ logger.debug(f"Removing transformers_logs directory {output_dir}")
532
+ shutil.rmtree(output_dir)
533
+
534
+ if time_limit is not None:
535
+ self.time_limit = time_limit - (time.monotonic() - start_time) # inference time budget
284
536
 
285
537
  def _get_inference_data_loader(
286
538
  self,
@@ -305,6 +557,13 @@ class ChronosModel(AbstractTimeSeriesModel):
305
557
  on_batch=timeout_callback(seconds=time_limit),
306
558
  )
307
559
 
560
+ def _get_context_length(self, data: TimeSeriesDataFrame) -> int:
561
+ context_length = self.context_length or min(
562
+ data.num_timesteps_per_item().max(),
563
+ self.maximum_context_length,
564
+ )
565
+ return context_length
566
+
308
567
  def _predict(
309
568
  self,
310
569
  data: TimeSeriesDataFrame,
@@ -319,15 +578,13 @@ class ChronosModel(AbstractTimeSeriesModel):
319
578
  # Note that this is independent of the model's own context length set in the model's config file.
320
579
  # For example, if the context_length is set to 2048 here but the model expects context length
321
580
  # (according to its config.json file) of 512, it will further truncate the series during inference.
322
- context_length = self.context_length or min(
323
- data.num_timesteps_per_item().max(),
324
- self.maximum_context_length,
325
- )
581
+ context_length = self._get_context_length(data)
326
582
 
327
583
  with warning_filter(all_warnings=True):
328
584
  import torch
329
585
 
330
586
  if self.model_pipeline is None:
587
+ # FIXME: optimization_strategy is ignored when model is fine-tuned
331
588
  # load model pipeline to device memory
332
589
  self.load_model_pipeline()
333
590
 
@@ -366,7 +623,7 @@ class ChronosModel(AbstractTimeSeriesModel):
366
623
  return TimeSeriesDataFrame(df)
367
624
 
368
625
  def _more_tags(self) -> Dict:
369
- return {"allow_nan": True}
626
+ return {"allow_nan": True, "can_use_val_data": self._get_model_params()["fine_tune"]}
370
627
 
371
628
  def score_and_cache_oof(
372
629
  self,
@@ -2,12 +2,15 @@
2
2
 
3
3
  from enum import Enum
4
4
  from pathlib import Path
5
- from typing import Dict, List, Optional, Tuple, Union
5
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
 
9
9
  from .utils import left_pad_and_stack_1D
10
10
 
11
+ if TYPE_CHECKING:
12
+ from transformers import PreTrainedModel
13
+
11
14
 
12
15
  class ForecastType(Enum):
13
16
  SAMPLES = "samples"
@@ -36,6 +39,16 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
36
39
  "float64": torch.float64,
37
40
  }
38
41
 
42
+ def __init__(self, inner_model: "PreTrainedModel"):
43
+ """
44
+ Parameters
45
+ ----------
46
+ inner_model : PreTrainedModel
47
+ A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
48
+ """
49
+ # for easy access to the inner HF-style model
50
+ self.inner_model = inner_model
51
+
39
52
  def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
40
53
  if isinstance(context, list):
41
54
  context = left_pad_and_stack_1D(context)