autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +339 -186
  5. autogluon/timeseries/learner.py +192 -60
  6. autogluon/timeseries/metrics/__init__.py +55 -11
  7. autogluon/timeseries/metrics/abstract.py +96 -25
  8. autogluon/timeseries/metrics/point.py +186 -39
  9. autogluon/timeseries/metrics/quantile.py +47 -20
  10. autogluon/timeseries/metrics/utils.py +6 -6
  11. autogluon/timeseries/models/__init__.py +13 -7
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
  14. autogluon/timeseries/models/abstract/model_trial.py +10 -10
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
  20. autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
  21. autogluon/timeseries/models/chronos/__init__.py +4 -0
  22. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  23. autogluon/timeseries/models/chronos/model.py +738 -0
  24. autogluon/timeseries/models/chronos/utils.py +369 -0
  25. autogluon/timeseries/models/ensemble/__init__.py +35 -2
  26. autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
  27. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  28. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  29. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  34. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  35. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  36. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  37. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  38. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  39. autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
  40. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  41. autogluon/timeseries/models/gluonts/__init__.py +3 -1
  42. autogluon/timeseries/models/gluonts/abstract.py +583 -0
  43. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  44. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
  45. autogluon/timeseries/models/local/__init__.py +1 -10
  46. autogluon/timeseries/models/local/abstract_local_model.py +150 -97
  47. autogluon/timeseries/models/local/naive.py +31 -23
  48. autogluon/timeseries/models/local/npts.py +6 -2
  49. autogluon/timeseries/models/local/statsforecast.py +99 -112
  50. autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
  51. autogluon/timeseries/models/registry.py +64 -0
  52. autogluon/timeseries/models/toto/__init__.py +3 -0
  53. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  62. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  63. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  64. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  65. autogluon/timeseries/models/toto/dataloader.py +108 -0
  66. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  67. autogluon/timeseries/models/toto/model.py +236 -0
  68. autogluon/timeseries/predictor.py +826 -305
  69. autogluon/timeseries/regressor.py +253 -0
  70. autogluon/timeseries/splitter.py +10 -31
  71. autogluon/timeseries/trainer/__init__.py +2 -3
  72. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  73. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  74. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  75. autogluon/timeseries/trainer/trainer.py +1298 -0
  76. autogluon/timeseries/trainer/utils.py +17 -0
  77. autogluon/timeseries/transforms/__init__.py +2 -0
  78. autogluon/timeseries/transforms/covariate_scaler.py +164 -0
  79. autogluon/timeseries/transforms/target_scaler.py +149 -0
  80. autogluon/timeseries/utils/constants.py +10 -0
  81. autogluon/timeseries/utils/datetime/base.py +38 -20
  82. autogluon/timeseries/utils/datetime/lags.py +18 -16
  83. autogluon/timeseries/utils/datetime/seasonality.py +14 -14
  84. autogluon/timeseries/utils/datetime/time_features.py +17 -14
  85. autogluon/timeseries/utils/features.py +317 -53
  86. autogluon/timeseries/utils/forecast.py +31 -17
  87. autogluon/timeseries/utils/timer.py +173 -0
  88. autogluon/timeseries/utils/warning_filters.py +44 -6
  89. autogluon/timeseries/version.py +2 -1
  90. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  91. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
  92. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  93. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  94. autogluon/timeseries/configs/presets_configs.py +0 -11
  95. autogluon/timeseries/evaluator.py +0 -6
  96. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  97. autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
  98. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  99. autogluon/timeseries/models/presets.py +0 -325
  100. autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
  101. autogluon/timeseries/trainer/auto_trainer.py +0 -74
  102. autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
  103. autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
  104. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -0,0 +1,361 @@
