autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +339 -186
  5. autogluon/timeseries/learner.py +192 -60
  6. autogluon/timeseries/metrics/__init__.py +55 -11
  7. autogluon/timeseries/metrics/abstract.py +96 -25
  8. autogluon/timeseries/metrics/point.py +186 -39
  9. autogluon/timeseries/metrics/quantile.py +47 -20
  10. autogluon/timeseries/metrics/utils.py +6 -6
  11. autogluon/timeseries/models/__init__.py +13 -7
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
  14. autogluon/timeseries/models/abstract/model_trial.py +10 -10
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
  20. autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
  21. autogluon/timeseries/models/chronos/__init__.py +4 -0
  22. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  23. autogluon/timeseries/models/chronos/model.py +738 -0
  24. autogluon/timeseries/models/chronos/utils.py +369 -0
  25. autogluon/timeseries/models/ensemble/__init__.py +35 -2
  26. autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
  27. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  28. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  29. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  34. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  35. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  36. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  37. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  38. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  39. autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
  40. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  41. autogluon/timeseries/models/gluonts/__init__.py +3 -1
  42. autogluon/timeseries/models/gluonts/abstract.py +583 -0
  43. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  44. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
  45. autogluon/timeseries/models/local/__init__.py +1 -10
  46. autogluon/timeseries/models/local/abstract_local_model.py +150 -97
  47. autogluon/timeseries/models/local/naive.py +31 -23
  48. autogluon/timeseries/models/local/npts.py +6 -2
  49. autogluon/timeseries/models/local/statsforecast.py +99 -112
  50. autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
  51. autogluon/timeseries/models/registry.py +64 -0
  52. autogluon/timeseries/models/toto/__init__.py +3 -0
  53. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  62. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  63. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  64. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  65. autogluon/timeseries/models/toto/dataloader.py +108 -0
  66. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  67. autogluon/timeseries/models/toto/model.py +236 -0
  68. autogluon/timeseries/predictor.py +826 -305
  69. autogluon/timeseries/regressor.py +253 -0
  70. autogluon/timeseries/splitter.py +10 -31
  71. autogluon/timeseries/trainer/__init__.py +2 -3
  72. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  73. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  74. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  75. autogluon/timeseries/trainer/trainer.py +1298 -0
  76. autogluon/timeseries/trainer/utils.py +17 -0
  77. autogluon/timeseries/transforms/__init__.py +2 -0
  78. autogluon/timeseries/transforms/covariate_scaler.py +164 -0
  79. autogluon/timeseries/transforms/target_scaler.py +149 -0
  80. autogluon/timeseries/utils/constants.py +10 -0
  81. autogluon/timeseries/utils/datetime/base.py +38 -20
  82. autogluon/timeseries/utils/datetime/lags.py +18 -16
  83. autogluon/timeseries/utils/datetime/seasonality.py +14 -14
  84. autogluon/timeseries/utils/datetime/time_features.py +17 -14
  85. autogluon/timeseries/utils/features.py +317 -53
  86. autogluon/timeseries/utils/forecast.py +31 -17
  87. autogluon/timeseries/utils/timer.py +173 -0
  88. autogluon/timeseries/utils/warning_filters.py +44 -6
  89. autogluon/timeseries/version.py +2 -1
  90. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  91. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
  92. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  93. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  94. autogluon/timeseries/configs/presets_configs.py +0 -11
  95. autogluon/timeseries/evaluator.py +0 -6
  96. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  97. autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
  98. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  99. autogluon/timeseries/models/presets.py +0 -325
  100. autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
  101. autogluon/timeseries/trainer/auto_trainer.py +0 -74
  102. autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
  103. autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
  104. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Any, Dict, Type
2
+ from typing import Any, Type
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -14,22 +14,24 @@ class AbstractStatsForecastModel(AbstractLocalModel):
14
14
 
15
15
  init_time_in_seconds = 15 # numba compilation for the first run
16
16
 
