autogluon.core 1.2.1b20250115__py3-none-any.whl → 1.2.1b20250130__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,39 +4,43 @@ import copy
4
4
  import logging
5
5
  import os
6
6
  import shutil
7
- import sys
8
7
  import time
9
8
  import traceback
10
- import typing
11
9
  from collections import defaultdict
12
10
  from pathlib import Path
13
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
11
+ from typing import Any, Generic, Literal, Optional, Type, TypeVar
14
12
 
15
13
  import networkx as nx
16
14
  import numpy as np
17
15
  import pandas as pd
16
+ from typing_extensions import Self
18
17
 
19
18
  from autogluon.common.features.feature_metadata import FeatureMetadata
20
19
  from autogluon.common.features.types import R_FLOAT, S_STACK
21
20
  from autogluon.common.utils.distribute_utils import DistributedContext
22
21
  from autogluon.common.utils.lite import disable_if_lite_mode
23
22
  from autogluon.common.utils.log_utils import convert_time_in_s_to_log_friendly, reset_logger_for_remote_call
24
- from autogluon.common.utils.path_converter import PathConverter
25
23
  from autogluon.common.utils.resource_utils import ResourceManager, get_resource_manager
26
24
  from autogluon.common.utils.try_import import try_import_ray, try_import_torch