1
+ import logging
2
+ import os
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing_extensions import Self
8
+
9
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
10
+ from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Chronos2Model(AbstractTimeSeriesModel):
16
+ """Chronos-2 pretrained time series forecasting model [Ansari2025]_, which provides strong zero-shot forecasting
17
+ capability natively taking advantage of covariates. The model can also be fine-tuned in a task specific manner.
18
+
19
+ This implementation wraps the original implementation in the `chronos-forecasting`
20
+ `library <https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos2/pipeline.py>`_ .
21
+
22
+ Chronos-2 can be used both on GPU and CPU. However, we recommend using a GPU for faster inference and fine-tuning.
23
+
24
+ Chronos-2 variants can be fine-tuned by setting ``fine_tune=True`` and selecting appropriate fine-tuning parameters
25
+ such as the learning rate (``fine_tune_lr``) and max steps (``fine_tune_steps``). By default, a low-rank adapter (LoRA)
26
+ will be used for fine-tuning.
27
+
28
+ References
29
+ ----------
30
+ .. [Ansari2025] Ansari, Abdul Fatir, Shchur, Oleksandr, Kuken, Jaris et al.
31
+ "Chronos-2: From Univariate to Universal Forecasting." (2025).
32
+ https://arxiv.org/abs/2510.15821
33
+
34
+ Other Parameters
35
+ ----------------
36
+ model_path : str, default = "autogluon/chronos-2"
37
+ Model path used for the model, i.e., a Hugging Face transformers ``name_or_path``. Can be a
38
+ compatible model name on Hugging Face Hub or a local path to a model directory.
39
+ batch_size : int, default = 256
40
+ Size of batches used during inference.
41
+ device : str, default = None
42
+ Device to use for inference (and fine-tuning, if enabled). If None, model will use the GPU if
43
+ available.
44
+ cross_learning : bool, default = True
45
+ If True, the cross-learning mode of Chronos-2 is enabled. This means that the model will make joint
46
+ predictions across time series in a batch, by default True
47
+ Note: Enabling this mode makes the results sensitive to the ``batch_size`` used.
48
+ context_length : int or None, default = None
49
+ The context length to use for inference. If None, the model will use its default context length
50
+ of 8192. Shorter context lengths may reduce accuracy, but result in faster inference.
51
+ fine_tune : bool, default = False
52
+ If True, the pretrained model will be fine-tuned.
53
+ fine_tune_mode : str, default = "lora"
54
+ Fine-tuning mode, either "full" for full fine-tuning or "lora" for Low Rank Adaptation (LoRA).
55
+ LoRA is faster and uses less memory.
56
+ fine_tune_lr : float, default = 1e-5
57
+ The learning rate used for fine-tuning. When using full fine-tuning, a lower learning rate such as 1e-6
58
+ is recommended.
59
+ fine_tune_steps : int, default = 1000
60
+ The number of gradient update steps to fine-tune for.
61
+ fine_tune_batch_size : int, default = 32
62
+ The batch size to use for fine-tuning.
63
+ fine_tune_context_length : int, default = 2048
64
+ The maximum context_length to use for fine-tuning
65
+ eval_during_fine_tune : bool, default = False
66
+ If True, validation will be performed during fine-tuning to select the best checkpoint. Setting this
67
+ argument to True may result in slower fine-tuning. This parameter is ignored if ``skip_model_selection=True``
68
+ in ``TimeSeriesPredictor.fit``.
69
+ fine_tune_eval_max_items : int, default = 256
70
+ The maximum number of randomly-sampled time series to use from the validation set for evaluation
71
+ during fine-tuning. If None, the entire validation dataset will be used.
72
+ fine_tune_lora_config : dict, optional
73
+ Configuration for LoRA fine-tuning when ``fine_tune_mode="lora"``. If None and LoRA is enabled,
74
+ a default configuration will be used. Example: ``{"r": 8, "lora_alpha": 16}``.
75
+ fine_tune_trainer_kwargs : dict, optional
76
+ Extra keyword arguments passed to ``transformers.TrainingArguments``
77
+ revision : str, default = None
78
+ Model revision to use (branch name or commit hash). If None, the default branch (usually "main") is used.
79
+ """
80
+
81
+ ag_model_aliases = ["Chronos-2"]
82
+ fine_tuned_ckpt_name: str = "fine-tuned-ckpt"
83
+
84
+ _supports_known_covariates = True
85
+ _supports_past_covariates = True
86
+
87
+ def __init__(
88
+ self,
89
+ freq: str | None = None,
90
+ prediction_length: int = 1,
91
+ path: str | None = None,
92
+ name: str | None = None,
93
+ eval_metric: str | None = None,
94
+ hyperparameters: dict[str, Any] | None = None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(
98
+ path=path,
99
+ freq=freq,
100
+ prediction_length=prediction_length,
101
+ name=name,
102
+ eval_metric=eval_metric,
103
+ hyperparameters=hyperparameters,
104
+ **kwargs,
105
+ )
106
+ self._is_fine_tuned: bool = False
107
+ self._model_pipeline = None
108
+
109
+ @property
110
+ def model_path(self) -> str:
111
+ default_model_path = self.get_hyperparameter("model_path")
112
+
113
+ if self._is_fine_tuned:
114
+ model_path = os.path.join(self.path, self.fine_tuned_ckpt_name)
115
+ if not os.path.exists(model_path):
116
+ raise ValueError("Cannot find finetuned checkpoint for Chronos-2.")
117
+ else:
118
+ return model_path
119
+
120
+ return default_model_path
121
+
122
+ def save(self, path: str | None = None, verbose: bool = True) -> str:
123
+ pipeline = self._model_pipeline
124
+ self._model_pipeline = None
125
+ path = super().save(path=path, verbose=verbose)
126
+ self._model_pipeline = pipeline
127
+
128
+ return str(path)
129
+
130
+ def _fit(
131
+ self,
132
+ train_data: TimeSeriesDataFrame,
133
+ val_data: TimeSeriesDataFrame | None = None,
134
+ time_limit: float | None = None,
135
+ num_cpus: int | None = None,
136
+ num_gpus: int | None = None,
137
+ verbosity: int = 2,
138
+ **kwargs,
139
+ ) -> None:
140
+ self._check_fit_params()
141
+ self._log_unused_hyperparameters()
142
+ self.load_model_pipeline()
143
+
144
+ # NOTE: This must be placed after load_model_pipeline to ensure that the loggers are available in loggerDict
145
+ self._update_transformers_loggers(logging.ERROR if verbosity <= 3 else logging.INFO)
146
+
147
+ if self.get_hyperparameter("fine_tune"):
148
+ self._fine_tune(train_data, val_data, time_limit=time_limit, verbosity=verbosity)
149
+
150
+ def get_hyperparameters(self) -> dict:
151
+ """Gets params that are passed to the inner model."""
152
+ init_args = super().get_hyperparameters()
153
+
154
+ fine_tune_trainer_kwargs = dict(disable_tqdm=True)
155
+ user_fine_tune_trainer_kwargs = init_args.get("fine_tune_trainer_kwargs", {})
156
+ fine_tune_trainer_kwargs.update(user_fine_tune_trainer_kwargs)
157
+ init_args["fine_tune_trainer_kwargs"] = fine_tune_trainer_kwargs
158
+
159
+ return init_args.copy()
160
+
161
+ def _get_default_hyperparameters(self) -> dict:
162
+ return {
163
+ "model_path": "autogluon/chronos-2",
164
+ "batch_size": 256,
165
+ "device": None,
166
+ "cross_learning": True,
167
+ "context_length": None,
168
+ "fine_tune": False,
169
+ "fine_tune_mode": "lora",
170
+ "fine_tune_lr": 1e-5,
171
+ "fine_tune_steps": 1000,
172
+ "fine_tune_batch_size": 32,
173
+ "fine_tune_context_length": 2048,
174
+ "eval_during_fine_tune": False,
175
+ "fine_tune_eval_max_items": 256,
176
+ "fine_tune_lora_config": None,
177
+ "revision": None,
178
+ }
179
+
180
+ @property
181
+ def allowed_hyperparameters(self) -> list[str]:
182
+ return super().allowed_hyperparameters + [
183
+ "model_path",
184
+ "batch_size",
185
+ "device",
186
+ "cross_learning",
187
+ "context_length",
188
+ "fine_tune",
189
+ "fine_tune_mode",
190
+ "fine_tune_lr",
191
+ "fine_tune_steps",
192
+ "fine_tune_batch_size",
193
+ "fine_tune_context_length",
194
+ "eval_during_fine_tune",
195
+ "fine_tune_eval_max_items",
196
+ "fine_tune_lora_config",
197
+ "fine_tune_trainer_kwargs",
198
+ "revision",
199
+ ]
200
+
201
+ def _predict(
202
+ self,
203
+ data: TimeSeriesDataFrame,
204
+ known_covariates: TimeSeriesDataFrame | None = None,
205
+ **kwargs,
206
+ ) -> TimeSeriesDataFrame:
207
+ if self._model_pipeline is None:
208
+ self.load_model_pipeline()
209
+ assert self._model_pipeline is not None
210
+
211
+ if max(data.num_timesteps_per_item()) < 3:
212
+ # If all time series have length 2 or less, we prepend 2 dummy timesteps to the first series
213
+ first_item_id = data.index.get_level_values(0)[0]
214
+ dummy_timestamps = pd.date_range(end=data.loc[first_item_id].index[0], periods=3, freq=self.freq)[:-1]
215
+ full_time_index_first_item = data.loc[first_item_id].index.union(dummy_timestamps)
216
+ new_index = (
217
+ pd.MultiIndex.from_product([[first_item_id], full_time_index_first_item], names=data.index.names)
218
+ ).union(data.index)
219
+ context_df = data.reindex(new_index).reset_index()
220
+ else:
221
+ context_df = data.reset_index().to_data_frame()
222
+
223
+ batch_size = self.get_hyperparameter("batch_size")
224
+ cross_learning = self.get_hyperparameter("cross_learning")
225
+ context_length = self.get_hyperparameter("context_length")
226
+ future_df = known_covariates.reset_index().to_data_frame() if known_covariates is not None else None
227
+
228
+ forecast_df = self._model_pipeline.predict_df(
229
+ df=context_df,
230
+ future_df=future_df,
231
+ target=self.target,
232
+ prediction_length=self.prediction_length,
233
+ quantile_levels=self.quantile_levels,
234
+ context_length=context_length,
235
+ batch_size=batch_size,
236
+ validate_inputs=False,
237
+ cross_learning=cross_learning,
238
+ )
239
+
240
+ forecast_df = forecast_df.rename(columns={"predictions": "mean"}).drop(columns="target_name")
241
+
242
+ return TimeSeriesDataFrame(forecast_df)
243
+
244
+ def load_model_pipeline(self):
245
+ from chronos.chronos2.pipeline import Chronos2Pipeline
246
+
247
+ device = (self.get_hyperparameter("device") or "cuda") if self._is_gpu_available() else "cpu"
248
+
249
+ assert self.model_path is not None
250
+ pipeline = Chronos2Pipeline.from_pretrained(
251
+ self.model_path,
252
+ device_map=device,
253
+ revision=self.get_hyperparameter("revision"),
254
+ )
255
+
256
+ self._model_pipeline = pipeline
257
+
258
+ def persist(self) -> Self:
259
+ self.load_model_pipeline()
260
+ return self
261
+
262
+ def _update_transformers_loggers(self, log_level: int):
263
+ for logger_name in logging.root.manager.loggerDict:
264
+ if "transformers" in logger_name:
265
+ transformers_logger = logging.getLogger(logger_name)
266
+ transformers_logger.setLevel(log_level)
267
+
268
+ def _fine_tune(
269
+ self,
270
+ train_data: TimeSeriesDataFrame,
271
+ val_data: TimeSeriesDataFrame | None,
272
+ time_limit: float | None = None,
273
+ verbosity: int = 2,
274
+ ):
275
+ from chronos.df_utils import convert_df_input_to_list_of_dicts_input
276
+
277
+ from .utils import LoggerCallback, TimeLimitCallback
278
+
279
+ def convert_data(df: TimeSeriesDataFrame):
280
+ inputs, _, _ = convert_df_input_to_list_of_dicts_input(
281
+ df=df.reset_index().to_data_frame(),
282
+ future_df=None,
283
+ target_columns=[self.target],
284
+ prediction_length=self.prediction_length,
285
+ validate_inputs=False,
286
+ )
287
+
288
+ # The above utility will only split the dataframe into target and past_covariates, where past_covariates contains
289
+ # past values of both past-only and known-future covariates. We need to add future_covariates to enable fine-tuning
290
+ # with known covariates by indicating which covariates are known in the future.
291
+ known_covariates = self.covariate_metadata.known_covariates
292
+
293
+ if len(known_covariates) > 0:
294
+ for input_dict in inputs:
295
+ # NOTE: the covariates are empty because the actual values are not used
296
+ # This only indicates which covariates are known in the future
297
+ input_dict["future_covariates"] = {name: np.array([]) for name in known_covariates}
298
+
299
+ return inputs
300
+
301
+ assert self._model_pipeline is not None
302
+ hyperparameters = self.get_hyperparameters()
303
+
304
+ callbacks = []
305
+ if time_limit is not None:
306
+ callbacks.append(TimeLimitCallback(time_limit=time_limit))
307
+
308
+ val_inputs = None
309
+ if val_data is not None and hyperparameters["eval_during_fine_tune"]:
310
+ # evaluate on a randomly-sampled subset
311
+ fine_tune_eval_max_items = (
312
+ min(val_data.num_items, hyperparameters["fine_tune_eval_max_items"])
313
+ if hyperparameters["fine_tune_eval_max_items"] is not None
314
+ else val_data.num_items
315
+ )
316
+
317
+ if fine_tune_eval_max_items < val_data.num_items:
318
+ eval_items = np.random.choice(val_data.item_ids.values, size=fine_tune_eval_max_items, replace=False) # noqa: F841
319
+ val_data = val_data.query("item_id in @eval_items")
320
+
321
+ assert isinstance(val_data, TimeSeriesDataFrame)
322
+ val_inputs = convert_data(val_data)
323
+
324
+ if verbosity >= 3:
325
+ logger.warning(
326
+ "Transformers logging is turned on during fine-tuning. Note that losses reported by transformers "
327
+ "do not correspond to those specified via `eval_metric`."
328
+ )
329
+ callbacks.append(LoggerCallback())
330
+
331
+ self._model_pipeline = self._model_pipeline.fit(
332
+ inputs=convert_data(train_data),
333
+ prediction_length=self.prediction_length,
334
+ validation_inputs=val_inputs,
335
+ finetune_mode=hyperparameters["fine_tune_mode"],
336
+ lora_config=hyperparameters["fine_tune_lora_config"],
337
+ context_length=hyperparameters["fine_tune_context_length"],
338
+ learning_rate=hyperparameters["fine_tune_lr"],
339
+ num_steps=hyperparameters["fine_tune_steps"],
340
+ batch_size=hyperparameters["fine_tune_batch_size"],
341
+ output_dir=self.path,
342
+ finetuned_ckpt_name=self.fine_tuned_ckpt_name,
343
+ callbacks=callbacks,
344
+ remove_printer_callback=True,
345
+ min_past=1,
346
+ **hyperparameters["fine_tune_trainer_kwargs"],
347
+ )
348
+ self._is_fine_tuned = True
349
+
350
+ def _more_tags(self) -> dict[str, Any]:
351
+ do_fine_tune = self.get_hyperparameter("fine_tune")
352
+ return {
353
+ "allow_nan": True,
354
+ "can_use_train_data": do_fine_tune,
355
+ "can_use_val_data": do_fine_tune,
356
+ }
357
+
358
+ def _is_gpu_available(self) -> bool:
359
+ import torch.cuda
360
+
361
+ return torch.cuda.is_available()