chronos-forecasting 2.2.0__tar.gz → 2.2.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.gitignore +1 -3
  2. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/PKG-INFO +1 -1
  3. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/evaluate.py +3 -3
  4. chronos_forecasting-2.2.0rc2/src/chronos/__about__.py +1 -0
  5. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/base.py +0 -5
  6. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/pipeline.py +12 -41
  7. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/df_utils.py +37 -20
  8. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/test_chronos2.py +44 -92
  9. chronos_forecasting-2.2.0/src/chronos/__about__.py +0 -1
  10. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.gitattributes +0 -0
  11. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  12. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  13. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.github/workflows/ci.yml +0 -0
  14. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.github/workflows/eval-model.yml +0 -0
  15. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/.github/workflows/publish-to-pypi.yml +0 -0
  16. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/CITATION.cff +0 -0
  17. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/CODE_OF_CONDUCT.md +0 -0
  18. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/CONTRIBUTING.md +0 -0
  19. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/LICENSE +0 -0
  20. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/NOTICE +0 -0
  21. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/README.md +0 -0
  22. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/ci/evaluate/backtest_config.yaml +0 -0
  23. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/notebooks/chronos-2-quickstart.ipynb +0 -0
  24. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/notebooks/deploy-chronos-to-amazon-sagemaker.ipynb +0 -0
  25. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/pyproject.toml +0 -0
  26. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/README.md +0 -0
  27. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/agg-relative-score.py +0 -0
  28. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/configs/in-domain.yaml +0 -0
  29. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/configs/zero-shot.yaml +0 -0
  30. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv +0 -0
  31. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-base-in-domain.csv +0 -0
  32. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-base-zero-shot.csv +0 -0
  33. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-mini-agg-rel-scores.csv +0 -0
  34. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-mini-in-domain.csv +0 -0
  35. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv +0 -0
  36. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-small-agg-rel-scores.csv +0 -0
  37. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-small-in-domain.csv +0 -0
  38. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-small-zero-shot.csv +0 -0
  39. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-tiny-agg-rel-scores.csv +0 -0
  40. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv +0 -0
  41. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv +0 -0
  42. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-base-agg-rel-scores.csv +0 -0
  43. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-base-in-domain.csv +0 -0
  44. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-base-zero-shot.csv +0 -0
  45. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-large-agg-rel-scores.csv +0 -0
  46. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-large-in-domain.csv +0 -0
  47. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-large-zero-shot.csv +0 -0
  48. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-mini-agg-rel-scores.csv +0 -0
  49. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-mini-in-domain.csv +0 -0
  50. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-mini-zero-shot.csv +0 -0
  51. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-small-agg-rel-scores.csv +0 -0
  52. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-small-in-domain.csv +0 -0
  53. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-small-zero-shot.csv +0 -0
  54. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-tiny-agg-rel-scores.csv +0 -0
  55. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-tiny-in-domain.csv +0 -0
  56. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv +0 -0
  57. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/seasonal-naive-in-domain.csv +0 -0
  58. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/evaluation/results/seasonal-naive-zero-shot.csv +0 -0
  59. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/kernel-synth.py +0 -0
  60. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/configs/chronos-gpt2.yaml +0 -0
  61. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/configs/chronos-t5-base.yaml +0 -0
  62. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/configs/chronos-t5-large.yaml +0 -0
  63. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/configs/chronos-t5-mini.yaml +0 -0
  64. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/configs/chronos-t5-small.yaml +0 -0
  65. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/configs/chronos-t5-tiny.yaml +0 -0
  66. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/scripts/training/train.py +0 -0
  67. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/__init__.py +0 -0
  68. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/boto_utils.py +0 -0
  69. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos.py +0 -0
  70. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/__init__.py +0 -0
  71. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/config.py +0 -0
  72. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/dataset.py +0 -0
  73. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/layers.py +0 -0
  74. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/model.py +0 -0
  75. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos2/trainer.py +0 -0
  76. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/chronos_bolt.py +0 -0
  77. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/src/chronos/utils.py +0 -0
  78. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/__init__.py +0 -0
  79. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos-bolt-model/config.json +0 -0
  80. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos-bolt-model/model.safetensors +0 -0
  81. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos-model/config.json +0 -0
  82. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos-model/generation_config.json +0 -0
  83. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos-model/pytorch_model.bin +0 -0
  84. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos2-lora/adapter_config.json +0 -0
  85. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos2-lora/adapter_model.safetensors +0 -0
  86. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos2-model/config.json +0 -0
  87. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/dummy-chronos2-model/model.safetensors +0 -0
  88. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/test_chronos.py +0 -0
  89. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/test_chronos_bolt.py +0 -0
  90. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/test_utils.py +0 -0
  91. {chronos_forecasting-2.2.0 → chronos_forecasting-2.2.0rc2}/test/util.py +0 -0