17
- def _update_local_model_args(self, local_model_args: Dict[str, Any]) -> Dict[str, Any]:
17
+ def _update_local_model_args(self, local_model_args: dict[str, Any]) -> dict[str, Any]:
18
18
  seasonal_period = local_model_args.pop("seasonal_period")
19
19
  local_model_args["season_length"] = seasonal_period
20
20
  return local_model_args
21
21
 
22
- def _get_model_type(self) -> Type:
22
+ def _get_model_type(self, variant: str | None = None) -> Type:
23
23
  raise NotImplementedError
24
24
 
25
- def _get_local_model(self, local_model_args: Dict):
26
- model_type = self._get_model_type()
25
+ def _get_local_model(self, local_model_args: dict):
26
+ local_model_args = local_model_args.copy()
27
+ variant = local_model_args.pop("variant", None)
28
+ model_type = self._get_model_type(variant)
27
29
  return model_type(**local_model_args)
28
30
 
29
31
  def _get_point_forecast(
30
32
  self,
31
33
  time_series: pd.Series,
32
- local_model_args: Dict,
34
+ local_model_args: dict,
33
35
  ) -> np.ndarray:
34
36
  return self._get_local_model(local_model_args).forecast(
35
37
  h=self.prediction_length, y=time_series.values.ravel()
@@ -49,15 +51,7 @@ class AbstractProbabilisticStatsForecastModel(AbstractStatsForecastModel):
49
51
  time_series: pd.Series,
50
52
  local_model_args: dict,
51
53
  ) -> pd.DataFrame:
52
- # Code does conversion between confidence levels and quantiles
53
- levels = []
54
- quantile_to_key = {}
55
- for q in self.quantile_levels:
56
- level = round(abs(q - 0.5) * 200, 1)
57
- suffix = "lo" if q < 0.5 else "hi"
58
- levels.append(level)
59
- quantile_to_key[str(q)] = f"{suffix}-{level}"
60
- levels = sorted(list(set(levels)))
54
+ levels, quantile_to_key = self._get_confidence_levels()
61
55
 