27
-
28
- from ..augmentation.distill_utils import augment_data, format_distillation_labels
29
- from ..calibrate import calibrate_decision_threshold
30
- from ..calibrate.conformity_score import compute_conformity_score
31
- from ..calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling
32
- from ..callbacks import AbstractCallback
33
- from ..constants import AG_ARGS, BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REFIT_FULL_SUFFIX, REGRESSION, SOFTCLASS
34
- from ..data.label_cleaner import LabelCleanerMulticlassToBinary
35
- from ..metrics import compute_metric, Scorer, get_metric
36
- from ..models import AbstractModel, BaggedEnsembleModel, GreedyWeightedEnsembleModel, SimpleWeightedEnsembleModel, StackerEnsembleModel, WeightedEnsembleModel
37
- from ..pseudolabeling.pseudolabeling import assert_pseudo_column_match
38
- from ..ray.distributed_jobs_managers import ParallelFitManager
39
- from ..utils import (
25
+ from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels
26
+ from autogluon.core.calibrate import calibrate_decision_threshold
27
+ from autogluon.core.calibrate.conformity_score import compute_conformity_score
28
+ from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling
29
+ from autogluon.core.callbacks import AbstractCallback
30
+ from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS
31
+ from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary
32
+ from autogluon.core.metrics import Scorer, compute_metric, get_metric
33
+ from autogluon.core.models import (
34
+ AbstractModel,
35
+ BaggedEnsembleModel,
36
+ GreedyWeightedEnsembleModel,
37
+ SimpleWeightedEnsembleModel,
38
+ StackerEnsembleModel,
39
+ WeightedEnsembleModel,
40
+ )
41
+ from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match
42
+ from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager
43
+ from autogluon.core.utils import (
40
44
  compute_permutation_feature_importance,
41
45
  convert_pred_probas_to_df,
42
46
  default_holdout_frac,
@@ -45,18 +49,208 @@ from ..utils import (
45
49
  get_pred_from_proba,
46
50
  infer_eval_metric,
47
51
  )
48
- from ..utils.exceptions import InsufficientTime, NoGPUError, NotEnoughCudaMemoryError, NotEnoughMemoryError, NotValidStacker, NoStackFeatures, NoValidFeatures, TimeLimitExceeded
49
- from ..utils.feature_selection import FeatureSelector
50
- from ..utils.loaders import load_pkl
51
- from ..utils.savers import save_json, save_pkl
52
+ from autogluon.core.utils.exceptions import (
53
+ InsufficientTime,
54
+ NoGPUError,
55
+ NoStackFeatures,
56
+ NotEnoughCudaMemoryError,
57
+ NotEnoughMemoryError,
58
+ NotValidStacker,
59
+ NoValidFeatures,
60
+ TimeLimitExceeded,
61
+ )
62
+ from autogluon.core.utils.feature_selection import FeatureSelector
63
+ from autogluon.core.utils.loaders import load_pkl
64
+ from autogluon.core.utils.savers import save_json, save_pkl
65
+
52
66
  from .utils import process_hyperparameters
53
67
 
54
68
  logger = logging.getLogger(__name__)
55
69
 
56
70
 
57
- class AbstractTrainer:
71
+ ModelTypeT = TypeVar("ModelTypeT", bound=AbstractModel)
72
+
73
+
74
+ class AbstractTrainer(Generic[ModelTypeT]):
75
+ trainer_file_name = "trainer.pkl"
76
+ trainer_info_name = "info.pkl"
77
+ trainer_info_json_name = "info.json"
78
+
79
+ def __init__(self, path: str, *, low_memory: bool, save_data: bool):
80
+ self.path = path
81
+ self.reset_paths = False
82
+
83
+ self.low_memory: bool = low_memory
84
+ self.save_data: bool = save_data
85
+
86
+ self.models: dict[str, Any] = {}
87
+ self.model_graph = nx.DiGraph()
88
+ self.model_best: str | None = None
89
+
90
+ self._extra_banned_names: set[str] = set()
91
+
92
+ def _get_banned_model_names(self) -> list[str]:
93
+ """Gets all model names which would cause model files to be overwritten if a new model
94
+ was trained with the name
95
+ """
96
+ return self.get_model_names() + list(self._extra_banned_names)
97
+
98
+ @property
99
+ def path_root(self) -> str:
100
+ """directory containing learner.pkl"""
101
+ return os.path.dirname(self.path)
102
+
103
+ @property
104
+ def path_utils(self) -> str:
105
+ return os.path.join(self.path_root, "utils")
106
+
107
+ @property
108
+ def path_data(self) -> str:
109
+ return os.path.join(self.path_utils, "data")
110
+
111
+ def set_contexts(self, path_context: str) -> None:
112
+ self.path = self.create_contexts(path_context)
113
+
114
+ def create_contexts(self, path_context: str) -> str:
115
+ path = path_context
116
+ return path
117
+
118
+ def save_model(self, model: ModelTypeT, **kwargs) -> None:
119
+ model.save()
120
+ if not self.low_memory:
121
+ self.models[model.name] = model
122
+
123
+ def get_models_attribute_dict(self, attribute: str, models: list[str] | None = None) -> dict[str, Any]:
124
+ raise NotImplementedError
125
+
126
+ def get_model_attribute(self, model: str | ModelTypeT, attribute: str, **kwargs) -> Any:
127
+ """Return model attribute value.
128
+ If `default` is specified, return default value if attribute does not exist.
129
+ If `default` is not specified, raise ValueError if attribute does not exist.
130
+ """
131
+ if not isinstance(model, str):
132
+ model = model.name
133
+ if model not in self.model_graph.nodes:
134
+ raise ValueError(f"Model does not exist: (model={model})")
135
+ if attribute not in self.model_graph.nodes[model]:
136
+ if "default" in kwargs:
137
+ return kwargs["default"]
138
+ else:
139
+ raise ValueError(f"Model does not contain attribute: (model={model}, attribute={attribute})")
140
+ if attribute == "path":
141
+ return os.path.join(*self.model_graph.nodes[model][attribute])
142
+ return self.model_graph.nodes[model][attribute]
143
+
144
+ def set_model_attribute(self, model: str | ModelTypeT, attribute: str, val: Any):
145
+ if not isinstance(model, str):
146
+ model = model.name
147
+ self.model_graph.nodes[model][attribute] = val
148
+
149
+ def get_minimum_model_set(self, model: str | ModelTypeT, include_self: bool = True) -> list:
150
+ """Gets the minimum set of models that the provided model depends on, including itself
151
+ Returns a list of model names
152
+ """
153
+ if not isinstance(model, str):
154
+ model = model.name
155
+ minimum_model_set = list(nx.bfs_tree(self.model_graph, model, reverse=True))
156
+ if not include_self:
157
+ minimum_model_set = [m for m in minimum_model_set if m != model]
158
+ return minimum_model_set
159
+
160
+ def get_model_info(self, model: str | ModelTypeT) -> dict[str, Any]:
161
+ if isinstance(model, str):
162
+ if model in self.models.keys():
163
+ model = self.models[model]
164
+ if isinstance(model, str):
165
+ model_type = self.get_model_attribute(model=model, attribute="type")
166
+ model_path = self.get_model_attribute(model=model, attribute="path")
167
+ model_info = model_type.load_info(path=os.path.join(self.path, model_path))
168
+ else:
169
+ model_info = model.get_info()
170
+ return model_info
171
+
172
+ def get_model_names(self, **kwargs) -> list[str]:
173
+ """Get all model names that are registered in the model graph, in no particular order."""
174
+ return list(self.model_graph.nodes)
175
+
176
+ def get_models_info(self, models: list[str | ModelTypeT] | None = None) -> dict[str, dict[str, Any]]:
177
+ models_ = self.get_model_names() if models is None else models
178
+ model_info_dict = dict()
179
+ for model in models_:
180
+ model_name = model if isinstance(model, str) else model.name
181
+ model_info_dict[model_name] = self.get_model_info(model=model)
182
+ return model_info_dict
183
+
184
+ # TODO: model_name change to model in params
185
+ def load_model(self, model_name: str | ModelTypeT, path: str | None = None, model_type: Type[ModelTypeT] | None = None) -> ModelTypeT:
186
+ if isinstance(model_name, AbstractModel):
187
+ return model_name
188
+ if model_name in self.models.keys():
189
+ return self.models[model_name]
190
+ else:
191
+ if path is None:
192
+ path = self.get_model_attribute(model=model_name, attribute="path") # get relative location of the model to the trainer
193
+ assert path is not None
194
+ if model_type is None:
195
+ model_type = self.get_model_attribute(model=model_name, attribute="type")
196
+ assert model_type is not None
197
+ return model_type.load(path=os.path.join(self.path, path), reset_paths=self.reset_paths)
198
+
199
+ @classmethod
200
+ def load_info(cls, path: str, reset_paths: bool = False, load_model_if_required: bool = True) -> dict[str, Any]:
201
+ load_path = os.path.join(path, cls.trainer_info_name)
202
+ try:
203
+ return load_pkl.load(path=load_path)
204
+ except:
205
+ if load_model_if_required:
206
+ trainer = cls.load(path=path, reset_paths=reset_paths)
207
+ return trainer.get_info()
208
+ else:
209
+ raise
210
+
211
+ def save_info(self, include_model_info: bool = False) -> dict[str, Any]:
212
+ info = self.get_info(include_model_info=include_model_info)
213
+
214
+ save_pkl.save(path=os.path.join(self.path, self.trainer_info_name), object=info)
215
+ save_json.save(path=os.path.join(self.path, self.trainer_info_json_name), obj=info)
216
+ return info
217
+
218
+ def construct_model_templates(
219
+ self, hyperparameters: str | dict[str, Any], **kwargs
220
+ ) -> tuple[list[ModelTypeT], dict] | list[ModelTypeT]:
221
+ raise NotImplementedError
222
+
223
+ def get_model_best(self, *args, **kwargs) -> str:
224
+ raise NotImplementedError
225
+
226
+ def get_info(self, include_model_info: bool = False, **kwargs) -> dict[str, Any]:
227
+ raise NotImplementedError
228
+
229
+ def save(self) -> None:
230
+ raise NotImplementedError
231
+
232
+ @classmethod
233
+ def load(cls, path: str, reset_paths: bool = False) -> Self:
234
+ load_path = os.path.join(path, cls.trainer_file_name)
235
+ if not reset_paths:
236
+ return load_pkl.load(path=load_path)
237
+ else:
238
+ obj = load_pkl.load(path=load_path)
239
+ obj.set_contexts(path)
240
+ obj.reset_paths = reset_paths
241
+ return obj
242
+
243
+ def fit(self, *args, **kwargs):
244
+ raise NotImplementedError
245
+
246
+ def predict(self, *args, **kwargs) -> Any:
247
+ raise NotImplementedError
248
+
249
+
250
+ # TODO: This class will be moved to autogluon.tabular
251
+ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
58
252
  """
59
- AbstractTrainer contains logic to train a variety of models under a variety of constraints and automatically generate a multi-layer stack ensemble.
253
+ AbstractTabularTrainer contains logic to train a variety of models under a variety of constraints and automatically generate a multi-layer stack ensemble.
60
254
  Beyond the basic functionality, it also has support for model refitting, distillation, pseudo-labelling, unlabeled data, and much more.
61
255
 
62
256
  It is not recommended to directly use Trainer. Instead, use Predictor or Learner which internally uses Trainer.
@@ -65,7 +259,7 @@ class AbstractTrainer:
65
259
  Due to the complexity of the logic within this class, a text description will not give the full picture.
66
260
  It is recommended to carefully read the code and use a debugger to understand how it works.
67
261
 
68
- AbstractTrainer makes much fewer assumptions about the problem than Learner and Predictor.
262
+ AbstractTabularTrainer makes much fewer assumptions about the problem than Learner and Predictor.
69
263
  It expects these ambiguities to have already been resolved upstream. For example, problem_type, feature_metadata, num_classes, etc.
70
264
 
71
265
  Parameters
@@ -84,7 +278,7 @@ class AbstractTrainer:
84
278
  FeatureMetadata for X. Sent to each model during fit.
85
279
  eval_metric : Scorer, default = None
86
280
  Metric to optimize. If None, a default metric is used depending on the problem_type.
87
- quantile_levels : List[float] | np.ndarray, default = None
281
+ quantile_levels : list[float] | np.ndarray, default = None
88
282
  # TODO: Add documentation, not documented in Predictor.
89
283
  Only used when problem_type=quantile
90
284
  low_memory : bool, default = True
@@ -116,9 +310,6 @@ class AbstractTrainer:
116
310
  where `L` ranges from 0 to 50 (Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels).
117
311
  """
118
312
 
119
- trainer_file_name = "trainer.pkl"
120
- trainer_info_name = "info.pkl"
121
- trainer_info_json_name = "info.json"
122
313
  distill_stackname = "distill" # name of stack-level for distilled student models
123
314
 
124
315
  def __init__(
@@ -126,22 +317,26 @@ class AbstractTrainer:
126
317
  path: str,
127
318
  *,
128
319
  problem_type: str,
129
- num_classes: int = None,
130
- feature_metadata: FeatureMetadata = None,
131
- eval_metric: Scorer = None,
132
- quantile_levels: List[float] | np.ndarray = None,
320
+ num_classes: int | None = None,
321
+ feature_metadata: FeatureMetadata | None = None,
322
+ eval_metric: Scorer | None = None,
323
+ quantile_levels: list[float] | np.ndarray | None = None,
133
324
  low_memory: bool = True,
134
325
  k_fold: int = 0,
135
326
  n_repeats: int = 1,
136
- sample_weight: str = None,
327
+ sample_weight: str | None = None,
137
328
  weight_evaluation: bool = False,
138
329
  save_data: bool = False,
139
330
  random_state: int = 0,
140
331
  verbosity: int = 2,
141
332
  ):
333
+ super().__init__(
334
+ path=path,
335
+ low_memory=low_memory,
336
+ save_data=save_data,
337
+ )
142
338
  self._validate_num_classes(num_classes=num_classes, problem_type=problem_type)
143
339
  self._validate_quantile_levels(quantile_levels=quantile_levels, problem_type=problem_type)
144
- self.path = path
145
340
  self.problem_type = problem_type
146
341
  self.feature_metadata = feature_metadata
147
342
  self.save_data = save_data
@@ -185,7 +380,7 @@ class AbstractTrainer:
185
380
 
186
381
  self.model_best = None
187
382
 
188
- self.models = {} # Dict of model name -> model object. A key, value pair only exists if a model is persisted in memory. # TODO: v0.1 Rename and consider making private
383
+ self.models = {} # dict of model name -> model object. A key, value pair only exists if a model is persisted in memory. # TODO: v0.1 Rename and consider making private
189
384
  self.model_graph = nx.DiGraph() # Directed Acyclic Graph (DAG) of model interactions. Describes how certain models depend on the predictions of certain other models. Contains numerous metadata regarding each model.
190
385
  self.reset_paths = False
191
386
 
@@ -210,31 +405,18 @@ class AbstractTrainer:
210
405
 
211
406
  self._extra_banned_names = set() # Names which are banned but are not used by a trained model.
212
407
 
213
- self._models_failed_to_train_errors = dict() # Dict of model name -> model failure metadata
408
+ self._models_failed_to_train_errors = dict() # dict of model name -> model failure metadata
214
409
 
215
410
  # self._exceptions_list = [] # TODO: Keep exceptions list for debugging during benchmarking.
216
411
 
217
- self.callbacks: List[AbstractCallback] = []
412
+ self.callbacks: list[AbstractCallback] = []
218
413
  self._callback_early_stop = False
219
414
 
220
- # path_root is the directory containing learner.pkl
221
- @property
222
- def path_root(self) -> str:
223
- return os.path.dirname(self.path)
224
-
225
- @property
226
- def path_utils(self) -> str:
227
- return os.path.join(self.path_root, "utils")
228
-
229
415
  @property
230
416
  def _path_attr(self) -> str:
231
417
  """Path to cached model graph attributes"""
232
418
  return os.path.join(self.path_utils, "attr")
233
419
 
234
- @property
235
- def path_data(self) -> str:
236
- return os.path.join(self.path_utils, "data")
237
-
238
420
  @property
239
421
  def has_val(self) -> bool:
240
422
  """Whether the trainer uses validation data"""
@@ -246,6 +428,7 @@ class AbstractTrainer:
246
428
  if self._num_rows_val is not None:
247
429
  return self._num_rows_val
248
430
  elif self.bagged_mode:
431
+ assert self._num_rows_train is not None
249
432
  return self._num_rows_train
250
433
  else:
251
434
  return 0
@@ -334,8 +517,12 @@ class AbstractTrainer:
334
517
  self._y_test_saved = True
335
518
 
336
519
  def get_model_names(
337
- self, stack_name: Union[List[str], str] = None, level: Union[List[int], int] = None, can_infer: bool = None, models: List[str] = None
338
- ) -> List[str]:
520
+ self,
521
+ stack_name: list[str] | str | None = None,
522
+ level: list[int] | int | None = None,
523
+ can_infer: bool | None = None,
524
+ models: list[str] | None = None
525
+ ) -> list[str]:
339
526
  if models is None:
340
527
  models = list(self.model_graph.nodes)
341
528
  if stack_name is not None:
@@ -354,7 +541,7 @@ class AbstractTrainer:
354
541
  models = [model for model in models if node_attributes[model] == can_infer]
355
542
  return models
356
543
 
357
- def get_max_level(self, stack_name: str = None, models: List[str] = None) -> int:
544
+ def get_max_level(self, stack_name: str | None = None, models: list[str] | None = None) -> int:
358
545
  models = self.get_model_names(stack_name=stack_name, models=models)
359
546
  models_attribute_dict = self.get_models_attribute_dict(attribute="level", models=models)
360
547
  if models_attribute_dict:
@@ -362,25 +549,17 @@ class AbstractTrainer:
362
549
  else:
363
550
  return -1
364
551
 
365
- def construct_model_templates(self, hyperparameters: dict, **kwargs) -> Tuple[List[AbstractModel], dict]:
552
+ def construct_model_templates(self, hyperparameters: dict, **kwargs) -> tuple[list[AbstractModel], dict]:
366
553
  """Constructs a list of unfit models based on the hyperparameters dict."""
367
554
  raise NotImplementedError
368
555
 
369
- def construct_model_templates_distillation(self, hyperparameters: dict, **kwargs) -> Tuple[List[AbstractModel], dict]:
556
+ def construct_model_templates_distillation(self, hyperparameters: dict, **kwargs) -> tuple[list[AbstractModel], dict]:
370
557
  """Constructs a list of unfit models based on the hyperparameters dict for softclass distillation."""
371
558
  raise NotImplementedError
372
559
 
373
560
  def get_model_level(self, model_name: str) -> int:
374
561
  return self.get_model_attribute(model=model_name, attribute="level")
375
562
 
376
- def set_contexts(self, path_context):
377
- self.path = self.create_contexts(path_context)
378
-
379
- def create_contexts(self, path_context: str) -> str:
380
- path = path_context
381
-
382
- return path
383
-
384
563
  def fit(self, X, y, hyperparameters: dict, X_val=None, y_val=None, **kwargs):
385
564
  raise NotImplementedError
386
565
 
@@ -395,19 +574,19 @@ class AbstractTrainer:
395
574
  X_test=None,
396
575
  y_test=None,
397
576
  X_unlabeled=None,
398
- base_model_names: List[str] = None,
399
- core_kwargs: dict = None,
400
- aux_kwargs: dict = None,
577
+ base_model_names: list[str] | None = None,
578
+ core_kwargs: dict | None = None,
579
+ aux_kwargs: dict | None = None,
401
580
  level_start=1,
402
581
  level_end=1,
403
582
  time_limit=None,
404
- name_suffix: str = None,
583
+ name_suffix: str | None = None,
405
584
  relative_stack=True,
406
585
  level_time_modifier=0.333,
407
586
  infer_limit=None,
408
587
  infer_limit_batch_size=None,
409
- callbacks: List[AbstractCallback] = None,
410
- ) -> List[str]:
588
+ callbacks: list[AbstractCallback] | None = None,
589
+ ) -> list[str]:
411
590
  """
412
591
  Trains a multi-layer stack ensemble using the input data on the hyperparameters dict input.
413
592
  hyperparameters is used to determine the models used in each stack layer.
@@ -426,6 +605,8 @@ class AbstractTrainer:
426
605
  """
