autogluon.timeseries 1.4.1b20250818__tar.gz → 1.4.1b20250820__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (73) hide show
  1. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/PKG-INFO +1 -1
  2. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/dataset/ts_dataframe.py +13 -0
  3. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/abstract/abstract_timeseries_model.py +9 -4
  4. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/autogluon_tabular/mlforecast.py +4 -0
  5. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/autogluon_tabular/per_step.py +1 -0
  6. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/model.py +1 -0
  7. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/gluonts/models.py +15 -0
  8. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/local/naive.py +4 -0
  9. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/local/npts.py +1 -0
  10. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/local/statsforecast.py +16 -0
  11. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/presets.py +52 -130
  12. autogluon.timeseries-1.4.1b20250820/src/autogluon/timeseries/models/registry.py +65 -0
  13. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/predictor.py +3 -0
  14. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/features.py +1 -1
  15. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/version.py +1 -1
  16. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/PKG-INFO +1 -1
  17. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/SOURCES.txt +1 -0
  18. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/requires.txt +4 -4
  19. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/setup.cfg +0 -0
  20. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/setup.py +0 -0
  21. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/__init__.py +0 -0
  22. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/configs/__init__.py +0 -0
  23. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/configs/presets_configs.py +0 -0
  24. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/dataset/__init__.py +0 -0
  25. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/evaluator.py +0 -0
  26. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/learner.py +0 -0
  27. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/metrics/__init__.py +0 -0
  28. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/metrics/abstract.py +0 -0
  29. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/metrics/point.py +0 -0
  30. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/metrics/quantile.py +0 -0
  31. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/metrics/utils.py +0 -0
  32. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/__init__.py +0 -0
  33. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/abstract/__init__.py +0 -0
  34. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/abstract/model_trial.py +0 -0
  35. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/abstract/tunable.py +0 -0
  36. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/autogluon_tabular/__init__.py +0 -0
  37. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/autogluon_tabular/transforms.py +0 -0
  38. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/autogluon_tabular/utils.py +0 -0
  39. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/__init__.py +0 -0
  40. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -0
  41. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/pipeline/base.py +0 -0
  42. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -0
  43. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -0
  44. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/chronos/pipeline/utils.py +0 -0
  45. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/ensemble/__init__.py +0 -0
  46. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/ensemble/abstract.py +0 -0
  47. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/ensemble/basic.py +0 -0
  48. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/ensemble/greedy.py +0 -0
  49. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/gluonts/__init__.py +0 -0
  50. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/gluonts/abstract.py +0 -0
  51. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/gluonts/dataset.py +0 -0
  52. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/local/__init__.py +0 -0
  53. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/local/abstract_local_model.py +0 -0
  54. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/multi_window/__init__.py +0 -0
  55. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/models/multi_window/multi_window_model.py +0 -0
  56. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/regressor.py +0 -0
  57. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/splitter.py +0 -0
  58. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/trainer.py +0 -0
  59. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/transforms/__init__.py +0 -0
  60. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/transforms/covariate_scaler.py +0 -0
  61. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/transforms/target_scaler.py +0 -0
  62. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/__init__.py +0 -0
  63. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/datetime/__init__.py +0 -0
  64. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/datetime/base.py +0 -0
  65. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/datetime/lags.py +0 -0
  66. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/datetime/seasonality.py +0 -0
  67. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/datetime/time_features.py +0 -0
  68. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/forecast.py +0 -0
  69. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon/timeseries/utils/warning_filters.py +0 -0
  70. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/dependency_links.txt +0 -0
  71. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/namespace_packages.txt +0 -0
  72. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/top_level.txt +0 -0
  73. {autogluon.timeseries-1.4.1b20250818 → autogluon.timeseries-1.4.1b20250820}/src/autogluon.timeseries.egg-info/zip-safe +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.timeseries
3
- Version: 1.4.1b20250818
3
+ Version: 1.4.1b20250820
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -1124,6 +1124,19 @@ class TimeSeriesDataFrame(pd.DataFrame):
1124
1124
 
1125
1125
  @overload
1126
1126
  def __new__(cls, data: pd.DataFrame, static_features: Optional[pd.DataFrame] = None) -> Self: ... # type: ignore