62
56
  forecast = self._get_local_model(local_model_args).forecast(
63
57
  h=self.prediction_length, y=time_series.values.ravel(), level=levels
@@ -67,6 +61,18 @@ class AbstractProbabilisticStatsForecastModel(AbstractStatsForecastModel):
67
61
  predictions[q] = forecast[key]
68
62
  return pd.DataFrame(predictions)
69
63
 
64
+ def _get_confidence_levels(self) -> tuple[list[float], dict[str, str]]:
65
+ """Get StatsForecast compatible levels from quantiles"""
66
+ levels = []
67
+ quantile_to_key = {}
68
+ for q in self.quantile_levels:
69
+ level = round(abs(q - 0.5) * 200, 1)
70
+ suffix = "lo" if q < 0.5 else "hi"
71
+ levels.append(level)
72
+ quantile_to_key[str(q)] = f"{suffix}-{level}"
73
+ levels = sorted(list(set(levels)))
74
+ return levels, quantile_to_key
75
+
70
76
 
71
77
  class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
72
78
  """Automatically tuned ARIMA model.
@@ -117,7 +123,7 @@ class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
117
123
  When set to None, seasonal_period will be inferred from the frequency of the training data. Can also be
118
124
  specified manually by providing an integer > 1.
119
125
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
120
- n_jobs : int or float, default = 0.5
126
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
121
127
  Number of CPU cores used to fit the models in parallel.
122
128
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
123
129
  When set to a positive integer, that many cores are used.
@@ -127,6 +133,8 @@ class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
127
133
  This significantly speeds up fitting and usually leads to no change in accuracy.
128
134
  """
129
135
 
136
+ ag_priority = 60
137
+ init_time_in_seconds = 0 # C++ models require no compilation
130
138
  allowed_local_model_args = [
131
139
  "d",
132
140
  "D",
@@ -154,7 +162,7 @@ class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
154
162
  local_model_args.setdefault("allowmean", True)
155
163
  return local_model_args
156
164
 
157
- def _get_model_type(self):
165
+ def _get_model_type(self, variant: str | None = None):
158
166
  from statsforecast.models import AutoARIMA
159
167
 
160
168
  return AutoARIMA
@@ -168,9 +176,9 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
168
176
 
169
177
  Other Parameters
170
178
  ----------------
171
- order: Tuple[int, int, int], default = (1, 1, 1)
179
+ order: tuple[int, int, int], default = (1, 1, 1)
172
180
  The (p, d, q) order of the model for the number of AR parameters, differences, and MA parameters to use.
173
- seasonal_order: Tuple[int, int, int], default = (0, 0, 0)
181
+ seasonal_order: tuple[int, int, int], default = (0, 0, 0)
174
182
  The (P, D, Q) parameters of the seasonal ARIMA model. Setting to (0, 0, 0) disables seasonality.
175
183
  include_mean : bool, default = True
176
184
  Should the ARIMA model include a mean term?
@@ -186,7 +194,7 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
186
194
  method : {"CSS-ML", "CSS", "ML"}, default = "CSS-ML"
187
195
  Fitting method: CSS (conditional sum of squares), ML (maximum likelihood), CSS-ML (initialize with CSS, then
188
196
  optimize with ML).
189
- fixed : Dict[str, float], optional
197
+ fixed : dict[str, float], optional
190
198
  Dictionary containing fixed coefficients for the ARIMA model.
191
199
  seasonal_period : int or None, default = None
192
200
  Number of time steps in a complete seasonal cycle for seasonal models. For example, 7 for daily data with a
@@ -194,7 +202,7 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
194
202
  When set to None, seasonal_period will be inferred from the frequency of the training data. Can also be
195
203
  specified manually by providing an integer > 1.
196
204
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
197
- n_jobs : int or float, default = 0.5
205
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
198
206
  Number of CPU cores used to fit the models in parallel.
199
207
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
200
208
  When set to a positive integer, that many cores are used.
@@ -204,6 +212,8 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
204
212
  This significantly speeds up fitting and usually leads to no change in accuracy.
205
213
  """
206
214
 
215
+ ag_priority = 10
216
+ init_time_in_seconds = 0 # C++ models require no compilation
207
217
  allowed_local_model_args = [
208
218
  "order",
209
219
  "seasonal_order",
@@ -222,7 +232,7 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
222
232
  local_model_args.setdefault("order", (1, 1, 1))
223
233
  return local_model_args
224
234
 
225
- def _get_model_type(self):
235
+ def _get_model_type(self, variant: str | None = None):
226
236
  from statsforecast.models import ARIMA
227
237
 
228
238
  return ARIMA
@@ -249,7 +259,7 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
249
259
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
250
260
  damped : bool, default = False
251
261
  Whether to dampen the trend.
252
- n_jobs : int or float, default = 0.5
262
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
253
263
  Number of CPU cores used to fit the models in parallel.
254
264
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
255
265
  When set to a positive integer, that many cores are used.
@@ -259,13 +269,15 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
259
269
  This significantly speeds up fitting and usually leads to no change in accuracy.
260
270
  """
261
271
 
272
+ ag_priority = 70
273
+ init_time_in_seconds = 0 # C++ models require no compilation
262
274
  allowed_local_model_args = [
263
275
  "damped",
264
276
  "model",
265
277
  "seasonal_period",
266
278
  ]
267
279
 
268
- def _get_model_type(self):
280
+ def _get_model_type(self, variant: str | None = None):
269
281
  from statsforecast.models import AutoETS
270
282
 
271
283
  return AutoETS
@@ -284,7 +296,7 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
284
296
  # Disable seasonality if time series too short for chosen season_length, season_length is too high, or
285
297
  # season_length == 1. Otherwise model will crash
286
298
  season_length = local_model_args["season_length"]
287
- if len(time_series) < 2 * season_length or season_length == 1 or season_length > 24:
299
+ if len(time_series) < 2 * season_length or season_length == 1:
288
300
  # changing last character to "N" disables seasonality, e.g., model="AAA" -> model="AAN"
289
301
  local_model_args["model"] = local_model_args["model"][:-1] + "N"
290
302
  return super()._predict_with_local_model(time_series=time_series, local_model_args=local_model_args)
@@ -311,7 +323,7 @@ class ETSModel(AutoETSModel):
311
323
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
312
324
  damped : bool, default = False
313
325
  Whether to dampen the trend.
314
- n_jobs : int or float, default = 0.5
326
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
315
327
  Number of CPU cores used to fit the models in parallel.
316
328
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
317
329
  When set to a positive integer, that many cores are used.
@@ -321,6 +333,8 @@ class ETSModel(AutoETSModel):
321
333
  This significantly speeds up fitting and usually leads to no change in accuracy.
322
334
  """
323
335
 
336
+ ag_priority = 80
337
+
324
338
  def _update_local_model_args(self, local_model_args: dict) -> dict:
325
339
  local_model_args = super()._update_local_model_args(local_model_args)
326
340
  local_model_args.setdefault("model", "AAA")
@@ -350,7 +364,7 @@ class DynamicOptimizedThetaModel(AbstractProbabilisticStatsForecastModel):
350
364
  When set to None, seasonal_period will be inferred from the frequency of the training data. Can also be
351
365
  specified manually by providing an integer > 1.
352
366
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
353
- n_jobs : int or float, default = 0.5
367
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
354
368
  Number of CPU cores used to fit the models in parallel.
355
369
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
356
370
  When set to a positive integer, that many cores are used.
@@ -360,12 +374,13 @@ class DynamicOptimizedThetaModel(AbstractProbabilisticStatsForecastModel):
360
374
  This significantly speeds up fitting and usually leads to no change in accuracy.
361
375
  """
362
376
 
377
+ ag_priority = 75
363
378
  allowed_local_model_args = [
364
379
  "decomposition_type",
365
380
  "seasonal_period",
366
381
  ]
367
382
 
368
- def _get_model_type(self):
383
+ def _get_model_type(self, variant: str | None = None):
369
384
  from statsforecast.models import DynamicOptimizedTheta
370
385
 
371
386
  return DynamicOptimizedTheta
@@ -394,7 +409,7 @@ class ThetaModel(AbstractProbabilisticStatsForecastModel):
394
409
  When set to None, seasonal_period will be inferred from the frequency of the training data. Can also be
395
410
  specified manually by providing an integer > 1.
396
411
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
397
- n_jobs : int or float, default = 0.5
412
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
398
413
  Number of CPU cores used to fit the models in parallel.
399
414
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
400
415
  When set to a positive integer, that many cores are used.
@@ -404,12 +419,13 @@ class ThetaModel(AbstractProbabilisticStatsForecastModel):
404
419
  This significantly speeds up fitting and usually leads to no change in accuracy.
405
420
  """
406
421
 
422
+ ag_priority = 75
407
423
  allowed_local_model_args = [
408
424
  "decomposition_type",
409
425
  "seasonal_period",
410
426
  ]
411
427
 
412
- def _get_model_type(self):
428
+ def _get_model_type(self, variant: str | None = None):
413
429
  from statsforecast.models import Theta
414
430
 
415
431
  return Theta
@@ -433,7 +449,7 @@ class AbstractConformalizedStatsForecastModel(AbstractStatsForecastModel):
433
449
  def _get_nonconformity_scores(
434
450
  self,
435
451
  time_series: pd.Series,
436
- local_model_args: Dict,
452
+ local_model_args: dict,
437
453
  ) -> np.ndarray:
438
454
  h = self.prediction_length
439
455
  y = time_series.values.ravel()
@@ -489,10 +505,9 @@ class AbstractConformalizedStatsForecastModel(AbstractStatsForecastModel):
489
505
  return pd.DataFrame(predictions)
490
506
 
491
507
 
492
- # TODO: Starting from StatsForecast v1.5.0, AutoCES can inherit from AbstractProbabilisticStatsForecastModel
493
- class AutoCESModel(AbstractConformalizedStatsForecastModel):
508
+ class AutoCESModel(AbstractProbabilisticStatsForecastModel):
494
509
  """Forecasting with an Complex Exponential Smoothing model where the model selection is performed using the
495
- Akaike Information Criterion.
510
+ Akaike Information Criterion [Svetunkov2022]_.
496
511
 
497
512
  Based on `statsforecast.models.AutoCES <https://nixtla.mintlify.app/statsforecast/docs/models/autoces.html>`_.
498
513
 
@@ -515,7 +530,7 @@ class AutoCESModel(AbstractConformalizedStatsForecastModel):
515
530
  When set to None, seasonal_period will be inferred from the frequency of the training data. Can also be
516
531
  specified manually by providing an integer > 1.
517
532
  If seasonal_period (inferred or provided) is equal to 1, seasonality will be disabled.
518
- n_jobs : int or float, default = 0.5
533
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
519
534
  Number of CPU cores used to fit the models in parallel.
520
535
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
521
536
  When set to a positive integer, that many cores are used.
@@ -525,12 +540,13 @@ class AutoCESModel(AbstractConformalizedStatsForecastModel):
525
540
  This significantly speeds up fitting and usually leads to no change in accuracy.
526
541
  """
527
542
 
543
+ ag_priority = 10
528
544
  allowed_local_model_args = [
529
545
  "model",
530
546
  "seasonal_period",
531
547
  ]
532
548
 
533
- def _get_model_type(self):
549
+ def _get_model_type(self, variant: str | None = None):
534
550
  from statsforecast.models import AutoCES
535
551
 
536
552
  return AutoCES
@@ -540,7 +556,7 @@ class AutoCESModel(AbstractConformalizedStatsForecastModel):
540
556
  local_model_args.setdefault("model", "Z")
541
557
  return local_model_args
542
558
 
543
- def _get_point_forecast(self, time_series: pd.Series, local_model_args: Dict):
559
+ def _get_point_forecast(self, time_series: pd.Series, local_model_args: dict):
544
560
  # Disable seasonality if time series too short for chosen season_length or season_length == 1,
545
561
  # otherwise model will crash
546
562
  if len(time_series) < 5:
@@ -552,7 +568,7 @@ class AutoCESModel(AbstractConformalizedStatsForecastModel):
552
568
 
553
569
 
554
570
  class AbstractStatsForecastIntermittentDemandModel(AbstractConformalizedStatsForecastModel):
555
- def _update_local_model_args(self, local_model_args: Dict[str, Any]) -> Dict[str, Any]:
571
+ def _update_local_model_args(self, local_model_args: dict[str, Any]) -> dict[str, Any]:
556
572
  _ = local_model_args.pop("seasonal_period")
557
573
  return local_model_args
558
574
 
@@ -582,7 +598,7 @@ class ADIDAModel(AbstractStatsForecastIntermittentDemandModel):
582
598
 
583
599
  Other Parameters
584
600
  ----------------
585
- n_jobs : int or float, default = 0.5
601
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
586
602
  Number of CPU cores used to fit the models in parallel.
587
603
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
588
604
  When set to a positive integer, that many cores are used.
@@ -592,59 +608,35 @@ class ADIDAModel(AbstractStatsForecastIntermittentDemandModel):
592
608
  This significantly speeds up fitting and usually leads to no change in accuracy.
593
609
  """
594
610
 
595
- def _get_model_type(self):
611
+ ag_priority = 10
612
+
613
+ def _get_model_type(self, variant: str | None = None):
596
614
  from statsforecast.models import ADIDA
597
615
 
598
616
  return ADIDA
599
617
 
600
618
 
601
- class CrostonSBAModel(AbstractStatsForecastIntermittentDemandModel):
602
- """Intermittent demand forecasting model using Croston's model with the Syntetos-Boylan
603
- bias correction approach [SyntetosBoylan2001]_.
604
-
605
- Based on `statsforecast.models.CrostonSBA <https://nixtla.mintlify.app/statsforecast/docs/models/crostonsba.html>`_.
606
-
619
+ class CrostonModel(AbstractStatsForecastIntermittentDemandModel):
620
+ """Intermittent demand forecasting model using Croston's model from [Croston1972]_ and [SyntetosBoylan2001]_.
607
621
 
608
622
  References
609
623
  ----------
624
+ .. [Croston1972] Croston, John D. "Forecasting and stock control for intermittent demands." Journal of
625
+ the Operational Research Society 23.3 (1972): 289-303.
610
626
  .. [SyntetosBoylan2001] Syntetos, Aris A., and John E. Boylan. "On the bias of intermittent
611
627
  demand estimates." International journal of production economics 71.1-3 (2001): 457-466.
612
628
 
613
629
 
614
630
  Other Parameters
615
631
  ----------------
616
- n_jobs : int or float, default = 0.5
617
- Number of CPU cores used to fit the models in parallel.
618
- When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
619
- When set to a positive integer, that many cores are used.
620
- When set to -1, all CPU cores are used.
621
- max_ts_length : int, default = 2500
622
- If not None, only the last ``max_ts_length`` time steps of each time series will be used to train the model.
623
- This significantly speeds up fitting and usually leads to no change in accuracy.
624
- """
625
-
626
- def _get_model_type(self):
627
- from statsforecast.models import CrostonSBA
628
-
629
- return CrostonSBA
630
-
631
-
632
- class CrostonOptimizedModel(AbstractStatsForecastIntermittentDemandModel):
633
- """Intermittent demand forecasting model using Croston's model where the smoothing parameter
634
- is optimized [Croston1972]_.
635
-
636
- Based on `statsforecast.models.CrostonOptimized <https://nixtla.mintlify.app/statsforecast/docs/models/crostonoptimized.html>`_.
632
+ variant : {"SBA", "classic", "optimized"}, default = "SBA"
633
+ Variant of the Croston model that is used. Available options:
637
634
 
635
+ - ``"classic"`` - variant of the Croston method where the smoothing parameter is fixed to 0.1 (based on `statsforecast.models.CrostonClassic <https://nixtla.mintlify.app/statsforecast/docs/models/crostonclassic.html>`_)
636
+ - ``"SBA"`` - variant of the Croston method based on Syntetos-Boylan Approximation (based on `statsforecast.models.CrostonSBA <https://nixtla.mintlify.app/statsforecast/docs/models/crostonsba.html>`_)
637
+ - ``"optimized"`` - variant of the Croston method where the smoothing parameter is optimized (based on `statsforecast.models.CrostonOptimized <https://nixtla.mintlify.app/statsforecast/docs/models/crostonoptimized.html>`_)
638
638
 
639
- References
640
- ----------
641
- .. [Croston1972] Croston, John D. "Forecasting and stock control for intermittent demands." Journal of
642
- the Operational Research Society 23.3 (1972): 289-303.
643
-
644
-
645
- Other Parameters
646
- ----------------
647
- n_jobs : int or float, default = 0.5
639
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
648
640
  Number of CPU cores used to fit the models in parallel.
649
641
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
650
642
  When set to a positive integer, that many cores are used.
@@ -654,41 +646,32 @@ class CrostonOptimizedModel(AbstractStatsForecastIntermittentDemandModel):
654
646
  This significantly speeds up fitting and usually leads to no change in accuracy.
655
647
  """
656
648
 
657
- def _get_model_type(self):
658
- from statsforecast.models import CrostonOptimized
659
-
660
- return CrostonOptimized
661
-
662
-
663
- class CrostonClassicModel(AbstractStatsForecastIntermittentDemandModel):
664
- """Intermittent demand forecasting model using Croston's model where the smoothing parameter
665
- is fixed to 0.1 [Croston1972]_.
666
-
667
- Based on `statsforecast.models.CrostonClassic <https://nixtla.mintlify.app/statsforecast/docs/models/crostonclassic.html>`_.
668
-
669
-
670
- References
671
- ----------
672
- .. [Croston1972] Croston, John D. "Forecasting and stock control for intermittent demands." Journal of
673
- the Operational Research Society 23.3 (1972): 289-303.
649
+ ag_model_aliases = ["CrostonSBA"]
650
+ ag_priority = 80
651
+ allowed_local_model_args = [
652
+ "variant",
653
+ ]
674
654
 
655
+ def _get_model_type(self, variant: str | None = None):
656
+ from statsforecast.models import CrostonClassic, CrostonOptimized, CrostonSBA
675
657
 
676
- Other Parameters
677
- ----------------
678
- n_jobs : int or float, default = 0.5
679
- Number of CPU cores used to fit the models in parallel.
680
- When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
681
- When set to a positive integer, that many cores are used.
682
- When set to -1, all CPU cores are used.
683
- max_ts_length : int, default = 2500
684
- If not None, only the last ``max_ts_length`` time steps of each time series will be used to train the model.
685
- This significantly speeds up fitting and usually leads to no change in accuracy.
686
- """
658
+ model_variants = {
659
+ "classic": CrostonClassic,
660
+ "sba": CrostonSBA,
661
+ "optimized": CrostonOptimized,
662
+ }
687
663
 
688
- def _get_model_type(self):
689
- from statsforecast.models import CrostonClassic
664
+ if not isinstance(variant, str) or variant.lower() not in model_variants:
665
+ raise ValueError(
666
+ f"Invalid model variant '{variant}'. Available Croston model variants: {list(model_variants)}"
667
+ )
668
+ else:
669
+ return model_variants[variant.lower()]
690
670
 
691
- return CrostonClassic
671
+ def _update_local_model_args(self, local_model_args: dict) -> dict:
672
+ local_model_args = super()._update_local_model_args(local_model_args)
673
+ local_model_args.setdefault("variant", "SBA")
674
+ return local_model_args
692
675
 
693
676
 
694
677
  class IMAPAModel(AbstractStatsForecastIntermittentDemandModel):
@@ -707,7 +690,7 @@ class IMAPAModel(AbstractStatsForecastIntermittentDemandModel):
707
690
 
708
691
  Other Parameters
709
692
  ----------------
710
- n_jobs : int or float, default = 0.5
693
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
711
694
  Number of CPU cores used to fit the models in parallel.
712
695
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
713
696
  When set to a positive integer, that many cores are used.
@@ -717,7 +700,9 @@ class IMAPAModel(AbstractStatsForecastIntermittentDemandModel):
717
700
  This significantly speeds up fitting and usually leads to no change in accuracy.
718
701
  """
719
702
 
720
- def _get_model_type(self):
703
+ ag_priority = 10
704
+
705
+ def _get_model_type(self, variant: str | None = None):
721
706
  from statsforecast.models import IMAPA
722
707
 
723
708
  return IMAPA
@@ -729,7 +714,7 @@ class ZeroModel(AbstractStatsForecastIntermittentDemandModel):
729
714
 
730
715
  Other Parameters
731
716
  ----------------
732
- n_jobs : int or float, default = 0.5
717
+ n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
733
718
  Number of CPU cores used to fit the models in parallel.
734
719
  When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
735
720
  When set to a positive integer, that many cores are used.
@@ -739,13 +724,15 @@ class ZeroModel(AbstractStatsForecastIntermittentDemandModel):
739
724
  This significantly speeds up fitting and usually leads to no change in accuracy.
740
725
  """
741
726
 
742
- def _get_model_type(self):
727
+ ag_priority = 100
728
+
729
+ def _get_model_type(self, variant: str | None = None):
743
730
  # ZeroModel does not depend on a StatsForecast implementation
744
731
  raise NotImplementedError
745
732
 
746
733
  def _get_point_forecast(
747
734
  self,
748
735
  time_series: pd.Series,
749
- local_model_args: Dict,
736
+ local_model_args: dict,
750
737
  ):
751
738
  return np.zeros(self.prediction_length)