427
606
  self._fit_setup(time_limit=time_limit, callbacks=callbacks)
428
607
  time_train_start = self._time_train_start
608
+ assert time_train_start is not None
609
+
429
610
  if self.callbacks:
430
611
  callback_classes = [c.__class__.__name__ for c in self.callbacks]
431
612
  logger.log(20, f"User-specified callbacks ({len(self.callbacks)}): {callback_classes}")
@@ -517,7 +698,7 @@ class AbstractTrainer:
517
698
  self.save()
518
699
  return model_names_fit
519
700
 
520
- def _fit_setup(self, time_limit: float | None = None, callbacks: List[AbstractCallback] = None):
701
+ def _fit_setup(self, time_limit: float | None = None, callbacks: list[AbstractCallback] | None = None):
521
702
  """
522
703
  Prepare the trainer state at the start of / prior to a fit call.
523
704
  Should be paired with a `self._fit_cleanup()` at the conclusion of the fit call.
@@ -561,8 +742,13 @@ class AbstractTrainer:
561
742
 
562
743
  # TODO: Consider better greedy approximation method such as via fitting a weighted ensemble to evaluate the value of a subset.
563
744
  def _filter_base_models_via_infer_limit(
564
- self, base_model_names: List[str], infer_limit: float, infer_limit_modifier: float = 1.0, as_child: bool = True, verbose: bool = True
565
- ) -> List[str]:
745
+ self,
746
+ base_model_names: list[str],
747
+ infer_limit: float | None,
748
+ infer_limit_modifier: float = 1.0,
749
+ as_child: bool = True,
750
+ verbose: bool = True,
751
+ ) -> list[str]:
566
752
  """
567
753
  Returns a subset of base_model_names whose combined prediction time for 1 row of data does not exceed infer_limit seconds.
568
754
  With the goal of selecting the best valid subset that is most valuable to stack ensembles who use them as base models,
@@ -572,9 +758,9 @@ class AbstractTrainer:
572
758
 
573
759
  Parameters
574
760
  ----------
575
- base_model_names: List[str]
576
- List of model names. These models must already be added to the trainer.
577
- infer_limit: float
761
+ base_model_names: list[str]
762
+ list of model names. These models must already be added to the trainer.
763
+ infer_limit: float, optional
578
764
  Inference limit in seconds for 1 row of data. This is compared against values pre-computed during fit for the models.
579
765
  infer_limit_modifier: float, default = 1.0
580
766
  Modifier to multiply infer_limit by.
@@ -655,22 +841,22 @@ class AbstractTrainer:
655
841
  self,
656
842
  X,
657
843
  y,
658
- models: Union[List[AbstractModel], dict],
844
+ models: list[AbstractModel] | dict,
659
845
  X_val=None,
660
846
  y_val=None,
661
847
  X_test=None,
662
848
  y_test=None,
663
849
  X_unlabeled=None,
664
850
  level=1,
665
- base_model_names: List[str] = None,
666
- core_kwargs: dict = None,
667
- aux_kwargs: dict = None,
668
- name_suffix: str = None,
851
+ base_model_names: list[str] | None = None,
852
+ core_kwargs: dict | None = None,
853
+ aux_kwargs: dict | None = None,
854
+ name_suffix: str | None = None,
669
855
  infer_limit=None,
670
856
  infer_limit_batch_size=None,
671
857
  full_weighted_ensemble: bool = False,
672
858
  additional_full_weighted_ensemble: bool = False,
673
- ) -> (List[str], List[str]):
859
+ ) -> tuple[list[str], list[str]]:
674
860
  """
675
861
  Similar to calling self.stack_new_level_core, except auxiliary models will also be trained via a call to self.stack_new_level_aux, with the models trained from self.stack_new_level_core used as base models.
676
862
  """