1127
+ @overload
1128
+ def __new__(
1129
+ cls,
1130
+ data: Union[pd.DataFrame, str, Path, Iterable],
1131
+ static_features: Optional[Union[pd.DataFrame, str, Path]] = None,
1132
+ id_column: Optional[str] = None,
1133
+ timestamp_column: Optional[str] = None,
1134
+ num_cpus: int = -1,
1135
+ *args,
1136
+ **kwargs,
1137
+ ) -> Self:
1138
+ """This overload is needed since in pandas, during type checking, the default constructor resolves to __new__"""
1139
+ ...
1127
1140
 
1128
1141
  @overload
1129
1142
  def __getitem__(self, items: List[str]) -> Self: ... # type: ignore
@@ -21,6 +21,7 @@ from autogluon.core.models import ModelBase
21
21
  from autogluon.core.utils.exceptions import TimeLimitExceeded
22
22
  from autogluon.timeseries.dataset import TimeSeriesDataFrame
23
23
  from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_metric
24
+ from autogluon.timeseries.models.registry import ModelRegistry
24
25
  from autogluon.timeseries.regressor import CovariateRegressor, get_covariate_regressor
25
26
  from autogluon.timeseries.transforms import CovariateScaler, TargetScaler, get_covariate_scaler, get_target_scaler
26
27
  from autogluon.timeseries.utils.features import CovariateMetadata
@@ -376,11 +377,13 @@ class TimeSeriesModelBase(ModelBase, ABC):
376
377
  return template
377
378
 
378
379
 
379
- class AbstractTimeSeriesModel(TimeSeriesModelBase, TimeSeriesTunable, ABC):
380
+ class AbstractTimeSeriesModel(TimeSeriesModelBase, TimeSeriesTunable, metaclass=ModelRegistry):
380
381
  """Abstract base class for all time series models that take historical data as input and
381
382
  make predictions for the forecast horizon.
382
383
  """
383
384
 