@@ -160,6 +160,4 @@ cython_debug/
160
160
  #.idea/
161
161
 
162
162
  # macOS stuff
163
- .DS_store
164
-
165
- chronos-2-finetuned
163
+ .DS_store
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chronos-forecasting
3
- Version: 2.2.0
3
+ Version: 2.2.0rc2
4
4
  Summary: Chronos: Pretrained models for time series forecasting
5
5
  Project-URL: Homepage, https://github.com/amazon-science/chronos-forecasting
6
6
  Project-URL: Issues, https://github.com/amazon-science/chronos-forecasting/issues
@@ -295,7 +295,7 @@ def chronos_2(
295
295
  device: str = "cuda",
296
296
  torch_dtype: str = "float32",
297
297
  batch_size: int = 32,
298
- cross_learning: bool = False,
298
+ predict_batches_jointly: bool = False,
299
299
  ):
300
300
  """Evaluate Chronos-2 models.
301
301
 
@@ -316,7 +316,7 @@ def chronos_2(
316
316
  batch_size : int, optional, default = 32
317
317
  Batch size for inference. For Chronos-Bolt models, significantly larger
318
318
  batch sizes can be used
319
- cross_learning: bool, optional, default = False
319
+ predict_batches_jointly: bool, optional, default = False
320
320
  If True, cross-learning is enables and model makes joint predictions for all
321
321
  items in the batch
322
322
  """
@@ -335,7 +335,7 @@ def chronos_2(
335
335
  metrics_path=metrics_path,
336
336
  model_id=model_id,
337
337
  batch_size=batch_size,
338
- cross_learning=cross_learning,
338
+ predict_batches_jointly=predict_batches_jointly,
339
339
  )
340
340
 
341
341
 
@@ -0,0 +1 @@
1
+ __version__ = "2.2.0rc2"
@@ -141,7 +141,6 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
141
141
  target: str = "target",
142
142
  prediction_length: int | None = None,
143
143
  quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
144
- validate_inputs: bool = True,
145
144
  **predict_kwargs,
146
145
  ) -> "pd.DataFrame":
147
146
  """
@@ -163,9 +162,6 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
163
162
  Number of steps to predict for each time series
164
163
  quantile_levels
165
164
  Quantile levels to compute
166
- validate_inputs
167
- When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
168
- regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
169
165
  **predict_kwargs
170
166
  Additional arguments passed to predict_quantiles
171
167
 
@@ -200,7 +196,6 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
200
196
  timestamp_column=timestamp_column,
201
197
  target_columns=[target],
202
198
  prediction_length=prediction_length,
203
- validate_inputs=validate_inputs,
204
199
  )
205
200
 
206
201
  # NOTE: any covariates, if present, are ignored here
@@ -19,6 +19,7 @@ from transformers import AutoConfig
19
19
  from transformers.utils.import_utils import is_peft_available
20
20
  from transformers.utils.peft_utils import find_adapter_config_file
21
21
 
22
+
22
23
  import chronos.chronos2
23
24
  from chronos.base import BaseChronosPipeline, ForecastType
24
25
  from chronos.chronos2 import Chronos2Model
@@ -113,7 +114,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
113
114
  min_past: int | None = None,
114
115
  finetuned_ckpt_name: str = "finetuned-ckpt",
115
116
  callbacks: list["TrainerCallback"] | None = None,
116
- remove_printer_callback: bool = False,
117
117
  **extra_trainer_kwargs,
118
118
  ) -> "Chronos2Pipeline":
119
119
  """