@@ -718,14 +904,14 @@ class AbstractTrainer:
718
904
  self,
719
905
  X,
720
906
  y,
721
- models: Union[List[AbstractModel], dict],
907
+ models: list[AbstractModel] | dict,
722
908
  X_val=None,
723
909
  y_val=None,
724
910
  X_test=None,
725
911
  y_test=None,
726
912
  X_unlabeled=None,
727
913
  level=1,
728
- base_model_names: List[str] = None,
914
+ base_model_names: list[str] | None = None,
729
915
  fit_strategy: Literal["sequential", "parallel"] = "sequential",
730
916
  stack_name="core",
731
917
  ag_args=None,
@@ -734,13 +920,13 @@ class AbstractTrainer:
734
920
  included_model_types=None,
735
921
  excluded_model_types=None,
736
922
  ensemble_type=StackerEnsembleModel,
737
- name_suffix: str = None,
923
+ name_suffix: str | None = None,
738
924
  get_models_func=None,
739
925
  refit_full=False,
740
926
  infer_limit=None,
741
927
  infer_limit_batch_size=None,
742
928
  **kwargs,
743
- ) -> List[str]:
929
+ ) -> list[str]:
744
930
  """
745
931
  Trains all models using the data provided.
746
932
  If level > 1, then the models will use base model predictions as additional features.
@@ -757,7 +943,11 @@ class AbstractTrainer:
757
943
  if not self.bagged_mode and level != 1:
758
944
  raise ValueError("Stack Ensembling is not valid for non-bagged mode.")
759
945
 
760
- base_model_names = self._filter_base_models_via_infer_limit(base_model_names=base_model_names, infer_limit=infer_limit, infer_limit_modifier=0.8)
946
+ base_model_names = self._filter_base_models_via_infer_limit(
947
+ base_model_names=base_model_names,
948
+ infer_limit=infer_limit,
949
+ infer_limit_modifier=0.8,
950
+ )
761
951
  if ag_args_fit is None:
762
952
  ag_args_fit = {}
763
953
  ag_args_fit = ag_args_fit.copy()
@@ -779,7 +969,7 @@ class AbstractTrainer:
779
969
  (base_model_names, base_model_paths, base_model_types) = (None, None, None)
780
970
  elif level > 1:
781
971
  base_model_names, base_model_paths, base_model_types = self._get_models_load_info(model_names=base_model_names)
782
- if len(base_model_names) == 0:
972
+ if len(base_model_names) == 0: # type: ignore
783
973
  logger.log(20, f"No base models to train on, skipping stack level {level}...")
784
974
  return []
785
975
  else:
@@ -874,7 +1064,7 @@ class AbstractTrainer:
874
1064
  self,
875
1065
  X,
876
1066
  y,
877
- base_model_names: List[str],
1067
+ base_model_names: list[str],
878
1068
  level: int | str = "auto",
879
1069
  fit=True,
880
1070
  stack_name="aux1",
@@ -888,7 +1078,7 @@ class AbstractTrainer:
888
1078
  fit_weighted_ensemble: bool = True,
889
1079
  name_extra: str | None = None,
890
1080
  total_resources: dict | None = None,
891
- ) -> List[str]:
1081
+ ) -> list[str]:
892
1082
  """
893
1083
  Trains auxiliary models (currently a single weighted ensemble) using the provided base models.
894
1084
  Level must be greater than the level of any of the base models.
@@ -1038,13 +1228,13 @@ class AbstractTrainer:
1038
1228
  )
1039
1229
 
1040
1230
  # TODO: Slow if large ensemble with many models, could cache output result to speed up during inference
1041
- def _construct_model_pred_order(self, models: List[str]) -> List[str]:
1231
+ def _construct_model_pred_order(self, models: list[str]) -> list[str]:
1042
1232
  """
1043
1233
  Constructs a list of model names in order of inference calls required to infer on all the models.
1044
1234
 
1045
1235
  Parameters
1046
1236
  ----------
1047
- models : List[str]
1237
+ models : list[str]
1048
1238
  The list of models to construct the prediction order from.
1049
1239
  If a model has dependencies, the dependency models will be put earlier in the output list.
1050
1240
  Models explicitly mentioned in the `models` input will be placed as early as possible in the output list.
@@ -1068,17 +1258,17 @@ class AbstractTrainer:
1068
1258
  model_set = set(model_order)
1069
1259
  return model_order
1070
1260
 
1071
- def _construct_model_pred_order_with_pred_dict(self, models: List[str], models_to_ignore: List[str] = None) -> List[str]:
1261
+ def _construct_model_pred_order_with_pred_dict(self, models: list[str], models_to_ignore: list[str] = None) -> list[str]:
1072
1262
  """
1073
1263
  Constructs a list of model names in order of inference calls required to infer on all the models.
1074
1264
  Unlike `_construct_model_pred_order`, this method's output is in undefined order when multiple models are valid to infer at the same time.
1075
1265
 
1076
1266
  Parameters
1077
1267
  ----------
1078
- models : List[str]
1268
+ models : list[str]
1079
1269
  The list of models to construct the prediction order from.
1080
1270
  If a model has dependencies, the dependency models will be put earlier in the output list.
1081
- models_to_ignore : List[str], optional
1271
+ models_to_ignore : list[str], optional
1082
1272
  A list of models that have already been computed and can be ignored.
1083
1273
  Models in this list and their dependencies (if not depended on by other models in `models`) will be pruned from the final output.
1084
1274
 
@@ -1109,7 +1299,23 @@ class AbstractTrainer:
1109
1299
 
1110
1300
  # Get model prediction order
1111
1301
  return list(nx.lexicographical_topological_sort(subgraph))
1112
-
1302
+
1303
+ def get_models_attribute_dict(self, attribute: str, models: list | None = None) -> dict[str, Any]:
1304
+ """Returns dictionary of model name -> attribute value for the provided attribute.
1305
+ """
1306
+ models_attribute_dict = nx.get_node_attributes(self.model_graph, attribute)
1307
+ if models is not None:
1308
+ model_names = []
1309
+ for model in models:
1310
+ if not isinstance(model, str):
1311
+ model = model.name
1312
+ model_names.append(model)
1313
+ if attribute == "path":
1314
+ models_attribute_dict = {key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names}
1315
+ else:
1316
+ models_attribute_dict = {key: val for key, val in models_attribute_dict.items() if key in model_names}
1317
+ return models_attribute_dict
1318
+
1113
1319
  # TODO: Consider adding persist to disk functionality for pred_proba dictionary to lessen memory burden on large multiclass problems.
1114
1320
  # For datasets with 100+ classes, this function could potentially run the system OOM due to each pred_proba numpy array taking significant amounts of space.
1115
1321
  # This issue already existed in the previous level-based version but only had the minimum required predictions in memory at a time, whereas this has all model predictions in memory.