385
+ ag_priority: int = 0
386
+
384
387
  def __init__(
385
388
  self,
386
389
  path: Optional[str] = None,
@@ -427,7 +430,7 @@ class AbstractTimeSeriesModel(TimeSeriesModelBase, TimeSeriesTunable, ABC):
427
430
  @property
428
431
  def allowed_hyperparameters(self) -> List[str]:
429
432
  """List of hyperparameters allowed by the model."""
430
- return ["target_scaler", "covariate_regressor"]
433
+ return ["target_scaler", "covariate_regressor", "covariate_scaler"]
431
434
 
432
435
  def fit(
433
436
  self,
@@ -608,11 +611,13 @@ class AbstractTimeSeriesModel(TimeSeriesModelBase, TimeSeriesTunable, ABC):
608
611
  predictions = self._predict(data=data, known_covariates=known_covariates, **kwargs)
609
612
  self.covariate_regressor = covariate_regressor
610
613
 
611
- column_order = pd.Index(["mean"] + [str(q) for q in self.quantile_levels])
614
+ # Ensure that 'mean' is the leading column. Trailing columns might not match quantile_levels if self is
615
+ # a MultiWindowBacktestingModel and base_model.must_drop_median=True
616
+ column_order = pd.Index(["mean"] + [col for col in predictions.columns if col != "mean"])
612
617
  if not predictions.columns.equals(column_order):
613
618
  predictions = predictions.reindex(columns=column_order)
614
619
 
615
- # "0.5" might be missing from the quantiles if self is a wrapper (MultiWindowBacktestingModel or ensemble)
620
+ # "0.5" might be missing from the quantiles if self is a MultiWindowBacktestingModel
616
621
  if "0.5" in predictions.columns:
617
622
  if self.eval_metric.optimized_by_median:
618
623
  predictions["mean"] = predictions["0.5"]
@@ -517,6 +517,8 @@ class DirectTabularModel(AbstractMLForecastModel):
517
517
  (starting from the end of each time series).
518
518
  """
519
519
 
520
+ ag_priority = 85
521
+
520
522
  @property
521
523
  def is_quantile_model(self) -> bool:
522
524
  return self.eval_metric.needs_quantile
@@ -698,6 +700,8 @@ class RecursiveTabularModel(AbstractMLForecastModel):
698
700
  (starting from the end of each time series).
699
701
  """
700
702
 
703
+ ag_priority = 90
704
+
701
705
  def get_hyperparameters(self) -> Dict[str, Any]:
702
706
  model_params = super().get_hyperparameters()
703
707
  # We don't set 'target_scaler' if user already provided 'scaler' to avoid overriding the user-provided value
@@ -81,6 +81,7 @@ class PerStepTabularModel(AbstractTimeSeriesModel):
81
81
  If None, automatically determined based on available memory to prevent OOM errors.
82
82
  """
83
83
 
84
+ ag_priority = 70
84
85
  _dummy_freq = "D"
85
86
 
86
87
  def __init__(self, *args, **kwargs):
@@ -166,6 +166,7 @@ class ChronosModel(AbstractTimeSeriesModel):
166
166
  If True, the logs generated by transformers will NOT be removed after fine-tuning
167
167
  """
168
168
 
169
+ ag_priority = 55
169
170
  # default number of samples for prediction
170
171
  default_num_samples: int = 20
171
172
  default_model_path = "autogluon/chronos-bolt-small"
@@ -81,6 +81,8 @@ class DeepARModel(AbstractGluonTSModel):
81
81
 
82
82
  # TODO: Replace "scaling: bool" with "window_scaler": {"mean_abs", None} for consistency?
83
83
 
84
+ ag_priority = 40
85
+
84
86
  _supports_known_covariates = True
85
87
  _supports_static_features = True
86
88
 
@@ -138,6 +140,8 @@ class SimpleFeedForwardModel(AbstractGluonTSModel):
138
140
  If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
139
141
  """
140
142
 
143
+ ag_priority = 10
144
+
141
145
  def _get_estimator_class(self) -> Type[GluonTSEstimator]:
142
146
  from gluonts.torch.model.simple_feedforward import SimpleFeedForwardEstimator
143
147
 
@@ -199,6 +203,9 @@ class TemporalFusionTransformerModel(AbstractGluonTSModel):
199
203
  If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
200
204
  """
201
205
 
206
+ ag_priority = 45
207
+ ag_model_aliases = ["TFT"]
208
+
202
209
  _supports_known_covariates = True
203
210
  _supports_past_covariates = True
204
211
  _supports_cat_covariates = True
@@ -282,6 +289,8 @@ class DLinearModel(AbstractGluonTSModel):
282
289
  If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
283
290
  """
284
291
 
292
+ ag_priority = 10
293
+
285
294
  def _get_default_hyperparameters(self):
286
295
  return super()._get_default_hyperparameters() | {
287
296
  "context_length": 96,
@@ -340,6 +349,8 @@ class PatchTSTModel(AbstractGluonTSModel):
340
349
  If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
341
350
  """
342
351
 
352
+ ag_priority = 30
353
+
343
354
  _supports_known_covariates = True
344
355
 
345
356
  def _get_estimator_class(self) -> Type[GluonTSEstimator]:
@@ -416,6 +427,8 @@ class WaveNetModel(AbstractGluonTSModel):
416
427
  If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
417
428
  """
418
429
 
430
+ ag_priority = 25
431
+
419
432
  _supports_known_covariates = True
420
433
  _supports_static_features = True
421
434
  default_num_samples: int = 100
@@ -508,6 +521,8 @@ class TiDEModel(AbstractGluonTSModel):
508
521
  If True, ``lightning_logs`` directory will NOT be removed after the model finished training.
509
522
  """
510
523
 
524
+ ag_priority = 30
525
+
511
526
  _supports_known_covariates = True
512
527
  _supports_static_features = True
513
528
 
@@ -24,6 +24,7 @@ class NaiveModel(AbstractLocalModel):
24
24
  When set to -1, all CPU cores are used.
25
25
  """
26
26
 
27
+ ag_priority = 100
27
28
  allowed_local_model_args = ["seasonal_period"]
28
29
 
29
30
  def _predict_with_local_model(
@@ -66,6 +67,7 @@ class SeasonalNaiveModel(AbstractLocalModel):
66
67
  When set to -1, all CPU cores are used.
67
68
  """
68
69
 
70
+ ag_priority = 100
69
71
  allowed_local_model_args = ["seasonal_period"]
70
72
 
71
73
  def _predict_with_local_model(
@@ -99,6 +101,7 @@ class AverageModel(AbstractLocalModel):
99
101
  This significantly speeds up fitting and usually leads to no change in accuracy.
100
102
  """
101
103
 
104
+ ag_priority = 100
102
105
  allowed_local_model_args = ["seasonal_period"]
103
106
  default_max_ts_length = None
104
107
 
@@ -138,6 +141,7 @@ class SeasonalAverageModel(AbstractLocalModel):
138
141
  This significantly speeds up fitting and usually leads to no change in accuracy.
139
142
  """
140
143
 
144
+ ag_priority = 100
141
145
  allowed_local_model_args = ["seasonal_period"]
142
146
  default_max_ts_length = None
143
147
 
@@ -36,6 +36,7 @@ class NPTSModel(AbstractLocalModel):
36
36
  This significantly speeds up fitting and usually leads to no change in accuracy.
37
37
  """
38
38
 
39
+ ag_priority = 80
39
40
  allowed_local_model_args = [
40
41
  "kernel_type",
41
42
  "exp_kernel_weights",
@@ -133,6 +133,7 @@ class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
133
133
  This significantly speeds up fitting and usually leads to no change in accuracy.
134
134
  """
135
135
 
136
+ ag_priority = 60
136
137
  init_time_in_seconds = 0 # C++ models require no compilation
137
138
  allowed_local_model_args = [
138
139
  "d",
@@ -211,6 +212,7 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
211
212
  This significantly speeds up fitting and usually leads to no change in accuracy.
212
213
  """
213
214
 
215
+ ag_priority = 10
214
216
  init_time_in_seconds = 0 # C++ models require no compilation
215
217
  allowed_local_model_args = [
216
218
  "order",
@@ -267,6 +269,7 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
267
269
  This significantly speeds up fitting and usually leads to no change in accuracy.
268
270
  """
269
271
 
272
+ ag_priority = 70
270
273
  init_time_in_seconds = 0 # C++ models require no compilation
271
274
  allowed_local_model_args = [
272
275
  "damped",
@@ -330,6 +333,8 @@ class ETSModel(AutoETSModel):
330
333
  This significantly speeds up fitting and usually leads to no change in accuracy.
331
334
  """
332
335
 
336
+ ag_priority = 80
337
+
333
338
  def _update_local_model_args(self, local_model_args: dict) -> dict:
334
339
  local_model_args = super()._update_local_model_args(local_model_args)
335
340
  local_model_args.setdefault("model", "AAA")
@@ -369,6 +374,7 @@ class DynamicOptimizedThetaModel(AbstractProbabilisticStatsForecastModel):
369
374
  This significantly speeds up fitting and usually leads to no change in accuracy.
370
375
  """
371
376
 
377
+ ag_priority = 75
372
378
  allowed_local_model_args = [
373
379
  "decomposition_type",
374
380
  "seasonal_period",
@@ -413,6 +419,7 @@ class ThetaModel(AbstractProbabilisticStatsForecastModel):
413
419
  This significantly speeds up fitting and usually leads to no change in accuracy.
414
420
  """
415
421
 
422
+ ag_priority = 75
416
423
  allowed_local_model_args = [
417
424
  "decomposition_type",
418
425
  "seasonal_period",
@@ -533,6 +540,7 @@ class AutoCESModel(AbstractProbabilisticStatsForecastModel):
533
540
  This significantly speeds up fitting and usually leads to no change in accuracy.
534
541
  """
535
542
 
543
+ ag_priority = 10
536
544
  allowed_local_model_args = [
537
545
  "model",
538
546
  "seasonal_period",
@@ -600,6 +608,8 @@ class ADIDAModel(AbstractStatsForecastIntermittentDemandModel):
600
608
  This significantly speeds up fitting and usually leads to no change in accuracy.
601
609
  """
602
610
 
611
+ ag_priority = 10
612
+
603
613
  def _get_model_type(self, variant: Optional[str] = None):
604
614
  from statsforecast.models import ADIDA
605
615
 
@@ -636,6 +646,8 @@ class CrostonModel(AbstractStatsForecastIntermittentDemandModel):
636
646
  This significantly speeds up fitting and usually leads to no change in accuracy.
637
647
  """
638
648
 
649
+ ag_model_aliases = ["CrostonSBA"]
650
+ ag_priority = 80
639
651
  allowed_local_model_args = [
640
652
  "variant",
641
653
  ]
@@ -688,6 +700,8 @@ class IMAPAModel(AbstractStatsForecastIntermittentDemandModel):
688
700
  This significantly speeds up fitting and usually leads to no change in accuracy.
689
701
  """
690
702
 
703
+ ag_priority = 10
704
+
691
705
  def _get_model_type(self, variant: Optional[str] = None):
692
706
  from statsforecast.models import IMAPA
693
707
 
@@ -710,6 +724,8 @@ class ZeroModel(AbstractStatsForecastIntermittentDemandModel):
710
724
  This significantly speeds up fitting and usually leads to no change in accuracy.
711
725
  """
712
726
 
727
+ ag_priority = 100
728
+
713
729
  def _get_model_type(self, variant: Optional[str] = None):
714
730
  # ZeroModel does not depend on a StatsForecast implementation
715
731
  raise NotImplementedError
@@ -2,115 +2,21 @@ import copy
2
2
  import logging
3
3
  import re
4
4
  from collections import defaultdict
5
- from typing import Any, Dict, List, Optional, Type, Union
5
+ from typing import Any, Dict, List, Optional, Set, Type, Union
6
6
 
7
7
  from autogluon.common import space
8
8
  from autogluon.core import constants
9
9
  from autogluon.timeseries.metrics import TimeSeriesScorer
10
10
  from autogluon.timeseries.utils.features import CovariateMetadata
11
11
 
12
- from . import (
13
- ADIDAModel,
14
- ARIMAModel,
15
- AutoARIMAModel,
16
- AutoCESModel,
17
- AutoETSModel,
18
- AverageModel,
19
- ChronosModel,
20
- CrostonModel,
21
- DeepARModel,
22
- DirectTabularModel,
23
- DLinearModel,
24
- DynamicOptimizedThetaModel,
25
- ETSModel,
26
- IMAPAModel,
27
- NaiveModel,
28
- NPTSModel,
29
- PatchTSTModel,
30
- PerStepTabularModel,
31
- RecursiveTabularModel,
32
- SeasonalAverageModel,
33
- SeasonalNaiveModel,
34
- SimpleFeedForwardModel,
35
- TemporalFusionTransformerModel,
36
- ThetaModel,
37
- TiDEModel,
38
- WaveNetModel,
39
- ZeroModel,
40
- )
41
12
  from .abstract import AbstractTimeSeriesModel
42
13
  from .multi_window.multi_window_model import MultiWindowBacktestingModel
14
+ from .registry import ModelRegistry
43
15
 
44
16
  logger = logging.getLogger(__name__)
45
17
 
46
18
  ModelHyperparameters = Dict[str, Any]
47
19
 
48
- # define the model zoo with their aliases
49
- MODEL_TYPES = dict(
50
- SimpleFeedForward=SimpleFeedForwardModel,
51
- DeepAR=DeepARModel,
52
- DLinear=DLinearModel,
53
- PatchTST=PatchTSTModel,
54
- TemporalFusionTransformer=TemporalFusionTransformerModel,
55
- TiDE=TiDEModel,
56
- WaveNet=WaveNetModel,
57
- RecursiveTabular=RecursiveTabularModel,
58
- DirectTabular=DirectTabularModel,
59
- PerStepTabular=PerStepTabularModel,
60
- Average=AverageModel,
61
- SeasonalAverage=SeasonalAverageModel,
62
- Naive=NaiveModel,
63
- SeasonalNaive=SeasonalNaiveModel,
64
- Zero=ZeroModel,
65
- AutoETS=AutoETSModel,
66
- AutoCES=AutoCESModel,
67
- AutoARIMA=AutoARIMAModel,
68
- DynamicOptimizedTheta=DynamicOptimizedThetaModel,
69
- NPTS=NPTSModel,
70
- Theta=ThetaModel,
71
- ETS=ETSModel,
72
- ARIMA=ARIMAModel,
73
- ADIDA=ADIDAModel,
74
- Croston=CrostonModel,
75
- CrostonSBA=CrostonModel, # Alias for backward compatibility
76
- IMAPA=IMAPAModel,
77
- Chronos=ChronosModel,
78
- )
79
-
80
- DEFAULT_MODEL_NAMES = {v: k for k, v in MODEL_TYPES.items()}
81
- DEFAULT_MODEL_PRIORITY = dict(
82
- Naive=100,
83
- SeasonalNaive=100,
84
- Average=100,
85
- SeasonalAverage=100,
86
- Zero=100,
87
- RecursiveTabular=90,
88
- DirectTabular=85,
89
- PerStepTabular=70, # TODO: Update priority
90
- # All local models are grouped together to make sure that joblib parallel pool is reused
91
- NPTS=80,
92
- ETS=80,
93
- CrostonSBA=80, # Alias for backward compatibility
94
- Croston=80,
95
- Theta=75,
96
- DynamicOptimizedTheta=75,
97
- AutoETS=70,
98
- AutoARIMA=60,
99
- Chronos=55,
100
- # Models that can early stop are trained at the end
101
- TemporalFusionTransformer=45,
102
- DeepAR=40,
103
- TiDE=30,
104
- PatchTST=30,
105
- # Models below are not included in any presets
106
- WaveNet=25,
107
- AutoCES=10,
108
- ARIMA=10,
109
- ADIDA=10,
110
- IMAPA=10,
111
- SimpleFeedForward=10,
112
- )
113
- DEFAULT_CUSTOM_MODEL_PRIORITY = 0
114
20
 
115
21
  VALID_AG_ARGS_KEYS = {
116
22
  "name",
@@ -199,45 +105,23 @@ def get_preset_models(
199
105
  will create models according to presets.
200
106
  """
201
107
  models = []
202
- if hyperparameters is None:
203
- hp_string = "default"
204
- hyperparameters = copy.deepcopy(get_default_hps(hp_string))
205
- elif isinstance(hyperparameters, str):
206
- hyperparameters = copy.deepcopy(get_default_hps(hyperparameters))
207
- elif isinstance(hyperparameters, dict):
208
- hyperparameters = copy.deepcopy(hyperparameters)
209
- else:
210
- raise ValueError(
211
- f"hyperparameters must be a dict, a string or None (received {type(hyperparameters)}). "
212
- f"Please see the documentation for TimeSeriesPredictor.fit"
213
- )
214
- hyperparameters = check_and_clean_hyperparameters(hyperparameters, must_contain_searchspace=hyperparameter_tune)
108
+ hyperparameter_dict = get_hyperparameter_dict(hyperparameters, hyperparameter_tune)
215
109
 
216
- excluded_models = set()
217
- if excluded_model_types is not None and len(excluded_model_types) > 0:
218
- if not isinstance(excluded_model_types, list):
219
- raise ValueError(f"`excluded_model_types` must be a list, received {type(excluded_model_types)}")
220
- logger.info(f"Excluded model types: {excluded_model_types}")
221
- for model in excluded_model_types:
222
- if not isinstance(model, str):
223
- raise ValueError(f"Each entry in `excluded_model_types` must be a string, received {type(model)}")
224
- excluded_models.add(normalize_model_type_name(model))
225
-
226
- all_assigned_names = set(all_assigned_names)
227
-
228
- model_priority_list = sorted(hyperparameters.keys(), key=lambda x: DEFAULT_MODEL_PRIORITY.get(x, 0), reverse=True)
110
+ model_priority_list = sorted(
111
+ hyperparameter_dict.keys(), key=lambda x: ModelRegistry.get_model_priority(x), reverse=True
112
+ )
113
+ excluded_models = get_excluded_models(excluded_model_types)
114
+ all_assigned_names = all_assigned_names.copy()
229
115
 
230
116
  for model in model_priority_list:
231
117
  if isinstance(model, str):
232
- if model not in MODEL_TYPES:
233
- raise ValueError(f"Model {model} is not supported. Available models: {sorted(MODEL_TYPES)}")
234
118
  if model in excluded_models:
235
119
  logger.info(
236
120
  f"\tFound '{model}' model in `hyperparameters`, but '{model}' "
237
121
  "is present in `excluded_model_types` and will be removed."
238
122
  )
239
123
  continue
240
- model_type = MODEL_TYPES[model]
124
+ model_type: Type[AbstractTimeSeriesModel] = ModelRegistry.get_model_class(model)
241
125
  elif isinstance(model, type):
242
126
  if not issubclass(model, AbstractTimeSeriesModel):
243
127
  raise ValueError(f"Custom model type {model} must inherit from `AbstractTimeSeriesModel`.")
@@ -247,7 +131,7 @@ def get_preset_models(
247
131
  f"Keys of the `hyperparameters` dictionary must be strings or types, received {type(model)}."
248
132
  )
249
133
 
250
- for model_hps in hyperparameters[model]:
134
+ for model_hps in hyperparameter_dict[model]:
251
135
  ag_args = model_hps.pop(constants.AG_ARGS, {})
252
136
  for key in ag_args:
253
137
  if key not in VALID_AG_ARGS_KEYS:
@@ -256,7 +140,7 @@ def get_preset_models(
256
140
  )
257
141
  model_name_base = get_model_name(ag_args, model_type)
258
142
 
259
- model_type_kwargs = dict(
143
+ model_type_kwargs: Dict[str, Any] = dict(
260
144
  name=model_name_base,
261
145
  path=path,
262
146
  freq=freq,
@@ -269,22 +153,60 @@ def get_preset_models(
269
153
 
270
154
  # add models while preventing name collisions
271
155
  model = model_type(**model_type_kwargs)
272
-
273
156
  model_type_kwargs.pop("name", None)
157
+
274
158
  increment = 1
275
159
  while model.name in all_assigned_names:
276
160
  increment += 1
277
161
  model = model_type(name=f"{model_name_base}_{increment}", **model_type_kwargs)
278
162
 
279
163
  if multi_window:
280
- model = MultiWindowBacktestingModel(model_base=model, name=model.name, **model_type_kwargs)
164
+ model = MultiWindowBacktestingModel(model_base=model, name=model.name, **model_type_kwargs) # type: ignore
281
165
 
282
- all_assigned_names.add(model.name)
166
+ all_assigned_names.append(model.name)
283
167
  models.append(model)
284
168
 
285
169
  return models
286
170
 
287
171
 
172
+ def get_excluded_models(excluded_model_types: Optional[List[str]]) -> Set[str]:
173
+ excluded_models = set()
174
+ if excluded_model_types is not None and len(excluded_model_types) > 0:
175
+ if not isinstance(excluded_model_types, list):
176
+ raise ValueError(f"`excluded_model_types` must be a list, received {type(excluded_model_types)}")
177
+ logger.info(f"Excluded model types: {excluded_model_types}")
178
+ for model in excluded_model_types:
179
+ if not isinstance(model, str):
180
+ raise ValueError(f"Each entry in `excluded_model_types` must be a string, received {type(model)}")
181
+ excluded_models.add(normalize_model_type_name(model))
182
+ return excluded_models
183
+
184
+
185
+ def get_hyperparameter_dict(
186
+ hyperparameters: Union[str, Dict[str, Union[ModelHyperparameters, List[ModelHyperparameters]]], None],
187
+ hyperparameter_tune: bool,
188
+ ) -> Dict[str, List[ModelHyperparameters]]:
189
+ hyperparameter_dict = {}
190
+
191
+ if hyperparameters is None:
192
+ hyperparameter_dict = copy.deepcopy(get_default_hps("default"))
193
+ elif isinstance(hyperparameters, str):
194
+ hyperparameter_dict = copy.deepcopy(get_default_hps(hyperparameters))
195
+ elif isinstance(hyperparameters, dict):
196
+ hyperparameter_dict = copy.deepcopy(hyperparameters)
197
+ else:
198
+ raise ValueError(
199
+ f"hyperparameters must be a dict, a string or None (received {type(hyperparameters)}). "
200
+ f"Please see the documentation for TimeSeriesPredictor.fit"
201
+ )
202
+
203
+ hyperparameter_dict = check_and_clean_hyperparameters(
204
+ hyperparameter_dict, must_contain_searchspace=hyperparameter_tune
205
+ )
206
+
207
+ return hyperparameter_dict
208
+
209
+
288
210
  def normalize_model_type_name(model_name: str) -> str:
289
211
  """Remove 'Model' suffix from the end of the string, if it's present."""
290
212
  if model_name.endswith("Model"):
@@ -0,0 +1,65 @@
1
+ from abc import ABCMeta
2
+ from dataclasses import dataclass
3
+ from inspect import isabstract
4
+ from typing import Dict, List, Union
5
+
6
+
7
+ @dataclass
8
+ class ModelRecord:
9
+ model_class: type
10
+ ag_priority: int
11
+
12
+
13
+ class ModelRegistry(ABCMeta):
14
+ """Registry metaclass for time series models. Ensures that TimeSeriesModel classes
15
+ which implement this metaclass are automatically registered, in order to centralize
16
+ access to model types.
17
+
18
+ See, https://github.com/faif/python-patterns.
19
+ """
20
+
21
+ REGISTRY: Dict[str, ModelRecord] = {}
22
+
23
+ def __new__(cls, name, bases, attrs):
24
+ new_cls = super().__new__(cls, name, bases, attrs)
25
+
26
+ if name is not None and not isabstract(new_cls):
27
+ record = ModelRecord(
28
+ model_class=new_cls,
29
+ ag_priority=getattr(new_cls, "ag_priority", 0),
30
+ )
31
+ cls._add(name.removesuffix("Model"), record)
32
+
33
+ # if the class provides additional aliases, register them too
34
+ if aliases := attrs.get("ag_model_aliases"):
35
+ for alias in aliases:
36
+ cls._add(alias, record)
37
+
38
+ return new_cls
39
+
40
+ @classmethod
41
+ def _add(cls, alias: str, record: ModelRecord) -> None:
42
+ if alias in cls.REGISTRY:
43
+ raise ValueError(f"You are trying to define a new model with {alias}, but this model already exists.")
44
+ cls.REGISTRY[alias] = record
45
+
46
+ @classmethod
47
+ def _get_model_record(cls, alias: Union[str, type]) -> ModelRecord:
48
+ if isinstance(alias, type):
49
+ alias = alias.__name__
50
+ alias = alias.removesuffix("Model")
51
+ if alias not in cls.REGISTRY:
52
+ raise ValueError(f"Unknown model: {alias}, available models are: {cls.available_aliases()}")
53
+ return cls.REGISTRY[alias]
54
+
55
+ @classmethod
56
+ def get_model_class(cls, alias: Union[str, type]) -> type:
57
+ return cls._get_model_record(alias).model_class
58
+
59
+ @classmethod
60
+ def get_model_priority(cls, alias: Union[str, type]) -> int:
61
+ return cls._get_model_record(alias).ag_priority
62
+
63
+ @classmethod
64
+ def available_aliases(cls) -> List[str]:
65
+ return sorted(cls.REGISTRY.keys())
@@ -296,7 +296,10 @@ class TimeSeriesPredictor:
296
296
  df: TimeSeriesDataFrame = self._to_data_frame(data, name=name)
297
297
  if not pd.api.types.is_numeric_dtype(df[self.target]):
298
298
  raise ValueError(f"Target column {name}['{self.target}'] has a non-numeric dtype {df[self.target].dtype}")
299
+ # Assign makes a copy, so future operations can be performed in-place
299
300
  df = df.assign(**{self.target: df[self.target].astype("float64")})
301
+ df.replace(to_replace=[float("-inf"), float("inf")], value=float("nan"), inplace=True)
302
+
300
303
  # MultiIndex.is_monotonic_increasing checks if index is sorted by ["item_id", "timestamp"]
301
304
  if not df.index.is_monotonic_increasing:
302
305
  df = df.sort_index()
@@ -485,4 +485,4 @@ class ConstantReplacementFeatureImportanceTransform(AbstractFeatureImportanceTra
485
485
  if is_categorical:
486
486
  return feature_data.groupby(level=ITEMID, sort=False).transform(lambda x: x.mode()[0])
487
487
  else:
488
- return feature_data.groupby(level=ITEMID, sort=False).transform(self.real_value_aggregation)
488
+ return feature_data.groupby(level=ITEMID, sort=False).transform(self.real_value_aggregation) # type: ignore
@@ -1,4 +1,4 @@
1
1
  """This is the autogluon version file."""
2
2
 
3
- __version__ = "1.4.1b20250818"
3
+ __version__ = "1.4.1b20250820"
4
4
  __lite__ = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.timeseries
3
- Version: 1.4.1b20250818
3
+ Version: 1.4.1b20250820
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -27,6 +27,7 @@ src/autogluon/timeseries/metrics/quantile.py
27
27
  src/autogluon/timeseries/metrics/utils.py
28
28
  src/autogluon/timeseries/models/__init__.py
29
29
  src/autogluon/timeseries/models/presets.py
30
+ src/autogluon/timeseries/models/registry.py
30
31
  src/autogluon/timeseries/models/abstract/__init__.py
31
32
  src/autogluon/timeseries/models/abstract/abstract_timeseries_model.py
32
33
  src/autogluon/timeseries/models/abstract/model_trial.py
@@ -17,10 +17,10 @@ fugue>=0.9.0
17
17
  tqdm<5,>=4.38
18
18
  orjson~=3.9
19
19
  tensorboard<3,>=2.9
20
- autogluon.core[raytune]==1.4.1b20250818
21
- autogluon.common==1.4.1b20250818
22
- autogluon.features==1.4.1b20250818
23
- autogluon.tabular[catboost,lightgbm,xgboost]==1.4.1b20250818
20
+ autogluon.core[raytune]==1.4.1b20250820
21
+ autogluon.common==1.4.1b20250820
22
+ autogluon.features==1.4.1b20250820
23
+ autogluon.tabular[catboost,lightgbm,xgboost]==1.4.1b20250820
24
24
 
25
25
  [all]
26
26