@@ -156,8 +156,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
156
156
  The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt"
157
157
  callbacks
158
158
  A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
159
- remove_printer_callback
160
- If True, all instances of `PrinterCallback` are removed from callbacks
161
159
  **extra_trainer_kwargs
162
160
  Extra kwargs are directly forwarded to `TrainingArguments`
163
161
 
@@ -167,7 +165,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
167
165
  """
168
166
 
169
167
  import torch.cuda
170
- from transformers.trainer_callback import PrinterCallback
171
168
  from transformers.training_args import TrainingArguments
172
169
 
173
170
  if finetune_mode == "lora":
@@ -178,7 +175,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
178
175
  "`peft` is required for `finetune_mode='lora'`. Please install it with `pip install peft`. Falling back to `finetune_mode='full'`."
179
176
  )
180
177
  finetune_mode = "full"
181
- lora_config = None
182
178
 
183
179
  from chronos.chronos2.trainer import Chronos2Trainer, EvaluateAndSaveFinalStepCallback
184
180
 
@@ -269,7 +265,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
269
265
  report_to="none",
270
266
  max_steps=num_steps,
271
267
  gradient_accumulation_steps=1,
272
- dataloader_num_workers=0,
268
+ dataloader_num_workers=1,
273
269
  tf32=has_sm80 and not use_cpu,
274
270
  bf16=has_sm80 and not use_cpu,
275
271
  save_only_model=True,
@@ -326,19 +322,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
326
322
  eval_dataset=eval_dataset,
327
323
  callbacks=callbacks,
328
324
  )
329
-
330
- if remove_printer_callback:
331
- trainer.pop_callback(PrinterCallback)
332
-
333
325
  trainer.train()
334
326
 
335
- # update context_length and max_output_patches, if the model was fine-tuned with larger values
336
- model.chronos_config.context_length = max(model.chronos_config.context_length, context_length)
327
+ # update max_output_patches, if the model was fine-tuned with longer prediction_length
337
328
  model.chronos_config.max_output_patches = max(
338
329
  model.chronos_config.max_output_patches, math.ceil(prediction_length / self.model_output_patch_size)
339
330
  )
340
- # update chronos_config in model's config, so it is saved correctly
341
- model.config.chronos_config = model.chronos_config.__dict__
342
331
 
343
332
  # Create a new pipeline with the fine-tuned model
344
333
  finetuned_pipeline = Chronos2Pipeline(model=model)
@@ -447,7 +436,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
447
436
  prediction_length: int | None = None,
448
437
  batch_size: int = 256,
449
438
  context_length: int | None = None,
450
- cross_learning: bool = False,
439
+ predict_batches_jointly: bool = False,
451
440
  limit_prediction_length: bool = False,
452
441
  **kwargs,
453
442
  ) -> list[torch.Tensor]:
@@ -533,7 +522,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
533
522
  will be lower than this value, by default 256
534
523
  context_length
535
524
  The maximum context length used during for inference, by default set to the model's default context length
536
- cross_learning
525
+ predict_batches_jointly
537
526
  If True, cross-learning is enabled, i.e., all the tasks in `inputs` will be predicted jointly and the model will share information across all inputs, by default False
538
527
  The following must be noted when using cross-learning:
539
528
  - Cross-learning doesn't always improve forecast accuracy and must be tested for individual use cases.
@@ -553,14 +542,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
553
542
  if prediction_length is None:
554
543
  prediction_length = model_prediction_length
555
544
 
556
- if kwargs.get("predict_batches_jointly") is not None:
557
- warnings.warn(
558
- "The `predict_batches_jointly` argument is deprecated and will be removed in a future version. "
559
- "Please use `cross_learning=True` to enable the cross-learning mode.",
560
- category=FutureWarning,
561
- stacklevel=2,
562
- )
563
- cross_learning = kwargs.pop("predict_batches_jointly")
564
545
  # The maximum number of output patches to generate in a single forward pass before the long-horizon heuristic kicks in. Note: A value larger
565
546
  # than the model's default max_output_patches may lead to degradation in forecast accuracy, defaults to a model-specific value
566
547
  max_output_patches = kwargs.pop("max_output_patches", self.max_output_patches)
@@ -618,7 +599,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
618
599
  batch_future_covariates = batch["future_covariates"]
619
600
  batch_target_idx_ranges = batch["target_idx_ranges"]
620
601
 
621
- if cross_learning:
602
+ if predict_batches_jointly:
622
603
  batch_group_ids = torch.zeros_like(batch_group_ids)
623
604
 
624
605
  batch_prediction = self._predict_batch(
@@ -809,8 +790,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
809
790
  prediction_length: int | None = None,
810
791
  quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
811
792
  batch_size: int = 256,
812
- context_length: int | None = None,
813
- cross_learning: bool = False,
814
793
  validate_inputs: bool = True,
815
794
  **predict_kwargs,
816
795
  ) -> "pd.DataFrame":
@@ -841,18 +820,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
841
820
  The batch size used for prediction. Note that the batch size here means the number of time series, including target(s) and covariates,
842
821
  which are input into the model. If your data has multiple target and/or covariates, the effective number of time series tasks in a batch
843
822
  will be lower than this value, by default 256
844
- context_length
845
- The maximum context length used during for inference, by default set to the model's default context length
846
- cross_learning
847
- If True, cross-learning is enabled, i.e., all the tasks in `inputs` will be predicted jointly and the model will share information across all inputs, by default False
848
- The following must be noted when using cross-learning:
849
- - Cross-learning doesn't always improve forecast accuracy and must be tested for individual use cases.
850
- - Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
851
- For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
852
- - Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
853
823
  validate_inputs
854
- When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
855
- regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
824
+ When True, the dataframe(s) will be validated before prediction
856
825
  **predict_kwargs
857
826
  Additional arguments passed to predict_quantiles
858
827
 
@@ -893,8 +862,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
893
862
  quantile_levels=quantile_levels,
894
863
  limit_prediction_length=False,
895
864
  batch_size=batch_size,
896
- context_length=context_length,
897
- cross_learning=cross_learning,
898
865
  **predict_kwargs,
899
866
  )
900
867
  # since predict_df tasks are homogenous by input design, we can safely stack the list of tensors into a single tensor
@@ -1056,7 +1023,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
1056
1023
  finetune_kwargs["prediction_length"] = first_window.horizon
1057
1024
  finetune_kwargs["batch_size"] = finetune_kwargs.get("batch_size", batch_size)
1058
1025
 
1059
- pipeline = self.fit(inputs=inputs, **finetune_kwargs)
1026
+ try:
1027
+ pipeline = self.fit(inputs=inputs, **finetune_kwargs)
1028
+ except Exception as e:
1029
+ msg = f"Finetuning failed with error: {e}. Continuing with the pretrained model."
1030
+ warnings.warn(msg, category=UserWarning, stacklevel=2)
1060
1031
 
1061
1032
  predictions_per_window = []
1062
1033
  inference_time_s = 0.0
@@ -185,13 +185,25 @@ def validate_df_inputs(
185
185
  if context_ids != future_ids:
186
186
  raise ValueError("future_df must contain the same time series IDs as df")
187
187
 
188
- future_series_lengths = future_df[id_column].value_counts(sort=False)
189
- if (future_series_lengths != prediction_length).any():
190
- invalid_series = future_series_lengths[future_series_lengths != prediction_length]
191
- raise ValueError(
192
- f"future_df must contain {prediction_length=} values for each series, "
193
- f"but found series with different lengths: {invalid_series.to_dict()}"
194
- )
188
+ future_series_lengths = future_df[id_column].value_counts(sort=False).to_list()
189
+
190
+ # Validate future series lengths match prediction_length
191
+ future_start_idx = 0
192
+ future_timestamps_index = pd.DatetimeIndex(future_df[timestamp_column])
193
+ for future_length in future_series_lengths:
194
+ future_timestamps = future_timestamps_index[future_start_idx : future_start_idx + future_length]
195
+ future_series_id = future_df[id_column].iloc[future_start_idx]
196
+ if future_length != prediction_length:
197
+ raise ValueError(
198
+ f"Future covariates all time series must have length {prediction_length}, got {future_length} for series {future_series_id}"
199
+ )
200
+ if future_length < 3 or inferred_freq != validate_freq(future_timestamps, future_series_id):
201
+ raise ValueError(
202
+ f"Future covariates must have the same frequency as context, found series {future_series_id} with a different frequency"
203
+ )
204
+ future_start_idx += future_length
205
+
206
+ assert len(series_lengths) == len(future_series_lengths)
195
207
 
196
208
  return df, future_df, inferred_freq, series_lengths, original_order
197
209
 
@@ -291,16 +303,10 @@ def convert_df_input_to_list_of_dicts_input(
291
303
  past_covariates_dict = {
292
304
  col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns
293
305
  }
294
- future_covariates_dict = {}
295
306
  if future_df is not None:
296
- for col in future_df.columns.drop([id_column, timestamp_column]):
297
- future_covariates_dict[col] = future_df[col].to_numpy()
298
- if validate_inputs:
299
- if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any():
300
- raise ValueError(
301
- "future_df timestamps do not match the expected prediction timestamps. "
302
- "You can disable this check by setting `validate_inputs=False`"
303
- )
307
+ future_covariates_dict = {
308
+ col: future_df[col].to_numpy() for col in future_df.columns if col not in [id_column, timestamp_column]
309
+ }
304
310
 
305
311
  for i in range(len(series_lengths)):
306
312
  start_idx, end_idx = indptr[i], indptr[i + 1]
@@ -310,12 +316,23 @@ def convert_df_input_to_list_of_dicts_input(
310
316
  prediction_timestamps[series_id] = prediction_timestamps_array[future_start_idx:future_end_idx]
311
317
  task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_array[:, start_idx:end_idx]}
312
318
 
319
+ # Handle covariates if present
313
320
  if len(past_covariates_dict) > 0:
314
321
  task["past_covariates"] = {col: values[start_idx:end_idx] for col, values in past_covariates_dict.items()}
315
- if len(future_covariates_dict) > 0:
316
- task["future_covariates"] = {
317
- col: values[future_start_idx:future_end_idx] for col, values in future_covariates_dict.items()
318
- }
322
+
323
+ # Handle future covariates
324
+ if future_df is not None:
325
+ first_future_timestamp = future_df[timestamp_column].iloc[future_start_idx]
326
+ assert first_future_timestamp == prediction_timestamps[series_id][0], (
327
+ f"the first timestamp in future_df must be the first forecast timestamp, found mismatch "
328
+ f"({first_future_timestamp} != {prediction_timestamps[series_id][0]}) in series {series_id}"
329
+ )
330
+
331
+ if len(future_covariates_dict) > 0:
332
+ task["future_covariates"] = {
333
+ col: values[future_start_idx:future_end_idx] for col, values in future_covariates_dict.items()
334
+ }
335
+
319
336
  inputs.append(task)
320
337
 
321
338
  assert len(inputs) == len(series_lengths)
@@ -421,39 +421,43 @@ def test_pipeline_can_evaluate_on_dummy_fev_task(pipeline, task_kwargs):
421
421
 
422
422
 
423
423
  @pytest.mark.parametrize(
424
- "context_setup, future_setup",
424
+ "context_setup, future_setup, expected_rows",
425
425
  [
426
426
  # Targets only
427
- ({}, None),
427
+ ({}, None, 6), # 2 series * 3 predictions
428
428
  # Multiple targets with different context lengths
429
- ({"target_cols": ["sales", "revenue", "profit"], "n_points": [10, 17]}, None),
429
+ (
430
+ {"target_cols": ["sales", "revenue", "profit"], "n_points": [10, 17]},
431
+ None,
432
+ 18,
433
+ ), # 2 series * 3 targets * 3 predictions
430
434
  # With past covariates
431
- ({"covariates": ["cov1"]}, None),
435
+ ({"covariates": ["cov1"]}, None, 6),
432
436
  # With future covariates
433
- ({"covariates": ["cov1"]}, {"covariates": ["cov1"]}),
437
+ ({"covariates": ["cov1"]}, {"covariates": ["cov1"], "n_points": [3, 3]}, 6),
434
438
  # With past-only and future covariates
435
- ({"covariates": ["cov1", "cov2"]}, {"covariates": ["cov1"]}),
439
+ ({"covariates": ["cov1", "cov2"]}, {"covariates": ["cov1"], "n_points": [3, 3]}, 6),
436
440
  # With past-only and future covariates and different series order
437
441
  (
438
442
  {"series_ids": ["B", "C", "A", "Z"], "n_points": [10, 20, 100, 256], "covariates": ["cov1", "cov2"]},
439
- {"series_ids": ["B", "C", "A", "Z"], "covariates": ["cov1"]},
443
+ {
444
+ "series_ids": ["B", "C", "A", "Z"],
445
+ "covariates": ["cov1"],
446
+ "n_points": [3, 3, 3, 3],
447
+ },
448
+ 12,
440
449
  ),
441
450
  ],
442
451
  )
443
452
  @pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
444
- @pytest.mark.parametrize("prediction_length", [1, 4])
445
453
  @pytest.mark.parametrize("validate_inputs", [True, False])
446
454
  def test_predict_df_works_for_valid_inputs(
447
- pipeline, context_setup, future_setup, freq, validate_inputs, prediction_length
455
+ pipeline, context_setup, future_setup, expected_rows, freq, validate_inputs
448
456
  ):
457
+ prediction_length = 3
449
458
  df = create_df(**context_setup, freq=freq)
450
459
  forecast_start_times = get_forecast_start_times(df, freq)
451
- if future_setup:
452
- series_ids = future_setup.get("series_ids", ["A", "B"])
453
- future_setup_with_n_points = {**future_setup, "n_points": [prediction_length] * len(series_ids)}
454
- future_df = create_future_df(forecast_start_times, **future_setup_with_n_points, freq=freq)
455
- else:
456
- future_df = None
460
+ future_df = create_future_df(forecast_start_times, **future_setup, freq=freq) if future_setup else None
457
461
 
458
462
  series_ids = context_setup.get("series_ids", ["A", "B"])
459
463
  target_columns = context_setup.get("target_cols", ["target"])
@@ -467,7 +471,6 @@ def test_predict_df_works_for_valid_inputs(
467
471
  validate_inputs=validate_inputs,
468
472
  )
469
473
 
470
- expected_rows = n_series * n_targets * prediction_length
471
474
  assert len(result) == expected_rows
472
475
  assert "item_id" in result.columns and np.all(
473
476
  result["item_id"].to_numpy() == np.array(series_ids).repeat(n_targets * prediction_length)
@@ -577,78 +580,24 @@ def test_predict_df_with_future_df_missing_series_raises_error(pipeline):
577
580
  pipeline.predict_df(df, future_df=future_df)
578
581
 
579
582
 
580
- def test_predict_df_with_future_df_with_different_freq_raises_error(pipeline):
581
- df = create_df(series_ids=["A", "B"], covariates=["cov1"], freq="h")
582
- future_df = create_future_df(
583
- get_forecast_start_times(df), series_ids=["A", "B"], n_points=[3, 3], covariates=["cov1"], freq="D"
584
- )
585
-
586
- with pytest.raises(ValueError, match="future_df timestamps do not match"):
587
- pipeline.predict_df(df, future_df=future_df, prediction_length=3)
588
-
589
-
590
583
  def test_predict_df_with_future_df_with_different_lengths_raises_error(pipeline):
591
584
  df = create_df(series_ids=["A", "B"], covariates=["cov1"])
592
585
  future_df = create_future_df(
593
586
  get_forecast_start_times(df), series_ids=["A", "B"], n_points=[3, 7], covariates=["cov1"]
594
587
  )
595
588
 
596
- with pytest.raises(ValueError, match="future_df must contain prediction"):
589
+ with pytest.raises(ValueError, match="all time series must have length"):
597
590
  pipeline.predict_df(df, future_df=future_df, prediction_length=3)
598
591
 
599
592
 
600
- @pytest.mark.parametrize(
601
- "context_setup, future_setup",
602
- [
603
- # Targets only
604
- ({}, None),
605
- # Multiple targets with different context lengths
606
- ({"target_cols": ["sales", "revenue", "profit"], "n_points": [10, 17]}, None),
607
- # With past covariates
608
- ({"covariates": ["cov1"]}, None),
609
- # With future covariates
610
- ({"covariates": ["cov1"]}, {"covariates": ["cov1"]}),
611
- # With past-only and future covariates
612
- ({"covariates": ["cov1", "cov2"]}, {"covariates": ["cov1"]}),
613
- # With past-only and future covariates and different series order
614
- (
615
- {"series_ids": ["B", "C", "A", "Z"], "n_points": [10, 20, 100, 256], "covariates": ["cov1", "cov2"]},
616
- {"series_ids": ["B", "C", "A", "Z"], "covariates": ["cov1"]},
617
- ),
618
- ],
619
- )
620
- @pytest.mark.parametrize("prediction_length", [1, 4])
621
- def test_predict_df_outputs_different_results_with_cross_learning_enabled(
622
- pipeline, context_setup, future_setup, prediction_length
623
- ):
624
- freq = "h"
625
- df = create_df(**context_setup, freq=freq)
626
- forecast_start_times = get_forecast_start_times(df, freq)
627
- if future_setup:
628
- series_ids = future_setup.get("series_ids", ["A", "B"])
629
- future_setup_with_n_points = {**future_setup, "n_points": [prediction_length] * len(series_ids)}
630
- future_df = create_future_df(forecast_start_times, **future_setup_with_n_points, freq=freq)
631
- else:
632
- future_df = None
633
-
634
- series_ids = context_setup.get("series_ids", ["A", "B"])
635
- target_columns = context_setup.get("target_cols", ["target"])
636
- result_with_cross_learning = pipeline.predict_df(
637
- df,
638
- future_df=future_df,
639
- target=target_columns,
640
- prediction_length=prediction_length,
641
- cross_learning=True,
642
- )
643
- result_without_cross_learning = pipeline.predict_df(
644
- df,
645
- future_df=future_df,
646
- target=target_columns,
647
- prediction_length=prediction_length,
648
- cross_learning=False,
593
+ def test_predict_df_with_future_df_with_different_freq_raises_error(pipeline):
594
+ df = create_df(series_ids=["A", "B"], covariates=["cov1"], freq="h")
595
+ future_df = create_future_df(
596
+ get_forecast_start_times(df), series_ids=["A", "B"], n_points=[3, 3], covariates=["cov1"], freq="D"
649
597
  )
650
598
 
651
- assert not np.array_equal(result_with_cross_learning["predictions"], result_without_cross_learning["predictions"])
599
+ with pytest.raises(ValueError, match="must have the same frequency as context"):
600
+ pipeline.predict_df(df, future_df=future_df, prediction_length=3)
652
601
 
653
602
 
654
603
  @pytest.mark.parametrize(
@@ -925,36 +874,40 @@ def test_when_input_time_series_are_too_short_then_finetuning_raises_error(pipel
925
874
 
926
875
 
927
876
  @pytest.mark.parametrize(
928
- "context_setup, future_setup",
877
+ "context_setup, future_setup, expected_rows",
929
878
  [
930
879
  # Targets only
931
- ({}, None),
880
+ ({}, None, 6), # 2 series * 3 predictions
932
881
  # Multiple targets with different context lengths
933
- ({"target_cols": ["sales", "revenue", "profit"], "n_points": [10, 17]}, None),
882
+ (
883
+ {"target_cols": ["sales", "revenue", "profit"], "n_points": [10, 17]},
884
+ None,
885
+ 18,
886
+ ), # 2 series * 3 targets * 3 predictions
934
887
  # With past covariates
935
- ({"covariates": ["cov1"]}, None),
888
+ ({"covariates": ["cov1"]}, None, 6),
936
889
  # With future covariates
937
- ({"covariates": ["cov1"]}, {"covariates": ["cov1"]}),
890
+ ({"covariates": ["cov1"]}, {"covariates": ["cov1"], "n_points": [3, 3]}, 6),
938
891
  # With past-only and future covariates
939
- ({"covariates": ["cov1", "cov2"]}, {"covariates": ["cov1"]}),
892
+ ({"covariates": ["cov1", "cov2"]}, {"covariates": ["cov1"], "n_points": [3, 3]}, 6),
940
893
  # With past-only and future covariates and different series order
941
894
  (
942
895
  {"series_ids": ["B", "C", "A", "Z"], "n_points": [10, 20, 100, 256], "covariates": ["cov1", "cov2"]},
943
- {"series_ids": ["B", "C", "A", "Z"], "covariates": ["cov1"]},
896
+ {
897
+ "series_ids": ["B", "C", "A", "Z"],
898
+ "covariates": ["cov1"],
899
+ "n_points": [3, 3, 3, 3],
900
+ },
901
+ 12,
944
902
  ),
945
903
  ],
946
904
  )
947
905
  @pytest.mark.parametrize("freq", ["h", "D", "ME"])
948
- def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future_setup, freq):
906
+ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future_setup, expected_rows, freq):
949
907
  prediction_length = 3
950
908
  df = create_df(**context_setup, freq=freq)
951
909
  forecast_start_times = get_forecast_start_times(df, freq)
952
- if future_setup:
953
- series_ids = future_setup.get("series_ids", ["A", "B"])
954
- future_setup_with_n_points = {**future_setup, "n_points": [prediction_length] * len(series_ids)}
955
- future_df = create_future_df(forecast_start_times, **future_setup_with_n_points, freq=freq)
956
- else:
957
- future_df = None
910
+ future_df = create_future_df(forecast_start_times, **future_setup, freq=freq) if future_setup else None
958
911
 
959
912
  series_ids = context_setup.get("series_ids", ["A", "B"])
960
913
  target_columns = context_setup.get("target_cols", ["target"])
@@ -987,7 +940,6 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
987
940
  )
988
941
 
989
942
  # Check predictions from the fine-tuned model are valid
990
- expected_rows = n_series * n_targets * prediction_length
991
943
  assert len(result) == expected_rows
992
944
  assert "item_id" in result.columns and np.all(
993
945
  result["item_id"].to_numpy() == np.array(series_ids).repeat(n_targets * prediction_length)
@@ -1 +0,0 @@
1
- __version__ = "2.2.0"