@@ -1117,7 +1323,7 @@ class AbstractTrainer:
1117
1323
  def get_model_pred_proba_dict(
1118
1324
  self,
1119
1325
  X: pd.DataFrame,
1120
- models: List[str],
1326
+ models: list[str],
1121
1327
  model_pred_proba_dict: dict = None,
1122
1328
  model_pred_time_dict: dict = None,
1123
1329
  record_pred_time: bool = False,
@@ -1132,7 +1338,7 @@ class AbstractTrainer:
1132
1338
  ----------
1133
1339
  X : pd.DataFrame
1134
1340
  Input data to predict on.
1135
- models : List[str]
1341
+ models : list[str]
1136
1342
  The list of models to predict with.
1137
1343
  Note that if models have dependency models, their dependencies will also be predicted with and included in the output.
1138
1344
  model_pred_proba_dict : dict, optional
@@ -1194,13 +1400,13 @@ class AbstractTrainer:
1194
1400
  else:
1195
1401
  return model_pred_proba_dict
1196
1402
 
1197
- def get_model_oof_dict(self, models: List[str]) -> dict:
1403
+ def get_model_oof_dict(self, models: list[str]) -> dict:
1198
1404
  """
1199
1405
  Returns a dictionary of out-of-fold prediction probabilities, keyed by model name
1200
1406
  """
1201
1407
  return {model: self.get_model_oof(model) for model in models}
1202
1408
 
1203
- def get_model_pred_dict(self, X: pd.DataFrame, models: List[str], record_pred_time: bool = False, **kwargs):
1409
+ def get_model_pred_dict(self, X: pd.DataFrame, models: list[str], record_pred_time: bool = False, **kwargs):
1204
1410
  """
1205
1411
  Optimally computes predictions for each model in `models`.
1206
1412
  Will compute each necessary model only once and store predictions in a `model_pred_dict` dictionary.
@@ -1212,7 +1418,7 @@ class AbstractTrainer:
1212
1418
  ----------
1213
1419
  X : pd.DataFrame
1214
1420
  Input data to predict on.
1215
- models : List[str]
1421
+ models : list[str]
1216
1422
  The list of models to predict with.
1217
1423
  Note that if models have dependency models, their dependencies will also be predicted with and included in the output.
1218
1424
  record_pred_time : bool, default = False
@@ -1293,8 +1499,8 @@ class AbstractTrainer:
1293
1499
  self,
1294
1500
  X: pd.DataFrame,
1295
1501
  *,
1296
- model: str = None,
1297
- base_models: List[str] = None,
1502
+ model: str | None = None,
1503
+ base_models: list[str] | None = None,
1298
1504
  model_pred_proba_dict: Optional[dict] = None,
1299
1505
  fit: bool = False,
1300
1506
  use_orig_features: bool = True,
@@ -1311,7 +1517,7 @@ class AbstractTrainer:
1311
1517
  model : str, default = None
1312
1518
  The model to derive `base_models` from.
1313
1519
  Cannot be specified alongside `base_models`.
1314
- base_models : List[str], default = None
1520
+ base_models : list[str], default = None
1315
1521
  The list of base models to augment X with.
1316
1522
  Base models will add their prediction probabilities as extra features to X.
1317
1523
  Cannot be specified alongside `model`.
@@ -1355,7 +1561,7 @@ class AbstractTrainer:
1355
1561
  X = X_stacker
1356
1562
  return X
1357
1563
 
1358
- def get_feature_metadata(self, use_orig_features: bool = True, model: str | None = None, base_models: List[str] | None = None) -> FeatureMetadata:
1564
+ def get_feature_metadata(self, use_orig_features: bool = True, model: str | None = None, base_models: list[str] | None = None) -> FeatureMetadata:
1359
1565
  """
1360
1566
  Returns the FeatureMetadata input to a `model.fit` call.
1361
1567
  Pairs with `X = self.get_inputs_to_stacker(...)`. The returned FeatureMetadata should reflect the contents of `X`.
@@ -1368,7 +1574,7 @@ class AbstractTrainer:
1368
1574
  model : str, default = None
1369
1575
  If specified, it must be an already existing model.
1370
1576
  `base_models` will be set to the base models of `model`.
1371
- base_models : List[str], default = None
1577
+ base_models : list[str], default = None
1372
1578
  If specified, will add the stack features of the `base_models` to FeatureMetadata.
1373
1579
 
1374
1580
  Returns
@@ -1397,7 +1603,7 @@ class AbstractTrainer:
1397
1603
  feature_metadata = FeatureMetadata(type_map_raw={})
1398
1604
  return feature_metadata
1399
1605
 
1400
- def _get_stack_column_names(self, models: List[str]) -> Tuple[List[str], int]:
1606
+ def _get_stack_column_names(self, models: list[str]) -> tuple[list[str], int]:
1401
1607
  """
1402
1608
  Get the stack column names generated when the provided models are used as base models in a stack ensemble.
1403
1609
  Additionally output the number of columns per model as an int.
@@ -1562,7 +1768,7 @@ class AbstractTrainer:
1562
1768
  # Fits _FULL models and links them in the stack so _FULL models only use other _FULL models as input during stacking
1563
1769
  # If model is specified, will fit all _FULL models that are ancestors of the provided model, automatically linking them.
1564
1770
  # If no model is specified, all models are refit and linked appropriately.
1565
- def refit_ensemble_full(self, model: str | List[str] = "all", **kwargs) -> dict:
1771
+ def refit_ensemble_full(self, model: str | list[str] = "all", **kwargs) -> dict:
1566
1772
  if model == "all":
1567
1773
  ensemble_set = self.get_model_names()
1568
1774
  elif isinstance(model, list):
@@ -1618,7 +1824,13 @@ class AbstractTrainer:
1618
1824
  """Get refit full model's parent. If model does not have a parent, return `model`."""
1619
1825
  return self.get_model_attribute(model=model, attribute="refit_full_parent", default=model)
1620
1826
 
1621
- def get_model_best(self, can_infer: bool = None, allow_full: bool = True, infer_limit: float = None, infer_limit_as_child: bool = False) -> str:
1827
+ def get_model_best(
1828
+ self,
1829
+ can_infer: bool | None = None,
1830
+ allow_full: bool = True,
1831
+ infer_limit: float | None = None,
1832
+ infer_limit_as_child: bool = False
1833
+ ) -> str:
1622
1834
  """
1623
1835
  Returns the name of the model with the best validation score that satisfies all specified constraints.
1624
1836
  If no model satisfies the constraints, an AssertionError will be raised.
@@ -1704,7 +1916,7 @@ class AbstractTrainer:
1704
1916
  else:
1705
1917
  self.models[model.name] = model
1706
1918
 
1707
- def save(self):
1919
+ def save(self) -> None:
1708
1920
  models = self.models
1709
1921
  if self.low_memory:
1710
1922
  self.models = {}
@@ -1712,7 +1924,7 @@ class AbstractTrainer:
1712
1924
  if self.low_memory:
1713
1925
  self.models = models
1714
1926
 
1715
- def compile(self, model_names="all", with_ancestors=False, compiler_configs=None) -> List[str]:
1927
+ def compile(self, model_names="all", with_ancestors=False, compiler_configs=None) -> list[str]:
1716
1928
  """
1717
1929
  Compile a list of models for accelerated prediction.
1718
1930
 
@@ -1786,7 +1998,7 @@ class AbstractTrainer:
1786
1998
  self.save()
1787
1999
  return model_names
1788
2000
 
1789
- def persist(self, model_names="all", with_ancestors=False, max_memory=None) -> List[str]:
2001
+ def persist(self, model_names="all", with_ancestors=False, max_memory=None) -> list[str]:
1790
2002
  if model_names == "all":
1791
2003
  model_names = self.get_model_names()
1792
2004
  elif model_names == "best":
@@ -1852,19 +2064,6 @@ class AbstractTrainer:
1852
2064
  model.models[fold] = model.load_child(fold_model)
1853
2065
  return model_names
1854
2066
 
1855
- # TODO: model_name change to model in params
1856
- def load_model(self, model_name: str, path: str = None, model_type=None) -> AbstractModel:
1857
- if isinstance(model_name, AbstractModel):
1858
- return model_name
1859
- if model_name in self.models.keys():
1860
- return self.models[model_name]
1861
- else:
1862
- if path is None:
1863
- path = self.get_model_attribute(model=model_name, attribute="path") # get relative location of the model to the trainer
1864
- if model_type is None:
1865
- model_type = self.get_model_attribute(model=model_name, attribute="type")
1866
- return model_type.load(path=os.path.join(self.path, path), reset_paths=self.reset_paths)
1867
-
1868
2067
  def unpersist(self, model_names="all") -> list:
1869
2068
  if model_names == "all":
1870
2069
  model_names = list(self.models.keys())
@@ -1893,13 +2092,13 @@ class AbstractTrainer:
1893
2092
  hyperparameters=None,
1894
2093
  ag_args_fit=None,
1895
2094
  time_limit=None,
1896
- name_suffix: str = None,
2095
+ name_suffix: str | None = None,
1897
2096
  save_bag_folds=None,
1898
2097
  check_if_best=True,
1899
2098
  child_hyperparameters=None,
1900
2099
  get_models_func=None,
1901
2100
  total_resources: dict | None = None,
1902
- ) -> List[str]:
2101
+ ) -> list[str]:
1903
2102
  if get_models_func is None:
1904
2103
  get_models_func = self.construct_model_templates
1905
2104
  if len(base_model_names) == 0:
@@ -1979,10 +2178,10 @@ class AbstractTrainer:
1979
2178
  X: pd.DataFrame,
1980
2179
  y: pd.Series,
1981
2180
  model: AbstractModel,
1982
- X_val: pd.DataFrame = None,
1983
- y_val: pd.Series = None,
1984
- X_test: pd.DataFrame = None,
1985
- y_test: pd.Series = None,
2181
+ X_val: pd.DataFrame | None = None,
2182
+ y_val: pd.Series | None = None,
2183
+ X_test: pd.DataFrame | None = None,
2184
+ y_test: pd.Series | None = None,
1986
2185
  total_resources: dict = None,
1987
2186
  **model_fit_kwargs,
1988
2187
  ) -> AbstractModel:
@@ -1998,20 +2197,20 @@ class AbstractTrainer:
1998
2197
  X: pd.DataFrame,
1999
2198
  y: pd.Series,
2000
2199
  model: AbstractModel,
2001
- X_val: pd.DataFrame = None,
2002
- y_val: pd.Series = None,
2003
- X_test: pd.DataFrame = None,
2004
- y_test: pd.Series = None,
2005
- X_pseudo: pd.DataFrame = None,
2006
- y_pseudo: pd.DataFrame = None,
2007
- time_limit: float = None,
2200
+ X_val: pd.DataFrame | None = None,
2201
+ y_val: pd.Series | None = None,
2202
+ X_test: pd.DataFrame | None = None,
2203
+ y_test: pd.Series | None = None,
2204
+ X_pseudo: pd.DataFrame | None = None,
2205
+ y_pseudo: pd.DataFrame | None = None,
2206
+ time_limit: float | None = None,
2008
2207
  stack_name: str = "core",
2009
2208
  level: int = 1,
2010
2209
  compute_score: bool = True,
2011
- total_resources: dict = None,
2210
+ total_resources: dict | None = None,
2012
2211
  errors: Literal["ignore", "raise"] = "ignore",
2013
- errors_ignore: list = None,
2014
- errors_raise: list = None,
2212
+ errors_ignore: list | None = None,
2213
+ errors_raise: list | None = None,
2015
2214
  is_ray_worker: bool = False,
2016
2215
  **model_fit_kwargs,
2017
2216
  ) -> list[str]:
@@ -2186,7 +2385,7 @@ class AbstractTrainer:
2186
2385
  return model_names_trained
2187
2386
 
2188
2387
  # FIXME: v1.0 Move to AbstractModel for most fields
2189
- def _get_model_metadata(self, model: AbstractModel, stack_name: str = "core", level: int = 1) -> Dict[str, Any]:
2388
+ def _get_model_metadata(self, model: AbstractModel, stack_name: str = "core", level: int = 1) -> dict[str, Any]:
2190
2389
  """
2191
2390
  Returns the model metadata used to initialize a node in the DAG (self.model_graph).
2192
2391
  """
@@ -2415,11 +2614,11 @@ class AbstractTrainer:
2415
2614
  compute_score=True,
2416
2615
  total_resources: dict | None = None,
2417
2616
  errors: Literal["ignore", "raise"] = "ignore",
2418
- errors_ignore: list = None,
2419
- errors_raise: list = None,
2617
+ errors_ignore: list | None = None,
2618
+ errors_raise: list | None = None,
2420
2619
  is_ray_worker: bool = False,
2421
2620
  **kwargs,
2422
- ) -> List[str]:
2621
+ ) -> list[str]:
2423
2622
  """
2424
2623
  Trains a model, with the potential to train multiple versions of this model with hyperparameter tuning and feature pruning.
2425
2624
  Returns a list of successfully trained and saved model names.
@@ -2562,7 +2761,7 @@ class AbstractTrainer:
2562
2761
  raise exception
2563
2762
  return model_names_trained
2564
2763
 
2565
- # TODO: Move to a utility function outside of AbstractTrainer
2764
+ # TODO: Move to a utility function outside of AbstractTabularTrainer
2566
2765
  @staticmethod
2567
2766
  def _check_raise_exception(
2568
2767
  exception: Exception,
@@ -2643,7 +2842,7 @@ class AbstractTrainer:
2643
2842
  def _callbacks_after_fit(
2644
2843
  self,
2645
2844
  *,
2646
- model_names: List[str],
2845
+ model_names: list[str],
2647
2846
  stack_name: str,
2648
2847
  level: int,
2649
2848
  ):
@@ -2662,7 +2861,7 @@ class AbstractTrainer:
2662
2861
  # TODO: Time allowance not accurate if running from fit_continue
2663
2862
  # TODO: Remove level and stack_name arguments, can get them automatically
2664
2863
  # TODO: Make sure that pretraining on X_unlabeled only happens 1 time rather than every fold of bagging. (Do during pretrain API work?)
2665
- def _train_multi_repeats(self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs) -> List[str]:
2864
+ def _train_multi_repeats(self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs) -> list[str]:
2666
2865
  """
2667
2866
  Fits bagged ensemble models with additional folds and/or bagged repeats.
2668
2867
  Models must have already been fit prior to entering this method.
@@ -2721,7 +2920,7 @@ class AbstractTrainer:
2721
2920
  return models_valid
2722
2921
 
2723
2922
  def _train_multi_initial(
2724
- self, X, y, models: List[AbstractModel], k_fold, n_repeats, hyperparameter_tune_kwargs=None, time_limit=None, feature_prune_kwargs=None, **kwargs
2923
+ self, X, y, models: list[AbstractModel], k_fold, n_repeats, hyperparameter_tune_kwargs=None, time_limit=None, feature_prune_kwargs=None, **kwargs
2725
2924
  ):
2726
2925
  """
2727
2926
  Fits models that have not previously been fit.
@@ -3137,7 +3336,7 @@ class AbstractTrainer:
3137
3336
  self,
3138
3337
  X,
3139
3338
  y,
3140
- models: List[AbstractModel],
3339
+ models: list[AbstractModel],
3141
3340
  hyperparameter_tune_kwargs=None,
3142
3341
  feature_prune_kwargs=None,
3143
3342
  k_fold=None,
@@ -3146,7 +3345,7 @@ class AbstractTrainer:
3146
3345
  time_limit=None,
3147
3346
  delay_bag_sets: bool = False,
3148
3347
  **kwargs,
3149
- ) -> List[str]:
3348
+ ) -> list[str]:
3150
3349
  """
3151
3350
  Train a list of models using the same data.
3152
3351
  Assumes that input data has already been processed in the form the models will receive as input (including stack feature generation).
@@ -3205,13 +3404,13 @@ class AbstractTrainer:
3205
3404
  y_val,
3206
3405
  X_test=None,
3207
3406
  y_test=None,
3208
- hyperparameters: dict = None,
3407
+ hyperparameters: dict | None = None,
3209
3408
  X_unlabeled=None,
3210
3409
  num_stack_levels=0,
3211
3410
  time_limit=None,
3212
3411
  groups=None,
3213
3412
  **kwargs,
3214
- ) -> List[str]:
3413
+ ) -> list[str]:
3215
3414
  """Identical to self.train_multi_levels, but also saves the data to disk. This should only ever be called once."""
3216
3415
  if time_limit is not None and time_limit <= 0:
3217
3416
  raise AssertionError(f"Not enough time left to train models. Consider specifying a larger time_limit. Time remaining: {round(time_limit, 2)}s")
@@ -3254,19 +3453,19 @@ class AbstractTrainer:
3254
3453
  logger.log(30, "Warning: AutoGluon did not successfully train any models")
3255
3454
  return model_names_fit
3256
3455
 
3257
- def _predict_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict = None) -> np.ndarray:
3456
+ def _predict_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
3258
3457
  y_pred_proba = self._predict_proba_model(X=X, model=model, model_pred_proba_dict=model_pred_proba_dict)
3259
3458
  return get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=self.problem_type)
3260
3459
 
3261
- def _predict_proba_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict = None) -> np.ndarray:
3460
+ def _predict_proba_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
3262
3461
  model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=[model], model_pred_proba_dict=model_pred_proba_dict)
3263
3462
  if not isinstance(model, str):
3264
3463
  model = model.name
3265
3464
  return model_pred_proba_dict[model]
3266
3465
 
3267
3466
  def _proxy_model_feature_prune(
3268
- self, model_fit_kwargs: dict, time_limit: float, layer_fit_time: float, level: int, features: List[str], **feature_prune_kwargs: dict
3269
- ) -> List[str]:
3467
+ self, model_fit_kwargs: dict, time_limit: float, layer_fit_time: float, level: int, features: list[str], **feature_prune_kwargs: dict
3468
+ ) -> list[str]:
3270
3469
  """
3271
3470
  Uses the best LightGBM-based base learner of this layer to perform time-aware permutation feature importance based feature pruning.
3272
3471
  If all LightGBM models fail, use the model that achieved the highest validation accuracy. Feature pruning gets the smaller of the
@@ -3287,12 +3486,12 @@ class AbstractTrainer:
3287
3486
  How long it took to fit all the models in this layer once. Used to calculate how long to feature prune for.
3288
3487
  level : int
3289
3488
  Level of this stack layer.
3290
- features: List[str]
3489
+ features: list[str]
3291
3490
  The list of feature names in the inputted dataset.
3292
3491
 
3293
3492
  Returns
3294
3493
  -------
3295
- candidate_features : List[str]
3494
+ candidate_features : list[str]
3296
3495
  Feature names that survived the pruning procedure.
3297
3496
  """
3298
3497
  k = feature_prune_kwargs.pop("k", 2)
@@ -3326,14 +3525,14 @@ class AbstractTrainer:
3326
3525
  def _get_default_proxy_model_class(self):
3327
3526
  return None
3328
3527
 
3329
- def _retain_better_pruned_models(self, pruned_models: List[str], original_prune_map: dict, force_prune: bool = False) -> List[str]:
3528
+ def _retain_better_pruned_models(self, pruned_models: list[str], original_prune_map: dict, force_prune: bool = False) -> list[str]:
3330
3529
  """
3331
3530
  Compares models fit on the pruned set of features with their counterpart, models fit on full set of features.
3332
3531
  Take the model that achieved a higher validation set score and delete the other from self.model_graph.
3333
3532
 
3334
3533
  Parameters
3335
3534
  ----------
3336
- pruned_models : List[str]
3535
+ pruned_models : list[str]
3337
3536
  A list of pruned model names.
3338
3537
  original_prune_map : dict
3339
3538
  A dictionary mapping the names of models fitted on pruned features to the names of models fitted on original features.
@@ -3342,7 +3541,7 @@ class AbstractTrainer:
3342
3541
 
3343
3542
  Returns
3344
3543
  ----------
3345
- models : List[str]
3544
+ models : list[str]
3346
3545
  A list of model names.
3347
3546
  """
3348
3547
  models = []
@@ -3460,7 +3659,7 @@ class AbstractTrainer:
3460
3659
  model_types = self.get_models_attribute_dict(attribute="type", models=model_names)
3461
3660
  return model_names, model_paths, model_types
3462
3661
 
3463
- def get_model_attribute_full(self, model: Union[str, List[str]], attribute: str, func=sum) -> Union[float, int]:
3662
+ def get_model_attribute_full(self, model: str | list[str], attribute: str, func=sum) -> float | int:
3464
3663
  """
3465
3664
  Sums the attribute value across all models that the provided model depends on, including itself.
3466
3665
  For instance, this function can return the expected total predict_time of a model.
@@ -3497,7 +3696,7 @@ class AbstractTrainer:
3497
3696
  attribute_full = 0
3498
3697
  return attribute_full
3499
3698
 
3500
- def get_models_attribute_full(self, models: List[str], attribute: str, func=sum):
3699
+ def get_models_attribute_full(self, models: list[str], attribute: str, func=sum):
3501
3700
  """
3502
3701
  For each model in models, returns the output of self.get_model_attribute_full mapped to a dict.
3503
3702
  """
@@ -3506,57 +3705,6 @@ class AbstractTrainer:
3506
3705
  d[model] = self.get_model_attribute_full(model=model, attribute=attribute, func=func)
3507
3706
  return d
3508
3707
 
3509
- def get_models_attribute_dict(self, attribute: str, models: list = None) -> Dict[str, Any]:
3510
- """
3511
- Returns dictionary of model name -> attribute value for the provided attribute.
3512
- """
3513
- models_attribute_dict = nx.get_node_attributes(self.model_graph, attribute)
3514
- if models is not None:
3515
- model_names = []
3516
- for model in models:
3517
- if not isinstance(model, str):
3518
- model = model.name
3519
- model_names.append(model)
3520
- if attribute == "path":
3521
- models_attribute_dict = {key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names}
3522
- else:
3523
- models_attribute_dict = {key: val for key, val in models_attribute_dict.items() if key in model_names}
3524
- return models_attribute_dict
3525
-
3526
- def get_model_attribute(self, model, attribute: str, **kwargs) -> Any:
3527
- """
3528
- Return model attribute value.
3529
- If `default` is specified, return default value if attribute does not exist.
3530
- If `default` is not specified, raise ValueError if attribute does not exist.
3531
- """
3532
- if not isinstance(model, str):
3533
- model = model.name
3534
- if model not in self.model_graph.nodes:
3535
- raise ValueError(f"Model does not exist: (model={model})")
3536
- if attribute not in self.model_graph.nodes[model]:
3537
- if "default" in kwargs:
3538
- return kwargs["default"]
3539
- else:
3540
- raise ValueError(f"Model does not contain attribute: (model={model}, attribute={attribute})")
3541
- if attribute == "path":
3542
- return os.path.join(*self.model_graph.nodes[model][attribute])
3543
- return self.model_graph.nodes[model][attribute]
3544
-
3545
- def set_model_attribute(self, model, attribute: str, val):
3546
- if not isinstance(model, str):
3547
- model = model.name
3548
- self.model_graph.nodes[model][attribute] = val
3549
-
3550
- # Gets the minimum set of models that the provided model depends on, including itself
3551
- # Returns a list of model names
3552
- def get_minimum_model_set(self, model, include_self=True) -> list:
3553
- if not isinstance(model, str):
3554
- model = model.name
3555
- minimum_model_set = list(nx.bfs_tree(self.model_graph, model, reverse=True))
3556
- if not include_self:
3557
- minimum_model_set = [m for m in minimum_model_set if m != model]
3558
- return minimum_model_set
3559
-
3560
3708
  # Gets the minimum set of models that the provided models depend on, including themselves
3561
3709
  # Returns a list of model names
3562
3710
  def get_minimum_models_set(self, models: list) -> list:
@@ -3573,7 +3721,7 @@ class AbstractTrainer:
3573
3721
  base_model_set = list(self.model_graph.predecessors(model))
3574
3722
  return base_model_set
3575
3723
 
3576
- def model_refit_map(self, inverse=False) -> Dict[str, str]:
3724
+ def model_refit_map(self, inverse=False) -> dict[str, str]:
3577
3725
  """
3578
3726
  Returns dict of parent model -> refit model
3579
3727
 
@@ -3587,10 +3735,6 @@ class AbstractTrainer:
3587
3735
  def model_exists(self, model: str) -> bool:
3588
3736
  return model in self.get_model_names()
3589
3737
 
3590
- def _get_banned_model_names(self) -> list:
3591
- """Gets all model names which would cause model files to be overwritten if a new model was trained with the name"""
3592
- return self.get_model_names() + list(self._extra_banned_names)
3593
-
3594
3738
  def _flatten_model_info(self, model_info: dict) -> dict:
3595
3739
  """
3596
3740
  Flattens the model_info nested dictionary into a shallow dictionary to convert to a pandas DataFrame row.
@@ -3643,7 +3787,7 @@ class AbstractTrainer:
3643
3787
  model_info_flat[key] = custom_info[key]
3644
3788
  return model_info_flat
3645
3789
 
3646
- def leaderboard(self, extra_info=False, refit_full: bool = None, set_refit_score_to_parent: bool = False):
3790
+ def leaderboard(self, extra_info=False, refit_full: bool | None = None, set_refit_score_to_parent: bool = False):
3647
3791
  model_names = self.get_model_names()
3648
3792
  models_full_dict = self.get_models_attribute_dict(models=model_names, attribute="refit_full_parent")
3649
3793
  if refit_full is not None:
@@ -3956,30 +4100,6 @@ class AbstractTrainer:
3956
4100
 
3957
4101
  return info
3958
4102
 
3959
- def get_model_info(self, model: str | AbstractModel) -> Dict[str, Any]:
3960
- if isinstance(model, str):
3961
- if model in self.models.keys():
3962
- model = self.models[model]
3963
- if isinstance(model, str):
3964
- model_type = self.get_model_attribute(model=model, attribute="type")
3965
- model_path = self.get_model_attribute(model=model, attribute="path")
3966
- model_info = model_type.load_info(path=os.path.join(self.path, model_path))
3967
- else:
3968
- model_info = model.get_info()
3969
- return model_info
3970
-
3971
- def get_models_info(self, models: List[str | AbstractModel] = None) -> Dict[str, Dict[str, Any]]:
3972
- if models is None:
3973
- models = self.get_model_names()
3974
- model_info_dict = dict()
3975
- for model in models:
3976
- if isinstance(model, str):
3977
- model_name = model
3978
- else:
3979
- model_name = model.name
3980
- model_info_dict[model_name] = self.get_model_info(model=model)
3981
- return model_info_dict
3982
-
3983
4103
  def reduce_memory_size(
3984
4104
  self, remove_data=True, remove_fit_stack=False, remove_fit=True, remove_info=False, requires_save=True, reduce_children=False, **kwargs
3985
4105
  ):
@@ -4062,7 +4182,7 @@ class AbstractTrainer:
4062
4182
  for model in models_to_remove:
4063
4183
  model = self.load_model(model)
4064
4184
  logger.log(30, f"\tDirectory {model.path} would have been deleted.")
4065
- logger.log(30, f"To perform the deletion, set dry_run=False")
4185
+ logger.log(30, "To perform the deletion, set dry_run=False")
4066
4186
  return
4067
4187
 
4068
4188
  if delete_from_disk:
@@ -4091,37 +4211,8 @@ class AbstractTrainer:
4091
4211
  path_attr_model = Path(self._path_attr_model(model))
4092
4212
  shutil.rmtree(path=path_attr_model, ignore_errors=True)
4093
4213
 
4094
- @classmethod
4095
- def load(cls, path, reset_paths=False):
4096
- load_path = os.path.join(path, cls.trainer_file_name)
4097
- if not reset_paths:
4098
- return load_pkl.load(path=load_path)
4099
- else:
4100
- obj = load_pkl.load(path=load_path)
4101
- obj.set_contexts(path)
4102
- obj.reset_paths = reset_paths
4103
- return obj
4104
-
4105
- @classmethod
4106
- def load_info(cls, path, reset_paths=False, load_model_if_required=True):
4107
- load_path = os.path.join(path, cls.trainer_info_name)
4108
- try:
4109
- return load_pkl.load(path=load_path)
4110
- except:
4111
- if load_model_if_required:
4112
- trainer = cls.load(path=path, reset_paths=reset_paths)
4113
- return trainer.get_info()
4114
- else:
4115
- raise
4116
-
4117
- def save_info(self, include_model_info=False):
4118
- info = self.get_info(include_model_info=include_model_info)
4119
-
4120
- save_pkl.save(path=os.path.join(self.path, self.trainer_info_name), object=info)
4121
- save_json.save(path=os.path.join(self.path, self.trainer_info_json_name), obj=info)
4122
- return info
4123
-
4124
- def _process_hyperparameters(self, hyperparameters: dict) -> dict:
4214
+ @staticmethod
4215
+ def _process_hyperparameters(hyperparameters: dict) -> dict:
4125
4216
  return process_hyperparameters(hyperparameters=hyperparameters)
4126
4217
 
4127
4218
  def distill(
@@ -4327,7 +4418,7 @@ class AbstractTrainer:
4327
4418
  return distilled_model_names
4328
4419
 
4329
4420
  def _get_model_fit_kwargs(
4330
- self, X: pd.DataFrame, X_val: pd.DataFrame, time_limit: float, k_fold: int, fit_kwargs: dict, ens_sample_weight: List = None
4421
+ self, X: pd.DataFrame, X_val: pd.DataFrame, time_limit: float, k_fold: int, fit_kwargs: dict, ens_sample_weight: list | None = None
4331
4422
  ) -> dict:
4332
4423
  # Returns kwargs to be passed to AbstractModel's fit function
4333
4424
  if fit_kwargs is None:
@@ -4364,7 +4455,7 @@ class AbstractTrainer:
4364
4455
  k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start, compute_base_preds=False
4365
4456
  )
4366
4457
 
4367
- def _get_feature_prune_proxy_model(self, proxy_model_class: Union[AbstractModel, None], level: int) -> AbstractModel:
4458
+ def _get_feature_prune_proxy_model(self, proxy_model_class: AbstractModel | None, level: int) -> AbstractModel:
4368
4459
  """
4369
4460
  Returns proxy model to be used for feature pruning - the base learner that has the highest validation score in a particular stack layer.
4370
4461
  Ties are broken by inference speed. If proxy_model_class is not None, take the best base learner belonging to proxy_model_class.
@@ -4398,7 +4489,7 @@ class AbstractTrainer:
4398
4489
  best_candidate_model_rows = candidate_model_rows.loc[candidate_model_rows["score_val"] == candidate_model_rows["score_val"].max()]
4399
4490
  return self.load_model(best_candidate_model_rows.loc[best_candidate_model_rows["fit_time"].idxmin()]["model"])
4400
4491
 
4401
- def calibrate_model(self, model_name: str = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0):
4492
+ def calibrate_model(self, model_name: str | None = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0):
4402
4493
  """
4403
4494
  Applies temperature scaling to a model.
4404
4495
  Applies inverse softmax to predicted probs then trains temperature scalar
@@ -4494,11 +4585,11 @@ class AbstractTrainer:
4494
4585
  def calibrate_decision_threshold(
4495
4586
  self,
4496
4587
  X: pd.DataFrame | None = None,
4497
- y: np.array | None = None,
4588
+ y: np.ndarray | None = None,
4498
4589
  metric: str | Scorer | None = None,
4499
4590
  model: str = "best",
4500
4591
  weights=None,
4501
- decision_thresholds: int | List[float] = 25,
4592
+ decision_thresholds: int | list[float] = 25,
4502
4593
  secondary_decision_thresholds: int | None = 19,
4503
4594
  verbose: bool = True,
4504
4595
  **kwargs,
@@ -4565,18 +4656,18 @@ class AbstractTrainer:
4565
4656
  )
4566
4657
 
4567
4658
  @staticmethod
4568
- def _validate_num_classes(num_classes: int, problem_type: str):
4659
+ def _validate_num_classes(num_classes: int | None, problem_type: str):
4569
4660
  if problem_type == BINARY:
4570
4661
  assert num_classes is not None and num_classes == 2, f"num_classes must be 2 when problem_type='{problem_type}' (num_classes={num_classes})"
4571
4662
  elif problem_type in [MULTICLASS, SOFTCLASS]:
4572
4663
  assert num_classes is not None and num_classes >= 2, f"num_classes must be >=2 when problem_type='{problem_type}' (num_classes={num_classes})"
4573
4664
  elif problem_type in [REGRESSION, QUANTILE]:
4574
- assert num_classes is None, f"num_clases must be None when problem_type='{problem_type}' (num_classes={num_classes})"
4665
+ assert num_classes is None, f"num_classes must be None when problem_type='{problem_type}' (num_classes={num_classes})"
4575
4666
  else:
4576
4667
  raise AssertionError(f"Unknown problem_type: '{problem_type}'. Valid problem types: {[BINARY, MULTICLASS, REGRESSION, SOFTCLASS, QUANTILE]}")
4577
4668
 
4578
4669
  @staticmethod
4579
- def _validate_quantile_levels(quantile_levels: List[float] | np.array, problem_type: str):
4670
+ def _validate_quantile_levels(quantile_levels: list[float] | np.ndarray | None, problem_type: str):
4580
4671
  if problem_type == QUANTILE:
4581
4672
  assert quantile_levels is not None, f"quantile_levels must not be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4582
4673
  assert isinstance(quantile_levels, (list, np.ndarray)), f"quantile_levels must be a list or np.ndarray (quantile_levels={quantile_levels})"
@@ -4587,13 +4678,13 @@ class AbstractTrainer:
4587
4678
 
4588
4679
  def _detached_train_multi_fold(
4589
4680
  *,
4590
- _self: AbstractTrainer,
4681
+ _self: AbstractTabularTrainer,
4591
4682
  model: str | AbstractModel,
4592
4683
  X: pd.DataFrame,
4593
4684
  y: pd.Series,
4594
4685
  time_split: bool,
4595
4686
  time_start: float,
4596
- time_limit: float|None,
4687
+ time_limit: float | None,
4597
4688
  time_limit_model_split: float | None,
4598
4689
  hyperparameter_tune_kwargs: dict,
4599
4690
  is_ray_worker: bool = False,
@@ -4636,7 +4727,7 @@ def _detached_train_multi_fold(
4636
4727
 
4637
4728
  def _remote_train_multi_fold(
4638
4729
  *,
4639
- _self: AbstractTrainer,
4730
+ _self: AbstractTabularTrainer,
4640
4731
  model: str | AbstractModel,
4641
4732
  X: pd.DataFrame,
4642
4733
  y: pd.Series,
@@ -4694,7 +4785,7 @@ def _remote_train_multi_fold(
4694
4785
 
4695
4786
  def _detached_refit_single_full(
4696
4787
  *,
4697
- _self: AbstractTrainer,
4788
+ _self: AbstractTabularTrainer,
4698
4789
  model: str,
4699
4790
  X: pd.DataFrame,
4700
4791
  y: pd.Series,
@@ -4787,7 +4878,7 @@ def _detached_refit_single_full(
4787
4878
 
4788
4879
  def _remote_refit_single_full(
4789
4880
  *,
4790
- _self: AbstractTrainer,
4881
+ _self: AbstractTabularTrainer,
4791
4882
  model: str,
4792
4883
  X: pd.DataFrame,
4793
4884
  y: pd.Series,