autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -9,7 +9,7 @@ import pprint
9
9
  import shutil
10
10
  import time
11
11
  import warnings
12
- from typing import overload, Any, Literal, Optional, Union
12
+ from typing import Any, Literal, Optional, Union, overload
13
13
 
14
14
  import networkx as nx
15
15
  import numpy as np
@@ -22,13 +22,25 @@ from autogluon.common.savers import save_json
22
22
  from autogluon.common.utils.cv_splitter import CVSplitter
23
23
  from autogluon.common.utils.decorators import apply_presets
24
24
  from autogluon.common.utils.file_utils import get_directory_size, get_directory_size_per_file
25
- from autogluon.common.utils.resource_utils import ResourceManager, get_resource_manager
26
- from autogluon.common.utils.hyperparameter_utils import get_hyperparameter_str_deprecation_msg, is_advanced_hyperparameter_format
27
- from autogluon.common.utils.log_utils import add_log_to_file, set_logger_verbosity, warn_if_mlflow_autologging_is_enabled
25
+ from autogluon.common.utils.hyperparameter_utils import (
26
+ get_hyperparameter_str_deprecation_msg,
27
+ is_advanced_hyperparameter_format,
28
+ )
29
+ from autogluon.common.utils.log_utils import (
30
+ add_log_to_file,
31
+ set_logger_verbosity,
32
+ warn_if_mlflow_autologging_is_enabled,
33
+ )
28
34
  from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
35
+ from autogluon.common.utils.resource_utils import ResourceManager, get_resource_manager
29
36
  from autogluon.common.utils.system_info import get_ag_system_info
30
37
  from autogluon.common.utils.try_import import try_import_ray
31
- from autogluon.common.utils.utils import check_saved_predictor_version, compare_autogluon_metadata, get_autogluon_metadata, setup_outputdir
38
+ from autogluon.common.utils.utils import (
39
+ check_saved_predictor_version,
40
+ compare_autogluon_metadata,
41
+ get_autogluon_metadata,
42
+ setup_outputdir,
43
+ )
32
44
  from autogluon.core.callbacks import AbstractCallback
33
45
  from autogluon.core.constants import (
34
46
  AUTO_WEIGHT,
@@ -47,7 +59,12 @@ from autogluon.core.problem_type import problem_type_info
47
59
  from autogluon.core.pseudolabeling.pseudolabeling import filter_ensemble_pseudo, filter_pseudo
48
60
  from autogluon.core.scheduler.scheduler_factory import scheduler_factory
49
61
  from autogluon.core.stacked_overfitting.utils import check_stacked_overfitting_from_leaderboard
50
- from autogluon.core.utils import get_pred_from_proba_df, plot_performance_vs_trials, plot_summary_of_models, plot_tabular_models
62
+ from autogluon.core.utils import (
63
+ get_pred_from_proba_df,
64
+ plot_performance_vs_trials,
65
+ plot_summary_of_models,
66
+ plot_tabular_models,
67
+ )
51
68
  from autogluon.core.utils.loaders import load_pkl, load_str
52
69
  from autogluon.core.utils.savers import save_pkl, save_str
53
70
  from autogluon.core.utils.utils import generate_train_test_split_combined
@@ -60,8 +77,8 @@ from ..configs.pipeline_presets import (
60
77
  )
61
78
  from ..configs.presets_configs import tabular_presets_alias, tabular_presets_dict
62
79
  from ..learner import AbstractTabularLearner, DefaultLearner
63
- from ..trainer.abstract_trainer import AbstractTabularTrainer
64
80
  from ..registry import ag_model_registry
81
+ from ..trainer.abstract_trainer import AbstractTabularTrainer
65
82
  from ..version import __version__
66
83
 
67
84
  logger = logging.getLogger(__name__) # return autogluon root logger
@@ -205,7 +222,9 @@ class TabularPredictor:
205
222
  logger.log(15, f"{AUTO_WEIGHT} currently does not use any sample weights.")
206
223
  self.sample_weight = sample_weight
207
224
  self.weight_evaluation = weight_evaluation # TODO: sample_weight and weight_evaluation can both be properties that link to self._learner.sample_weight, self._learner.weight_evaluation
208
- self._decision_threshold = None # TODO: Each model should have its own decision threshold instead of one global threshold
225
+ self._decision_threshold = (
226
+ None # TODO: Each model should have its own decision threshold instead of one global threshold
227
+ )
209
228
  if self.sample_weight in [AUTO_WEIGHT, BALANCE_WEIGHT] and self.weight_evaluation:
210
229
  logger.warning(
211
230
  f"We do not recommend specifying weight_evaluation when sample_weight='{self.sample_weight}', instead specify appropriate eval_metric."
@@ -1086,7 +1105,9 @@ class TabularPredictor:
1086
1105
  >>> predictor = TabularPredictor(label=label, eval_metric=eval_metric).fit(train_data, presets=['best_quality'], time_limit=time_limit)
1087
1106
  """
1088
1107
  if self.is_fit:
1089
- raise AssertionError("Predictor is already fit! To fit additional models, refer to `predictor.fit_extra`, or create a new `Predictor`.")
1108
+ raise AssertionError(
1109
+ "Predictor is already fit! To fit additional models, refer to `predictor.fit_extra`, or create a new `Predictor`."
1110
+ )
1090
1111
 
1091
1112
  verbosity = kwargs.get("verbosity", self.verbosity)
1092
1113
  set_logger_verbosity(verbosity)
@@ -1160,7 +1181,9 @@ class TabularPredictor:
1160
1181
 
1161
1182
  if ag_args is None:
1162
1183
  ag_args = {}
1163
- ag_args = self._set_hyperparameter_tune_kwargs_in_ag_args(kwargs["hyperparameter_tune_kwargs"], ag_args, time_limit=time_limit)
1184
+ ag_args = self._set_hyperparameter_tune_kwargs_in_ag_args(
1185
+ kwargs["hyperparameter_tune_kwargs"], ag_args, time_limit=time_limit
1186
+ )
1164
1187
 
1165
1188
  feature_generator_init_kwargs = kwargs["_feature_generator_kwargs"]
1166
1189
  if feature_generator_init_kwargs is None:
@@ -1169,13 +1192,17 @@ class TabularPredictor:
1169
1192
  train_data, tuning_data, test_data, unlabeled_data = self._validate_fit_data(
1170
1193
  train_data=train_data, tuning_data=tuning_data, test_data=test_data, unlabeled_data=unlabeled_data
1171
1194
  )
1172
- infer_limit, infer_limit_batch_size = self._validate_infer_limit(infer_limit=infer_limit, infer_limit_batch_size=infer_limit_batch_size)
1195
+ infer_limit, infer_limit_batch_size = self._validate_infer_limit(
1196
+ infer_limit=infer_limit, infer_limit_batch_size=infer_limit_batch_size
1197
+ )
1173
1198
 
1174
1199
  # TODO: Temporary for v1.4. Make this more extensible for v1.5 by letting users make their own dynamic hyperparameters.
1175
1200
  dynamic_hyperparameters = kwargs["_experimental_dynamic_hyperparameters"]
1176
1201
  if dynamic_hyperparameters:
1177
1202
  logger.log(20, f"`extreme_v140` preset uses a dynamic portfolio based on dataset size...")
1178
- assert hyperparameters is None, f"hyperparameters must be unspecified when `_experimental_dynamic_hyperparameters=True`."
1203
+ assert hyperparameters is None, (
1204
+ f"hyperparameters must be unspecified when `_experimental_dynamic_hyperparameters=True`."
1205
+ )
1179
1206
  n_samples = len(train_data)
1180
1207
  if n_samples > 30000:
1181
1208
  data_size = "large"
@@ -1183,7 +1210,10 @@ class TabularPredictor:
1183
1210
  data_size = "small"
1184
1211
  assert data_size in ["large", "small"]
1185
1212
  if data_size == "large":
1186
- logger.log(20, f"\tDetected data size: large (>30000 samples), using `zeroshot` portfolio (identical to 'best_quality' preset).")
1213
+ logger.log(
1214
+ 20,
1215
+ f"\tDetected data size: large (>30000 samples), using `zeroshot` portfolio (identical to 'best_quality' preset).",
1216
+ )
1187
1217
  hyperparameters = "zeroshot"
1188
1218
  else:
1189
1219
  if "num_stack_levels" not in kwargs_orig:
@@ -1200,7 +1230,7 @@ class TabularPredictor:
1200
1230
  f"and a CUDA compatible GPU with 32+ GB vRAM when using this preset. "
1201
1231
  f"\n\t\tThis portfolio will download foundation model weights from HuggingFace during training. "
1202
1232
  f"Ensure you have an internet connection or have pre-downloaded the weights to use these models."
1203
- f"\n\t\tThis portfolio was meta-learned with TabArena: https://tabarena.ai"
1233
+ f"\n\t\tThis portfolio was meta-learned with TabArena: https://tabarena.ai",
1204
1234
  )
1205
1235
  hyperparameters = "zeroshot_2025_tabfm"
1206
1236
 
@@ -1222,7 +1252,11 @@ class TabularPredictor:
1222
1252
 
1223
1253
  if feature_metadata is not None and isinstance(feature_metadata, str) and feature_metadata == "infer":
1224
1254
  feature_metadata = None
1225
- self._set_feature_generator(feature_generator=feature_generator, feature_metadata=feature_metadata, init_kwargs=feature_generator_init_kwargs)
1255
+ self._set_feature_generator(
1256
+ feature_generator=feature_generator,
1257
+ feature_metadata=feature_metadata,
1258
+ init_kwargs=feature_generator_init_kwargs,
1259
+ )
1226
1260
 
1227
1261
  if self.problem_type is not None:
1228
1262
  inferred_problem_type = self.problem_type
@@ -1230,7 +1264,9 @@ class TabularPredictor:
1230
1264
  self._learner.validate_label(X=train_data)
1231
1265
  inferred_problem_type = self._learner.infer_problem_type(y=train_data[self.label], silent=True)
1232
1266
 
1233
- learning_curves = self._initialize_learning_curve_params(learning_curves=learning_curves, problem_type=inferred_problem_type)
1267
+ learning_curves = self._initialize_learning_curve_params(
1268
+ learning_curves=learning_curves, problem_type=inferred_problem_type
1269
+ )
1234
1270
  if len(learning_curves) == 0:
1235
1271
  test_data = None
1236
1272
  if ag_args_fit is not None:
@@ -1240,10 +1276,10 @@ class TabularPredictor:
1240
1276
 
1241
1277
  use_bag_holdout_was_auto = False
1242
1278
  dynamic_stacking_was_auto = False
1243
- if isinstance(use_bag_holdout,str) and use_bag_holdout == "auto":
1279
+ if isinstance(use_bag_holdout, str) and use_bag_holdout == "auto":
1244
1280
  use_bag_holdout = None
1245
1281
  use_bag_holdout_was_auto = True
1246
- if isinstance(dynamic_stacking,str) and dynamic_stacking == "auto":
1282
+ if isinstance(dynamic_stacking, str) and dynamic_stacking == "auto":
1247
1283
  dynamic_stacking = None
1248
1284
  dynamic_stacking_was_auto = True
1249
1285
 
@@ -1295,7 +1331,9 @@ class TabularPredictor:
1295
1331
  kwargs["save_bag_folds"] = kwargs["_save_bag_folds"]
1296
1332
 
1297
1333
  if kwargs["save_bag_folds"] is not None:
1298
- assert isinstance(kwargs["save_bag_folds"], bool), f"save_bag_folds must be a bool, found: {type(kwargs['save_bag_folds'])}"
1334
+ assert isinstance(kwargs["save_bag_folds"], bool), (
1335
+ f"save_bag_folds must be a bool, found: {type(kwargs['save_bag_folds'])}"
1336
+ )
1299
1337
  if use_bag_holdout and not kwargs["save_bag_folds"]:
1300
1338
  logger.log(
1301
1339
  30,
@@ -1390,7 +1428,9 @@ class TabularPredictor:
1390
1428
  f"DyStack is enabled (dynamic_stacking={dynamic_stacking}). "
1391
1429
  "AutoGluon will try to determine whether the input data is affected by stacked overfitting and enable or disable stacking as a consequence.",
1392
1430
  )
1393
- num_stack_levels, time_limit = self._dynamic_stacking(**ds_args, ag_fit_kwargs=ag_fit_kwargs, ag_post_fit_kwargs=ag_post_fit_kwargs)
1431
+ num_stack_levels, time_limit = self._dynamic_stacking(
1432
+ **ds_args, ag_fit_kwargs=ag_fit_kwargs, ag_post_fit_kwargs=ag_post_fit_kwargs
1433
+ )
1394
1434
  logger.info(
1395
1435
  f"Starting main fit with num_stack_levels={num_stack_levels}.\n"
1396
1436
  f"\tFor future fit calls on this dataset, you can skip DyStack to save time: "
@@ -1450,7 +1490,9 @@ class TabularPredictor:
1450
1490
 
1451
1491
  if time_limit_og is not None:
1452
1492
  time_limit = int(time_limit_og * detection_time_frac)
1453
- logger.info(f"\tRunning DyStack for up to {time_limit}s of the {time_limit_og}s of remaining time ({detection_time_frac*100:.0f}%).")
1493
+ logger.info(
1494
+ f"\tRunning DyStack for up to {time_limit}s of the {time_limit_og}s of remaining time ({detection_time_frac * 100:.0f}%)."
1495
+ )
1454
1496
  else:
1455
1497
  logger.info(f"\tWarning: No time limit provided for DyStack. This could take awhile.")
1456
1498
  time_limit = None
@@ -1468,8 +1510,12 @@ class TabularPredictor:
1468
1510
  inner_ag_fit_kwargs["X_val"] = X_val
1469
1511
  inner_ag_fit_kwargs["X_unlabeled"] = X_unlabeled
1470
1512
  inner_ag_post_fit_kwargs = copy.deepcopy(ag_post_fit_kwargs)
1471
- inner_ag_post_fit_kwargs["keep_only_best"] = False # Do not keep only best, otherwise it eliminates the purpose of the comparison
1472
- inner_ag_post_fit_kwargs["calibrate"] = False # Do not calibrate as calibration is only applied to the model with the best validation score
1513
+ inner_ag_post_fit_kwargs["keep_only_best"] = (
1514
+ False # Do not keep only best, otherwise it eliminates the purpose of the comparison
1515
+ )
1516
+ inner_ag_post_fit_kwargs["calibrate"] = (
1517
+ False # Do not calibrate as calibration is only applied to the model with the best validation score
1518
+ )
1473
1519
  # FIXME: Ensure all weighted ensembles have skip connections
1474
1520
 
1475
1521
  # Verify problem type is set
@@ -1485,7 +1531,9 @@ class TabularPredictor:
1485
1531
  # -- Validation Method
1486
1532
  if validation_procedure == "holdout":
1487
1533
  if holdout_data is None:
1488
- ds_fit_kwargs.update(dict(holdout_frac=holdout_frac, ds_fit_context=os.path.join(ds_fit_context, "sub_fit_ho")))
1534
+ ds_fit_kwargs.update(
1535
+ dict(holdout_frac=holdout_frac, ds_fit_context=os.path.join(ds_fit_context, "sub_fit_ho"))
1536
+ )
1489
1537
  else:
1490
1538
  _, holdout_data, _, _ = self._validate_fit_data(train_data=X, tuning_data=holdout_data)
1491
1539
  ds_fit_kwargs["ds_fit_context"] = os.path.join(ds_fit_context, "sub_fit_custom_ho")
@@ -1511,25 +1559,28 @@ class TabularPredictor:
1511
1559
  stratify=is_stratified,
1512
1560
  bin=is_binned,
1513
1561
  random_state=42,
1514
- ).split(
1515
- X=X.drop(self.label, axis=1),
1516
- y=X[self.label]
1517
- )
1562
+ ).split(X=X.drop(self.label, axis=1), y=X[self.label])
1518
1563
  n_splits = len(splits)
1519
1564
  logger.info(
1520
1565
  f'\tStarting (repeated-)cross-validation-based sub-fits for dynamic stacking. Context path: "{ds_fit_context}"'
1521
1566
  f"Run at most {n_splits} sub-fits based on {n_repeats}-repeated {n_folds}-fold cross-validation."
1522
1567
  )
1523
- np.random.RandomState(42).shuffle(splits) # shuffle splits to mix up order such that if only one of the repeats shows leakage we might stop early.
1568
+ np.random.RandomState(42).shuffle(
1569
+ splits
1570
+ ) # shuffle splits to mix up order such that if only one of the repeats shows leakage we might stop early.
1524
1571
  for split_index, (train_indices, val_indices) in enumerate(splits):
1525
1572
  if time_limit is None:
1526
1573
  sub_fit_time = None
1527
1574
  else:
1528
1575
  time_spend_sub_fits_so_far = int(time.time() - time_start)
1529
1576
  rest_time = time_limit - time_spend_sub_fits_so_far
1530
- sub_fit_time = int(1 / (n_splits - split_index) * rest_time) # if we are faster, give more time to rest of the folds.
1577
+ sub_fit_time = int(
1578
+ 1 / (n_splits - split_index) * rest_time
1579
+ ) # if we are faster, give more time to rest of the folds.
1531
1580
  if sub_fit_time <= 0:
1532
- logger.info(f"\tStop cross-validation during dynamic stacking early as no more time left. Consider specifying a larger time_limit.")
1581
+ logger.info(
1582
+ f"\tStop cross-validation during dynamic stacking early as no more time left. Consider specifying a larger time_limit."
1583
+ )
1533
1584
  break
1534
1585
  ds_fit_kwargs.update(
1535
1586
  dict(
@@ -1566,7 +1617,9 @@ class TabularPredictor:
1566
1617
  num_stack_levels = 0 if stacked_overfitting else org_num_stack_levels
1567
1618
  self._stacked_overfitting_occurred = stacked_overfitting
1568
1619
 
1569
- logger.info(f"\t{num_stack_levels}\t = Optimal num_stack_levels (Stacked Overfitting Occurred: {self._stacked_overfitting_occurred})")
1620
+ logger.info(
1621
+ f"\t{num_stack_levels}\t = Optimal num_stack_levels (Stacked Overfitting Occurred: {self._stacked_overfitting_occurred})"
1622
+ )
1570
1623
  log_str = f"\t{round(time_spend_sub_fits)}s\t = DyStack runtime"
1571
1624
  if time_limit_og is None:
1572
1625
  time_limit_fit_full = None
@@ -1580,7 +1633,10 @@ class TabularPredictor:
1580
1633
  if holdout_data is None:
1581
1634
  ag_fit_kwargs["X"] = X
1582
1635
  else:
1583
- logger.log(20, "\tConcatenating holdout data from dynamic stacking to the training data for the full fit (and reset the index).")
1636
+ logger.log(
1637
+ 20,
1638
+ "\tConcatenating holdout data from dynamic stacking to the training data for the full fit (and reset the index).",
1639
+ )
1584
1640
  ag_fit_kwargs["X"] = pd.concat([X, holdout_data], ignore_index=True)
1585
1641
 
1586
1642
  ag_fit_kwargs["X_val"] = X_val
@@ -1614,14 +1670,13 @@ class TabularPredictor:
1614
1670
  30,
1615
1671
  f"DyStack: Disabling memory safe fit mode in DyStack "
1616
1672
  f"because GPUs were detected and num_gpus='auto' (GPUs cannot be used in memory safe fit mode). "
1617
- f"If you want to use memory safe fit mode, manually set `num_gpus=0`."
1673
+ f"If you want to use memory safe fit mode, manually set `num_gpus=0`.",
1618
1674
  )
1619
1675
  if num_gpus > 0:
1620
1676
  memory_safe_fits = False
1621
1677
  else:
1622
1678
  memory_safe_fits = True
1623
1679
 
1624
-
1625
1680
  if memory_safe_fits:
1626
1681
  try:
1627
1682
  _ds_ray = try_import_ray()
@@ -1643,7 +1698,10 @@ class TabularPredictor:
1643
1698
  log_to_driver=False,
1644
1699
  )
1645
1700
  except Exception as e:
1646
- warnings.warn(f"Failed to use ray for memory safe fits. Falling back to normal fit. Error: {repr(e)}", stacklevel=2)
1701
+ warnings.warn(
1702
+ f"Failed to use ray for memory safe fits. Falling back to normal fit. Error: {repr(e)}",
1703
+ stacklevel=2,
1704
+ )
1647
1705
  _ds_ray = None
1648
1706
 
1649
1707
  if time_limit is not None:
@@ -1654,9 +1712,11 @@ class TabularPredictor:
1654
1712
  return False
1655
1713
 
1656
1714
  if holdout_data is None:
1657
- logger.info(f"\t\tContext path: \"{ds_fit_kwargs['ds_fit_context']}\"")
1715
+ logger.info(f'\t\tContext path: "{ds_fit_kwargs["ds_fit_context"]}"')
1658
1716
  else:
1659
- logger.info(f"\t\tRunning DyStack holdout-based sub-fit with custom validation data. Context path: \"{ds_fit_kwargs['ds_fit_context']}\"")
1717
+ logger.info(
1718
+ f'\t\tRunning DyStack holdout-based sub-fit with custom validation data. Context path: "{ds_fit_kwargs["ds_fit_context"]}"'
1719
+ )
1660
1720
 
1661
1721
  if _ds_ray is not None:
1662
1722
  # Handle resources
@@ -1790,7 +1850,9 @@ class TabularPredictor:
1790
1850
  if infer_limit is not None:
1791
1851
  infer_limit = infer_limit - self._learner.preprocess_1_time
1792
1852
  trainer_model_best = self._trainer.get_model_best(infer_limit=infer_limit, infer_limit_as_child=True)
1793
- logger.log(20, "Automatically performing refit_full as a post-fit operation (due to `.fit(..., refit_full=True)`")
1853
+ logger.log(
1854
+ 20, "Automatically performing refit_full as a post-fit operation (due to `.fit(..., refit_full=True)`"
1855
+ )
1794
1856
  if set_best_to_refit_full:
1795
1857
  _set_best_to_refit_full = trainer_model_best
1796
1858
  else:
@@ -1849,7 +1911,10 @@ class TabularPredictor:
1849
1911
  elif self.problem_type == QUANTILE:
1850
1912
  self._trainer.calibrate_model()
1851
1913
  else:
1852
- logger.log(30, "WARNING: `calibrate=True` is only applicable to classification or quantile regression problems. Skipping calibration...")
1914
+ logger.log(
1915
+ 30,
1916
+ "WARNING: `calibrate=True` is only applicable to classification or quantile regression problems. Skipping calibration...",
1917
+ )
1853
1918
 
1854
1919
  if isinstance(calibrate_decision_threshold, str) and calibrate_decision_threshold == "auto":
1855
1920
  calibrate_decision_threshold = self._can_calibrate_decision_threshold()
@@ -1888,10 +1953,16 @@ class TabularPredictor:
1888
1953
  f"Force calibration via specifying `calibrate_decision_threshold=True`.",
1889
1954
  )
1890
1955
  if calibrate_decision_threshold:
1891
- logger.log(20, f"Enabling decision threshold calibration (calibrate_decision_threshold='auto', metric is valid, problem_type is 'binary')")
1956
+ logger.log(
1957
+ 20,
1958
+ f"Enabling decision threshold calibration (calibrate_decision_threshold='auto', metric is valid, problem_type is 'binary')",
1959
+ )
1892
1960
  if calibrate_decision_threshold:
1893
1961
  if self.problem_type != BINARY:
1894
- logger.log(30, "WARNING: `calibrate_decision_threshold=True` is only applicable to binary classification. Skipping calibration...")
1962
+ logger.log(
1963
+ 30,
1964
+ "WARNING: `calibrate_decision_threshold=True` is only applicable to binary classification. Skipping calibration...",
1965
+ )
1895
1966
  else:
1896
1967
  best_threshold = self.calibrate_decision_threshold()
1897
1968
  self.set_decision_threshold(decision_threshold=best_threshold)
@@ -2036,7 +2107,9 @@ class TabularPredictor:
2036
2107
 
2037
2108
  if ag_args is None:
2038
2109
  ag_args = {}
2039
- ag_args = self._set_hyperparameter_tune_kwargs_in_ag_args(kwargs["hyperparameter_tune_kwargs"], ag_args, time_limit=time_limit)
2110
+ ag_args = self._set_hyperparameter_tune_kwargs_in_ag_args(
2111
+ kwargs["hyperparameter_tune_kwargs"], ag_args, time_limit=time_limit
2112
+ )
2040
2113
 
2041
2114
  fit_new_weighted_ensemble = False # TODO: Add as option
2042
2115
  aux_kwargs = {
@@ -2086,10 +2159,12 @@ class TabularPredictor:
2086
2159
  y_og = self._learner.label_cleaner.inverse_transform(y)
2087
2160
  y_og_classes = y_og.unique()
2088
2161
  y_pseudo_classes = y_pseudo_og.unique()
2089
- matching_classes = np.in1d(y_pseudo_classes, y_og_classes)
2162
+ matching_classes = np.isin(y_pseudo_classes, y_og_classes)
2090
2163
 
2091
2164
  if not matching_classes.all():
2092
- raise Exception(f"Pseudo training data contains classes not in original train data: {y_pseudo_classes[~matching_classes]}")
2165
+ raise Exception(
2166
+ f"Pseudo training data contains classes not in original train data: {y_pseudo_classes[~matching_classes]}"
2167
+ )
2093
2168
 
2094
2169
  name_suffix = kwargs.get("name_suffix", "")
2095
2170
 
@@ -2164,7 +2239,9 @@ class TabularPredictor:
2164
2239
  def _predict_pseudo(self, X_test: pd.DataFrame, use_ensemble: bool):
2165
2240
  if use_ensemble:
2166
2241
  if self.problem_type in PROBLEM_TYPES_CLASSIFICATION:
2167
- test_pseudo_idxes_true, y_pred_proba, y_pred = filter_ensemble_pseudo(predictor=self, unlabeled_data=X_test)
2242
+ test_pseudo_idxes_true, y_pred_proba, y_pred = filter_ensemble_pseudo(
2243
+ predictor=self, unlabeled_data=X_test
2244
+ )
2168
2245
  else:
2169
2246
  test_pseudo_idxes_true, y_pred = filter_ensemble_pseudo(predictor=self, unlabeled_data=X_test)
2170
2247
  y_pred_proba = y_pred.copy()
@@ -2219,7 +2296,11 @@ class TabularPredictor:
2219
2296
  --------
2220
2297
  self: TabularPredictor
2221
2298
  """
2222
- previous_score = self.leaderboard(set_refit_score_to_parent=True).set_index("model", drop=True).loc[self.model_best]["score_val"]
2299
+ previous_score = (
2300
+ self.leaderboard(set_refit_score_to_parent=True)
2301
+ .set_index("model", drop=True)
2302
+ .loc[self.model_best]["score_val"]
2303
+ )
2223
2304
  y_pseudo_og = pd.Series()
2224
2305
  X_test = unlabeled_data.copy()
2225
2306
 
@@ -2235,7 +2316,10 @@ class TabularPredictor:
2235
2316
  logger.log(20, f"Beginning iteration {iter_print} of pseudolabeling out of max {max_iter}")
2236
2317
 
2237
2318
  if len(test_pseudo_idxes_true) < 1:
2238
- logger.log(20, f"Could not confidently assign pseudolabels for any of the provided rows in iteration {iter_print}. Done with pseudolabeling...")
2319
+ logger.log(
2320
+ 20,
2321
+ f"Could not confidently assign pseudolabels for any of the provided rows in iteration {iter_print}. Done with pseudolabeling...",
2322
+ )
2239
2323
  break
2240
2324
  else:
2241
2325
  logger.log(
@@ -2252,7 +2336,9 @@ class TabularPredictor:
2252
2336
  if len(y_pseudo_og) == 0:
2253
2337
  y_pseudo_og = y_pred.loc[test_pseudo_idxes_true.index].copy()
2254
2338
  else:
2255
- y_pseudo_og = pd.concat([y_pseudo_og, y_pred.loc[test_pseudo_idxes_true.index]], verify_integrity=True)
2339
+ y_pseudo_og = pd.concat(
2340
+ [y_pseudo_og, y_pred.loc[test_pseudo_idxes_true.index]], verify_integrity=True
2341
+ )
2256
2342
 
2257
2343
  pseudo_data = unlabeled_data.loc[y_pseudo_og.index]
2258
2344
  pseudo_data[self.label] = y_pseudo_og
@@ -2261,7 +2347,11 @@ class TabularPredictor:
2261
2347
  if fit_ensemble and fit_ensemble_every_iter:
2262
2348
  self._fit_weighted_ensemble_pseudo()
2263
2349
 
2264
- current_score = self.leaderboard(set_refit_score_to_parent=True).set_index("model", drop=True).loc[self.model_best]["score_val"]
2350
+ current_score = (
2351
+ self.leaderboard(set_refit_score_to_parent=True)
2352
+ .set_index("model", drop=True)
2353
+ .loc[self.model_best]["score_val"]
2354
+ )
2265
2355
  logger.log(
2266
2356
  20,
2267
2357
  f"Pseudolabeling algorithm changed validation score from: {previous_score}, to: {current_score}"
@@ -2277,10 +2367,14 @@ class TabularPredictor:
2277
2367
  previous_score = current_score
2278
2368
 
2279
2369
  # Update y_pred_proba and test_pseudo_idxes_true based on the latest pseudolabelled iteration
2280
- y_pred, y_pred_proba, test_pseudo_idxes_true = self._predict_pseudo(X_test=X_test, use_ensemble=use_ensemble)
2370
+ y_pred, y_pred_proba, test_pseudo_idxes_true = self._predict_pseudo(
2371
+ X_test=X_test, use_ensemble=use_ensemble
2372
+ )
2281
2373
  # Update the y_pred_proba_og variable if an improvement was achieved
2282
2374
  if return_pred_prob and test_pseudo_idxes_false is not None:
2283
- y_pred_proba_og.loc[test_pseudo_idxes_false.index] = y_pred_proba.loc[test_pseudo_idxes_false.index]
2375
+ y_pred_proba_og.loc[test_pseudo_idxes_false.index] = y_pred_proba.loc[
2376
+ test_pseudo_idxes_false.index
2377
+ ]
2284
2378
 
2285
2379
  if fit_ensemble and not fit_ensemble_every_iter:
2286
2380
  self._fit_weighted_ensemble_pseudo()
@@ -2423,12 +2517,20 @@ class TabularPredictor:
2423
2517
  # TODO: Consider making calculating this information easier, such as keeping track of meta-info from the latest/original fit call.
2424
2518
  # Currently we use `stack_name == core` to figure out the number of stack levels, but this is somewhat brittle.
2425
2519
  if "num_stack_levels" not in fit_extra_kwargs and not was_fit:
2426
- models_core: list[str] = [m for m, stack_name in self._trainer.get_models_attribute_dict(attribute="stack_name").items() if stack_name == "core"]
2427
- num_stack_levels = max(self._trainer.get_models_attribute_dict(attribute="level", models=models_core).values()) - 1
2520
+ models_core: list[str] = [
2521
+ m
2522
+ for m, stack_name in self._trainer.get_models_attribute_dict(attribute="stack_name").items()
2523
+ if stack_name == "core"
2524
+ ]
2525
+ num_stack_levels = (
2526
+ max(self._trainer.get_models_attribute_dict(attribute="level", models=models_core).values()) - 1
2527
+ )
2428
2528
  fit_extra_kwargs["num_stack_levels"] = num_stack_levels
2429
2529
  if is_labeled:
2430
2530
  logger.log(20, "Fitting predictor using the provided pseudolabeled examples as extra training data...")
2431
- self.fit_extra(pseudo_data=pseudo_data, name_suffix=PSEUDO_MODEL_SUFFIX.format(iter="")[:-1], **fit_extra_kwargs)
2531
+ self.fit_extra(
2532
+ pseudo_data=pseudo_data, name_suffix=PSEUDO_MODEL_SUFFIX.format(iter="")[:-1], **fit_extra_kwargs
2533
+ )
2432
2534
 
2433
2535
  if fit_ensemble:
2434
2536
  logger.log(15, "Fitting weighted ensemble model using best models")
@@ -2499,7 +2601,13 @@ class TabularPredictor:
2499
2601
  data = self._get_dataset(data)
2500
2602
  if decision_threshold is None:
2501
2603
  decision_threshold = self.decision_threshold
2502
- return self._learner.predict(X=data, model=model, as_pandas=as_pandas, transform_features=transform_features, decision_threshold=decision_threshold)
2604
+ return self._learner.predict(
2605
+ X=data,
2606
+ model=model,
2607
+ as_pandas=as_pandas,
2608
+ transform_features=transform_features,
2609
+ decision_threshold=decision_threshold,
2610
+ )
2503
2611
 
2504
2612
  def predict_proba(
2505
2613
  self,
@@ -2551,9 +2659,17 @@ class TabularPredictor:
2551
2659
  f"You can check the value of `predictor.can_predict_proba` to tell if predict_proba is valid."
2552
2660
  )
2553
2661
  data = self._get_dataset(data)
2554
- return self._learner.predict_proba(X=data, model=model, as_pandas=as_pandas, as_multiclass=as_multiclass, transform_features=transform_features)
2662
+ return self._learner.predict_proba(
2663
+ X=data,
2664
+ model=model,
2665
+ as_pandas=as_pandas,
2666
+ as_multiclass=as_multiclass,
2667
+ transform_features=transform_features,
2668
+ )
2555
2669
 
2556
- def predict_from_proba(self, y_pred_proba: pd.DataFrame | np.ndarray, decision_threshold: float | None = None) -> pd.Series | np.array:
2670
+ def predict_from_proba(
2671
+ self, y_pred_proba: pd.DataFrame | np.ndarray, decision_threshold: float | None = None
2672
+ ) -> pd.Series | np.array:
2557
2673
  """
2558
2674
  Given prediction probabilities, convert to predictions.
2559
2675
 
@@ -2586,7 +2702,9 @@ class TabularPredictor:
2586
2702
  >>> y_pred_from_proba = predictor.predict_from_proba(y_pred_proba=y_pred_proba)
2587
2703
  """
2588
2704
  if not self.can_predict_proba:
2589
- raise AssertionError(f'`predictor.predict_from_proba` is not supported when problem_type="{self.problem_type}".')
2705
+ raise AssertionError(
2706
+ f'`predictor.predict_from_proba` is not supported when problem_type="{self.problem_type}".'
2707
+ )
2590
2708
  if decision_threshold is None:
2591
2709
  decision_threshold = self.decision_threshold
2592
2710
  return self._learner.get_pred_from_proba(y_pred_proba=y_pred_proba, decision_threshold=decision_threshold)
@@ -2678,7 +2796,15 @@ class TabularPredictor:
2678
2796
  )
2679
2797
 
2680
2798
  def evaluate_predictions(
2681
- self, y_true, y_pred, sample_weight=None, decision_threshold=None, display: bool = False, auxiliary_metrics=True, detailed_report=False, **kwargs
2799
+ self,
2800
+ y_true,
2801
+ y_pred,
2802
+ sample_weight=None,
2803
+ decision_threshold=None,
2804
+ display: bool = False,
2805
+ auxiliary_metrics=True,
2806
+ detailed_report=False,
2807
+ **kwargs,
2682
2808
  ) -> dict:
2683
2809
  """
2684
2810
  Evaluate the provided prediction probabilities against ground truth labels.
@@ -3219,11 +3345,17 @@ class TabularPredictor:
3219
3345
  model_types = self._trainer.get_models_attribute_dict(attribute="type")
3220
3346
  model_inner_types = self._trainer.get_models_attribute_dict(attribute="type_inner")
3221
3347
  model_typenames = {key: model_types[key].__name__ for key in model_types}
3222
- model_innertypenames = {key: model_inner_types[key].__name__ for key in model_types if key in model_inner_types}
3348
+ model_innertypenames = {
3349
+ key: model_inner_types[key].__name__ for key in model_types if key in model_inner_types
3350
+ }
3223
3351
  MODEL_STR = "Model"
3224
3352
  ENSEMBLE_STR = "Ensemble"
3225
3353
  for model in model_typenames:
3226
- if (model in model_innertypenames) and (ENSEMBLE_STR not in model_innertypenames[model]) and (ENSEMBLE_STR in model_typenames[model]):
3354
+ if (
3355
+ (model in model_innertypenames)
3356
+ and (ENSEMBLE_STR not in model_innertypenames[model])
3357
+ and (ENSEMBLE_STR in model_typenames[model])
3358
+ ):
3227
3359
  new_model_typename = model_typenames[model] + "_" + model_innertypenames[model]
3228
3360
  if new_model_typename.endswith(MODEL_STR):
3229
3361
  new_model_typename = new_model_typename[: -len(MODEL_STR)]
@@ -3286,7 +3418,11 @@ class TabularPredictor:
3286
3418
  print(self.feature_metadata)
3287
3419
  if verbosity > 1: # create plots
3288
3420
  plot_tabular_models(
3289
- results, output_directory=self.path, save_file="SummaryOfModels.html", plot_title="Models produced during fit()", show_plot=show_plot
3421
+ results,
3422
+ output_directory=self.path,
3423
+ save_file="SummaryOfModels.html",
3424
+ plot_title="Models produced during fit()",
3425
+ show_plot=show_plot,
3290
3426
  )
3291
3427
  if hpo_used:
3292
3428
  for model_type in results["hpo_results"]:
@@ -3315,7 +3451,9 @@ class TabularPredictor:
3315
3451
  print(
3316
3452
  f"HPO for {model_type} model: Num. configurations tried = {len(hpo_model['trial_info'])}, Time spent = {hpo_model['total_time']}s, Search strategy = {hpo_model['search_strategy']}"
3317
3453
  )
3318
- print(f"Best hyperparameter-configuration (validation-performance: {self.eval_metric} = {hpo_model['validation_performance']}):")
3454
+ print(
3455
+ f"Best hyperparameter-configuration (validation-performance: {self.eval_metric} = {hpo_model['validation_performance']}):"
3456
+ )
3319
3457
  print(hpo_model["best_config"])
3320
3458
  """
3321
3459
  if bagging_used:
@@ -3403,9 +3541,13 @@ class TabularPredictor:
3403
3541
  """
3404
3542
  self._assert_is_fit("transform_features")
3405
3543
  data = self._get_dataset(data, allow_nan=True)
3406
- return self._learner.get_inputs_to_stacker(dataset=data, model=model, base_models=base_models, use_orig_features=return_original_features)
3544
+ return self._learner.get_inputs_to_stacker(
3545
+ dataset=data, model=model, base_models=base_models, use_orig_features=return_original_features
3546
+ )
3407
3547
 
3408
- def transform_labels(self, labels: np.ndarray | pd.Series, inverse: bool = False, proba: bool = False) -> pd.Series | pd.DataFrame:
3548
+ def transform_labels(
3549
+ self, labels: np.ndarray | pd.Series, inverse: bool = False, proba: bool = False
3550
+ ) -> pd.Series | pd.DataFrame:
3409
3551
  """
3410
3552
  Transforms data labels to the internal label representation.
3411
3553
  This can be useful for training your own models on the same data label representation as AutoGluon.
@@ -3672,13 +3814,17 @@ class TabularPredictor:
3672
3814
  """
3673
3815
  self._assert_is_fit("persist")
3674
3816
  try:
3675
- return self._learner.persist_trainer(low_memory=False, models=models, with_ancestors=with_ancestors, max_memory=max_memory)
3817
+ return self._learner.persist_trainer(
3818
+ low_memory=False, models=models, with_ancestors=with_ancestors, max_memory=max_memory
3819
+ )
3676
3820
  except Exception as e:
3677
3821
  valid_models = self.model_names()
3678
3822
  if isinstance(models, list):
3679
3823
  invalid_models = [m for m in models if m not in valid_models]
3680
3824
  if invalid_models:
3681
- raise ValueError(f"Invalid models specified. The following models do not exist:\n\t{invalid_models}\nValid models:\n\t{valid_models}")
3825
+ raise ValueError(
3826
+ f"Invalid models specified. The following models do not exist:\n\t{invalid_models}\nValid models:\n\t{valid_models}"
3827
+ )
3682
3828
  raise e
3683
3829
 
3684
3830
  def unpersist(self, models="all") -> list[str]:
@@ -3805,7 +3951,9 @@ class TabularPredictor:
3805
3951
  X_pseudo, y_pseudo, _ = self._sanitize_pseudo_data(pseudo_data=train_data_extra, name="train_data_extra")
3806
3952
  kwargs["X_pseudo"] = X_pseudo
3807
3953
  kwargs["y_pseudo"] = y_pseudo
3808
- refit_full_dict = self._learner.refit_ensemble_full(model=model, total_resources=total_resources, fit_strategy=fit_strategy, **kwargs)
3954
+ refit_full_dict = self._learner.refit_ensemble_full(
3955
+ model=model, total_resources=total_resources, fit_strategy=fit_strategy, **kwargs
3956
+ )
3809
3957
 
3810
3958
  if set_best_to_refit_full:
3811
3959
  if isinstance(set_best_to_refit_full, str):
@@ -3841,7 +3989,9 @@ class TabularPredictor:
3841
3989
  )
3842
3990
 
3843
3991
  te = time.time()
3844
- logger.log(20, f'Refit complete, total runtime = {round(te - ts, 2)}s ... Best model: "{self._trainer.model_best}"')
3992
+ logger.log(
3993
+ 20, f'Refit complete, total runtime = {round(te - ts, 2)}s ... Best model: "{self._trainer.model_best}"'
3994
+ )
3845
3995
  return refit_full_dict
3846
3996
 
3847
3997
  @property
@@ -4070,7 +4220,9 @@ class TabularPredictor:
4070
4220
  if base_models is None:
4071
4221
  base_models = trainer.get_model_names(stack_name="core")
4072
4222
 
4073
- X_stack_preds = trainer.get_inputs_to_stacker(X=X, base_models=base_models, fit=fit, use_orig_features=False, use_val_cache=True)
4223
+ X_stack_preds = trainer.get_inputs_to_stacker(
4224
+ X=X, base_models=base_models, fit=fit, use_orig_features=False, use_val_cache=True
4225
+ )
4074
4226
 
4075
4227
  models = []
4076
4228
 
@@ -4179,7 +4331,9 @@ class TabularPredictor:
4179
4331
  # Use `zero_division` parameter to control this behavior.
4180
4332
 
4181
4333
  self._assert_is_fit("calibrate_decision_threshold")
4182
- assert self.problem_type == BINARY, f'calibrate_decision_threshold is only available for `problem_type="{BINARY}"`'
4334
+ assert self.problem_type == BINARY, (
4335
+ f'calibrate_decision_threshold is only available for `problem_type="{BINARY}"`'
4336
+ )
4183
4337
  data = self._get_dataset(data, allow_nan=True)
4184
4338
 
4185
4339
  if metric is None:
@@ -4197,7 +4351,16 @@ class TabularPredictor:
4197
4351
  verbose=verbose,
4198
4352
  )
4199
4353
 
4200
- def predict_oof(self, model: str = None, *, transformed=False, train_data=None, internal_oof=False, decision_threshold=None, can_infer=None) -> pd.Series:
4354
+ def predict_oof(
4355
+ self,
4356
+ model: str = None,
4357
+ *,
4358
+ transformed=False,
4359
+ train_data=None,
4360
+ internal_oof=False,
4361
+ decision_threshold=None,
4362
+ can_infer=None,
4363
+ ) -> pd.Series:
4201
4364
  """
4202
4365
  Note: This is advanced functionality not intended for normal usage.
4203
4366
 
@@ -4229,16 +4392,30 @@ class TabularPredictor:
4229
4392
  if decision_threshold is None:
4230
4393
  decision_threshold = self.decision_threshold
4231
4394
  y_pred_proba_oof = self.predict_proba_oof(
4232
- model=model, transformed=transformed, as_multiclass=True, train_data=train_data, internal_oof=internal_oof, can_infer=can_infer
4395
+ model=model,
4396
+ transformed=transformed,
4397
+ as_multiclass=True,
4398
+ train_data=train_data,
4399
+ internal_oof=internal_oof,
4400
+ can_infer=can_infer,
4401
+ )
4402
+ y_pred_oof = get_pred_from_proba_df(
4403
+ y_pred_proba_oof, problem_type=self.problem_type, decision_threshold=decision_threshold
4233
4404
  )
4234
- y_pred_oof = get_pred_from_proba_df(y_pred_proba_oof, problem_type=self.problem_type, decision_threshold=decision_threshold)
4235
4405
  if transformed:
4236
4406
  return self._learner.label_cleaner.to_transformed_dtype(y_pred_oof)
4237
4407
  return y_pred_oof
4238
4408
 
4239
4409
  # TODO: Remove train_data argument once we start caching the raw original data: Can just load that instead.
4240
4410
  def predict_proba_oof(
4241
- self, model: str = None, *, transformed=False, as_multiclass=True, train_data=None, internal_oof=False, can_infer=None
4411
+ self,
4412
+ model: str = None,
4413
+ *,
4414
+ transformed=False,
4415
+ as_multiclass=True,
4416
+ train_data=None,
4417
+ internal_oof=False,
4418
+ can_infer=None,
4242
4419
  ) -> pd.DataFrame | pd.Series:
4243
4420
  """
4244
4421
  Note: This is advanced functionality not intended for normal usage.
@@ -4303,8 +4480,12 @@ class TabularPredictor:
4303
4480
  if model != model_to_get_oof:
4304
4481
  logger.log(20, f'Using OOF from "{model_to_get_oof}" as a proxy for "{model}".')
4305
4482
  if self._trainer.get_model_attribute_full(model=model_to_get_oof, attribute="val_in_fit", func=max):
4306
- raise AssertionError(f"Model {model_to_get_oof} does not have out-of-fold predictions because it used a validation set during training.")
4307
- y_pred_proba_oof_transformed = self.transform_features(base_models=[model_to_get_oof], return_original_features=False)
4483
+ raise AssertionError(
4484
+ f"Model {model_to_get_oof} does not have out-of-fold predictions because it used a validation set during training."
4485
+ )
4486
+ y_pred_proba_oof_transformed = self.transform_features(
4487
+ base_models=[model_to_get_oof], return_original_features=False
4488
+ )
4308
4489
  if not internal_oof:
4309
4490
  is_duplicate_index = y_pred_proba_oof_transformed.index.duplicated(keep="first")
4310
4491
  if is_duplicate_index.any():
@@ -4330,12 +4511,16 @@ class TabularPredictor:
4330
4511
  missing_idx = list(train_data.index.difference(y_pred_proba_oof_transformed.index))
4331
4512
  if len(missing_idx) > 0:
4332
4513
  missing_idx_data = train_data.loc[missing_idx]
4333
- missing_pred_proba = self.transform_features(data=missing_idx_data, base_models=[model], return_original_features=False)
4514
+ missing_pred_proba = self.transform_features(
4515
+ data=missing_idx_data, base_models=[model], return_original_features=False
4516
+ )
4334
4517
  y_pred_proba_oof_transformed = pd.concat([y_pred_proba_oof_transformed, missing_pred_proba])
4335
4518
  y_pred_proba_oof_transformed = y_pred_proba_oof_transformed.reindex(list(train_data.index))
4336
4519
 
4337
4520
  if self.problem_type == MULTICLASS and self._learner.label_cleaner.problem_type_transform == MULTICLASS:
4338
- y_pred_proba_oof_transformed.columns = copy.deepcopy(self._learner.label_cleaner.ordered_class_labels_transformed)
4521
+ y_pred_proba_oof_transformed.columns = copy.deepcopy(
4522
+ self._learner.label_cleaner.ordered_class_labels_transformed
4523
+ )
4339
4524
  elif self.problem_type == QUANTILE:
4340
4525
  y_pred_proba_oof_transformed.columns = self.quantile_levels
4341
4526
  else:
@@ -4347,10 +4532,14 @@ class TabularPredictor:
4347
4532
  )
4348
4533
  elif self.problem_type == MULTICLASS:
4349
4534
  if transformed:
4350
- y_pred_proba_oof_transformed = LabelCleanerMulticlassToBinary.convert_binary_proba_to_multiclass_proba(
4351
- y_pred_proba_oof_transformed, as_pandas=True
4535
+ y_pred_proba_oof_transformed = (
4536
+ LabelCleanerMulticlassToBinary.convert_binary_proba_to_multiclass_proba(
4537
+ y_pred_proba_oof_transformed, as_pandas=True
4538
+ )
4539
+ )
4540
+ y_pred_proba_oof_transformed.columns = copy.deepcopy(
4541
+ self._learner.label_cleaner.ordered_class_labels_transformed
4352
4542
  )
4353
- y_pred_proba_oof_transformed.columns = copy.deepcopy(self._learner.label_cleaner.ordered_class_labels_transformed)
4354
4543
  if transformed:
4355
4544
  return y_pred_proba_oof_transformed
4356
4545
  else:
@@ -4575,7 +4764,9 @@ class TabularPredictor:
4575
4764
  List of model names
4576
4765
  """
4577
4766
  self._assert_is_fit("model_names")
4578
- model_names = self._trainer.get_model_names(stack_name=stack_name, level=level, can_infer=can_infer, models=models)
4767
+ model_names = self._trainer.get_model_names(
4768
+ stack_name=stack_name, level=level, can_infer=can_infer, models=models
4769
+ )
4579
4770
  if persisted is not None:
4580
4771
  persisted_model_names = list(self._trainer.models.keys())
4581
4772
  if persisted:
@@ -4696,7 +4887,9 @@ class TabularPredictor:
4696
4887
  # Might require using a different tool than pygraphviz to avoid the apt-get commands
4697
4888
  # TODO: v1.0 Rename to `plot_model_graph`
4698
4889
  # TODO: v1.0 Maybe add ensemble weights to the edges.
4699
- def plot_ensemble_model(self, model: str = "best", *, prune_unused_nodes: bool = True, filename: str = "ensemble_model.png") -> str:
4890
+ def plot_ensemble_model(
4891
+ self, model: str = "best", *, prune_unused_nodes: bool = True, filename: str = "ensemble_model.png"
4892
+ ) -> str:
4700
4893
  """
4701
4894
  Output the visualized stack ensemble architecture of a model trained by `fit()`.
4702
4895
  The plot is stored to a file, `ensemble_model.png` in folder `predictor.path` (or by the name specified in `filename`)
@@ -4783,7 +4976,12 @@ class TabularPredictor:
4783
4976
  node_val_score_str = "NaN"
4784
4977
  else:
4785
4978
  node_val_score_str = f"{float(node.attr['val_score']):.4f}"
4786
- label = f"{node.name}" f"\nscore_val: {node_val_score_str}" f"\nfit_time: {fit_time_str}" f"\npred_time_val: {predict_time_str}"
4979
+ label = (
4980
+ f"{node.name}"
4981
+ f"\nscore_val: {node_val_score_str}"
4982
+ f"\nfit_time: {fit_time_str}"
4983
+ f"\npred_time_val: {predict_time_str}"
4984
+ )
4787
4985
  # Remove unnecessary attributes
4788
4986
  node.attr.clear()
4789
4987
  node.attr["label"] = label
@@ -4841,11 +5039,16 @@ class TabularPredictor:
4841
5039
  return True
4842
5040
 
4843
5041
  scheduler_cls, scheduler_params = scheduler_factory(
4844
- hyperparameter_tune_kwargs=hyperparameter_tune_kwargs, time_out=time_limit, nthreads_per_trial="auto", ngpus_per_trial="auto"
5042
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
5043
+ time_out=time_limit,
5044
+ nthreads_per_trial="auto",
5045
+ ngpus_per_trial="auto",
4845
5046
  )
4846
5047
 
4847
5048
  if scheduler_params.get("dist_ip_addrs", None):
4848
- logger.warning("Warning: dist_ip_addrs does not currently work for Tabular. Distributed instances will not be utilized.")
5049
+ logger.warning(
5050
+ "Warning: dist_ip_addrs does not currently work for Tabular. Distributed instances will not be utilized."
5051
+ )
4849
5052
 
4850
5053
  if scheduler_params["num_trials"] == 1:
4851
5054
  logger.warning(
@@ -4855,18 +5058,24 @@ class TabularPredictor:
4855
5058
 
4856
5059
  scheduler_ngpus = scheduler_params["resource"].get("num_gpus", 0)
4857
5060
  if scheduler_ngpus is not None and isinstance(scheduler_ngpus, int) and scheduler_ngpus > 1:
4858
- logger.warning(f"Warning: TabularPredictor currently doesn't use >1 GPU per training run. Detected {scheduler_ngpus} GPUs.")
5061
+ logger.warning(
5062
+ f"Warning: TabularPredictor currently doesn't use >1 GPU per training run. Detected {scheduler_ngpus} GPUs."
5063
+ )
4859
5064
 
4860
5065
  return True
4861
5066
 
4862
5067
  def _set_hyperparameter_tune_kwargs_in_ag_args(self, hyperparameter_tune_kwargs, ag_args, time_limit):
4863
5068
  if hyperparameter_tune_kwargs is not None and "hyperparameter_tune_kwargs" not in ag_args:
4864
5069
  if "hyperparameter_tune_kwargs" in ag_args:
4865
- AssertionError("hyperparameter_tune_kwargs was specified in both ag_args and in kwargs. Please only specify once.")
5070
+ AssertionError(
5071
+ "hyperparameter_tune_kwargs was specified in both ag_args and in kwargs. Please only specify once."
5072
+ )
4866
5073
  else:
4867
5074
  ag_args["hyperparameter_tune_kwargs"] = hyperparameter_tune_kwargs
4868
5075
  if ag_args.get("hyperparameter_tune_kwargs", None) is not None:
4869
- logger.log(30, "Warning: hyperparameter tuning is currently experimental and may cause the process to hang.")
5076
+ logger.log(
5077
+ 30, "Warning: hyperparameter tuning is currently experimental and may cause the process to hang."
5078
+ )
4870
5079
  return ag_args
4871
5080
 
4872
5081
  def _set_post_fit_vars(self, learner: AbstractTabularLearner = None):
@@ -5066,14 +5275,16 @@ class TabularPredictor:
5066
5275
 
5067
5276
  if metadata_init is not None:
5068
5277
  try:
5069
- compare_autogluon_metadata(original=metadata_init, current=metadata_load, check_packages=check_packages)
5278
+ compare_autogluon_metadata(
5279
+ original=metadata_init, current=metadata_load, check_packages=check_packages
5280
+ )
5070
5281
  except:
5071
5282
  logger.log(30, "WARNING: Exception raised while comparing metadata files, skipping comparison...")
5072
5283
  if require_py_version_match:
5073
5284
  if metadata_init["py_version"] != metadata_load["py_version"]:
5074
5285
  raise AssertionError(
5075
- f'Predictor was created on Python version {metadata_init["py_version"]} '
5076
- f'but is being loaded with Python version {metadata_load["py_version"]}. '
5286
+ f"Predictor was created on Python version {metadata_init['py_version']} "
5287
+ f"but is being loaded with Python version {metadata_load['py_version']}. "
5077
5288
  f"Please ensure the versions match to avoid instability. While it is NOT recommended, "
5078
5289
  f"this error can be bypassed by specifying `require_py_version_match=False`."
5079
5290
  )
@@ -5105,7 +5316,9 @@ class TabularPredictor:
5105
5316
  """
5106
5317
  file_path = log_file_path
5107
5318
  if file_path is None:
5108
- assert predictor_path is not None, "Please either provide `predictor_path` or `log_file_path` to load the log file"
5319
+ assert predictor_path is not None, (
5320
+ "Please either provide `predictor_path` or `log_file_path` to load the log file"
5321
+ )
5109
5322
  file_path = os.path.join(predictor_path, "logs", cls._predictor_log_file_name)
5110
5323
  assert os.path.isfile(file_path), f"Log file does not exist at {file_path}"
5111
5324
  lines = []
@@ -5163,7 +5376,9 @@ class TabularPredictor:
5163
5376
  _experimental_dynamic_hyperparameters=False,
5164
5377
  )
5165
5378
  kwargs, ds_valid_keys = self._sanitize_dynamic_stacking_kwargs(kwargs)
5166
- kwargs = self._validate_fit_extra_kwargs(kwargs, extra_valid_keys=list(fit_kwargs_default.keys()) + ds_valid_keys)
5379
+ kwargs = self._validate_fit_extra_kwargs(
5380
+ kwargs, extra_valid_keys=list(fit_kwargs_default.keys()) + ds_valid_keys
5381
+ )
5167
5382
  kwargs_sanitized = fit_kwargs_default.copy()
5168
5383
  kwargs_sanitized.update(kwargs)
5169
5384
 
@@ -5173,7 +5388,8 @@ class TabularPredictor:
5173
5388
  valid_calibrate_decision_threshold_options = [True, False, "auto"]
5174
5389
  if calibrate_decision_threshold not in valid_calibrate_decision_threshold_options:
5175
5390
  raise ValueError(
5176
- f"`calibrate_decision_threshold` must be a value in " f"{valid_calibrate_decision_threshold_options}, but is: {calibrate_decision_threshold}"
5391
+ f"`calibrate_decision_threshold` must be a value in "
5392
+ f"{valid_calibrate_decision_threshold_options}, but is: {calibrate_decision_threshold}"
5177
5393
  )
5178
5394
 
5179
5395
  def _validate_num_cpus(self, num_cpus: int | str):
@@ -5195,7 +5411,9 @@ class TabularPredictor:
5195
5411
  if num_gpus != "auto":
5196
5412
  raise ValueError(f"`num_gpus` must be an int, float, or 'auto'. Value: {num_gpus}")
5197
5413
  elif not isinstance(num_gpus, (int, float)):
5198
- raise TypeError(f"`num_gpus` must be an int, float, or 'auto'. Found: {type(num_gpus)} | Value: {num_gpus}")
5414
+ raise TypeError(
5415
+ f"`num_gpus` must be an int, float, or 'auto'. Found: {type(num_gpus)} | Value: {num_gpus}"
5416
+ )
5199
5417
  else:
5200
5418
  if num_gpus < 0:
5201
5419
  raise ValueError(f"`num_gpus` must be greater than or equal to 0. (num_gpus={num_gpus})")
@@ -5207,7 +5425,9 @@ class TabularPredictor:
5207
5425
  if memory_limit != "auto":
5208
5426
  raise ValueError(f"`memory_limit` must be an int, float, or 'auto'. Value: {memory_limit}")
5209
5427
  elif not isinstance(memory_limit, (int, float)):
5210
- raise TypeError("`memory_limit` must be an int, float, or 'auto'." f" Found: {type(memory_limit)} | Value: {memory_limit}")
5428
+ raise TypeError(
5429
+ f"`memory_limit` must be an int, float, or 'auto'. Found: {type(memory_limit)} | Value: {memory_limit}"
5430
+ )
5211
5431
  else:
5212
5432
  if memory_limit <= 0:
5213
5433
  raise ValueError(f"`memory_limit` must be greater than 0. (memory_limit={memory_limit})")
@@ -5287,23 +5507,43 @@ class TabularPredictor:
5287
5507
  if key_mismatch:
5288
5508
  raise ValueError(f"Got invalid keys for `ds_args`. Allowed: {allowed_kes}. Got: {key_mismatch}")
5289
5509
  if ("validation_procedure" in ds_args) and (
5290
- (not isinstance(ds_args["validation_procedure"], str)) or (ds_args["validation_procedure"] not in ["holdout", "cv"])
5510
+ (not isinstance(ds_args["validation_procedure"], str))
5511
+ or (ds_args["validation_procedure"] not in ["holdout", "cv"])
5291
5512
  ):
5292
- raise ValueError("`validation_procedure` in `ds_args` must be str in {'holdout','cv'}. " + f"Got: {ds_args['validation_procedure']}")
5513
+ raise ValueError(
5514
+ "`validation_procedure` in `ds_args` must be str in {'holdout','cv'}. "
5515
+ + f"Got: {ds_args['validation_procedure']}"
5516
+ )
5293
5517
  for arg_name in ["clean_up_fits", "enable_ray_logging"]:
5294
5518
  if (arg_name in ds_args) and (not isinstance(ds_args[arg_name], bool)):
5295
5519
  raise ValueError(f"`{arg_name}` in `ds_args` must be bool. Got: {type(ds_args[arg_name])}")
5296
5520
  if "memory_safe_fits" in ds_args and not isinstance(ds_args["memory_safe_fits"], (bool, str)):
5297
- raise ValueError(f"`memory_safe_fits` in `ds_args` must be bool or 'auto'. Got: {type(ds_args['memory_safe_fits'])}")
5521
+ raise ValueError(
5522
+ f"`memory_safe_fits` in `ds_args` must be bool or 'auto'. Got: {type(ds_args['memory_safe_fits'])}"
5523
+ )
5298
5524
  for arg_name in ["detection_time_frac", "holdout_frac"]:
5299
- if (arg_name in ds_args) and ((not isinstance(ds_args[arg_name], float)) or (ds_args[arg_name] >= 1) or (ds_args[arg_name] <= 0)):
5300
- raise ValueError(f"`{arg_name}` in `ds_args` must be float in (0,1). Got: {type(ds_args[arg_name])}, {ds_args[arg_name]}")
5525
+ if (arg_name in ds_args) and (
5526
+ (not isinstance(ds_args[arg_name], float)) or (ds_args[arg_name] >= 1) or (ds_args[arg_name] <= 0)
5527
+ ):
5528
+ raise ValueError(
5529
+ f"`{arg_name}` in `ds_args` must be float in (0,1). Got: {type(ds_args[arg_name])}, {ds_args[arg_name]}"
5530
+ )
5301
5531
  if ("n_folds" in ds_args) and ((not isinstance(ds_args["n_folds"], int)) or (ds_args["n_folds"] < 2)):
5302
- raise ValueError(f"`n_folds` in `ds_args` must be int in [2, +inf). Got: {type(ds_args['n_folds'])}, {ds_args['n_folds']}")
5532
+ raise ValueError(
5533
+ f"`n_folds` in `ds_args` must be int in [2, +inf). Got: {type(ds_args['n_folds'])}, {ds_args['n_folds']}"
5534
+ )
5303
5535
  if ("n_repeats" in ds_args) and ((not isinstance(ds_args["n_repeats"], int)) or (ds_args["n_repeats"] < 1)):
5304
- raise ValueError(f"`n_repeats` in `ds_args` must be int in [1, +inf). Got: {type(ds_args['n_repeats'])}, {ds_args['n_repeats']}")
5305
- if ("holdout_data" in ds_args) and (not isinstance(ds_args["holdout_data"], (str, pd.DataFrame))) and (ds_args["holdout_data"] is not None):
5306
- raise ValueError(f"`holdout_data` in `ds_args` must be None, str, or pd.DataFrame. Got: {type(ds_args['holdout_data'])}")
5536
+ raise ValueError(
5537
+ f"`n_repeats` in `ds_args` must be int in [1, +inf). Got: {type(ds_args['n_repeats'])}, {ds_args['n_repeats']}"
5538
+ )
5539
+ if (
5540
+ ("holdout_data" in ds_args)
5541
+ and (not isinstance(ds_args["holdout_data"], (str, pd.DataFrame)))
5542
+ and (ds_args["holdout_data"] is not None)
5543
+ ):
5544
+ raise ValueError(
5545
+ f"`holdout_data` in `ds_args` must be None, str, or pd.DataFrame. Got: {type(ds_args['holdout_data'])}"
5546
+ )
5307
5547
  if (ds_args["validation_procedure"] == "cv") and (ds_args["holdout_data"] is not None):
5308
5548
  raise ValueError(
5309
5549
  "`validation_procedure` in `ds_args` is 'cv' but `holdout_data` in `ds_args` is specified."
@@ -5322,7 +5562,9 @@ class TabularPredictor:
5322
5562
  if kwarg_name not in allowed_kwarg_names:
5323
5563
  public_kwarg_options = [kwarg for kwarg in allowed_kwarg_names if kwarg[0] != "_"]
5324
5564
  public_kwarg_options.sort()
5325
- raise ValueError(f"Unknown `.fit` keyword argument specified: '{kwarg_name}'\nValid kwargs: {public_kwarg_options}")
5565
+ raise ValueError(
5566
+ f"Unknown `.fit` keyword argument specified: '{kwarg_name}'\nValid kwargs: {public_kwarg_options}"
5567
+ )
5326
5568
 
5327
5569
  kwargs_sanitized = fit_extra_kwargs_default.copy()
5328
5570
  kwargs_sanitized.update(kwargs)
@@ -5335,7 +5577,9 @@ class TabularPredictor:
5335
5577
  refit_full = kwargs_sanitized["refit_full"]
5336
5578
  set_best_to_refit_full = kwargs_sanitized["set_best_to_refit_full"]
5337
5579
  if refit_full and not self._learner.cache_data:
5338
- raise ValueError("`refit_full=True` is only available when `cache_data=True`. Set `cache_data=True` to utilize `refit_full`.")
5580
+ raise ValueError(
5581
+ "`refit_full=True` is only available when `cache_data=True`. Set `cache_data=True` to utilize `refit_full`."
5582
+ )
5339
5583
  if set_best_to_refit_full and not refit_full:
5340
5584
  raise ValueError(
5341
5585
  "`set_best_to_refit_full=True` is only available when `refit_full=True`. Set `refit_full=True` to utilize `set_best_to_refit_full`."
@@ -5388,20 +5632,30 @@ class TabularPredictor:
5388
5632
  unlabeled_data = TabularDataset(unlabeled_data)
5389
5633
 
5390
5634
  if not isinstance(train_data, pd.DataFrame):
5391
- raise AssertionError(f"train_data is required to be a pandas DataFrame, but was instead: {type(train_data)}")
5635
+ raise AssertionError(
5636
+ f"train_data is required to be a pandas DataFrame, but was instead: {type(train_data)}"
5637
+ )
5392
5638
 
5393
5639
  if len(set(train_data.columns)) < len(train_data.columns):
5394
5640
  raise ValueError(
5395
5641
  "Column names are not unique, please change duplicated column names (in pandas: train_data.rename(columns={'current_name':'new_name'})"
5396
5642
  )
5397
5643
 
5398
- self._validate_single_fit_dataset(train_data=train_data, other_data=tuning_data, name="tuning_data", is_labeled=True)
5399
- self._validate_single_fit_dataset(train_data=train_data, other_data=test_data, name="test_data", is_labeled=True)
5400
- self._validate_single_fit_dataset(train_data=train_data, other_data=unlabeled_data, name="unlabeled_data", is_labeled=False)
5644
+ self._validate_single_fit_dataset(
5645
+ train_data=train_data, other_data=tuning_data, name="tuning_data", is_labeled=True
5646
+ )
5647
+ self._validate_single_fit_dataset(
5648
+ train_data=train_data, other_data=test_data, name="test_data", is_labeled=True
5649
+ )
5650
+ self._validate_single_fit_dataset(
5651
+ train_data=train_data, other_data=unlabeled_data, name="unlabeled_data", is_labeled=False
5652
+ )
5401
5653
 
5402
5654
  return train_data, tuning_data, test_data, unlabeled_data
5403
5655
 
5404
- def _validate_single_fit_dataset(self, train_data: pd.DataFrame, other_data: pd.DataFrame, name: str, is_labeled: bool = True):
5656
+ def _validate_single_fit_dataset(
5657
+ self, train_data: pd.DataFrame, other_data: pd.DataFrame, name: str, is_labeled: bool = True
5658
+ ):
5405
5659
  """
5406
5660
  Validates additional dataset, ensuring format is consistent with train dataset.
5407
5661
 
@@ -5422,11 +5676,15 @@ class TabularPredictor:
5422
5676
  """
5423
5677
  if other_data is not None:
5424
5678
  if not isinstance(other_data, pd.DataFrame):
5425
- raise AssertionError(f"{name} is required to be a pandas DataFrame, but was instead: {type(other_data)}")
5679
+ raise AssertionError(
5680
+ f"{name} is required to be a pandas DataFrame, but was instead: {type(other_data)}"
5681
+ )
5426
5682
  self._validate_unique_indices(data=other_data, name=name)
5427
5683
  train_features = [column for column in train_data.columns if column != self.label]
5428
5684
  other_features = [column for column in other_data.columns if column != self.label]
5429
- train_features, other_features = self._prune_data_features(train_features=train_features, other_features=other_features, is_labeled=is_labeled)
5685
+ train_features, other_features = self._prune_data_features(
5686
+ train_features=train_features, other_features=other_features, is_labeled=is_labeled
5687
+ )
5430
5688
  train_features = np.array(train_features)
5431
5689
  other_features = np.array(other_features)
5432
5690
  if np.any(train_features != other_features):
@@ -5444,7 +5702,9 @@ class TabularPredictor:
5444
5702
  f"\tAutoGluon will attempt to convert the dtypes to align."
5445
5703
  )
5446
5704
 
5447
- def _initialize_learning_curve_params(self, learning_curves: dict | bool | None = None, problem_type: str | None = None) -> dict:
5705
+ def _initialize_learning_curve_params(
5706
+ self, learning_curves: dict | bool | None = None, problem_type: str | None = None
5707
+ ) -> dict:
5448
5708
  """
5449
5709
  Convert users learning_curve dict parameters into ag_param format.
5450
5710
  Also, converts all metrics into list of autogluon Scorer objects.
@@ -5524,7 +5784,9 @@ class TabularPredictor:
5524
5784
  def _validate_infer_limit(infer_limit: float, infer_limit_batch_size: int) -> tuple[float, int]:
5525
5785
  if infer_limit_batch_size is not None:
5526
5786
  if not isinstance(infer_limit_batch_size, int):
5527
- raise ValueError(f"infer_limit_batch_size must be type int, but was instead type {type(infer_limit_batch_size)}")
5787
+ raise ValueError(
5788
+ f"infer_limit_batch_size must be type int, but was instead type {type(infer_limit_batch_size)}"
5789
+ )
5528
5790
  elif infer_limit_batch_size < 1:
5529
5791
  raise AssertionError(f"infer_limit_batch_size must be >=1, value: {infer_limit_batch_size}")
5530
5792
  if infer_limit is not None:
@@ -5534,7 +5796,10 @@ class TabularPredictor:
5534
5796
  raise AssertionError(f"infer_limit must be greater than zero! (infer_limit={infer_limit})")
5535
5797
  if infer_limit is not None and infer_limit_batch_size is None:
5536
5798
  infer_limit_batch_size = 10000
5537
- logger.log(20, f"infer_limit specified, but infer_limit_batch_size was not specified. Setting infer_limit_batch_size={infer_limit_batch_size}")
5799
+ logger.log(
5800
+ 20,
5801
+ f"infer_limit specified, but infer_limit_batch_size was not specified. Setting infer_limit_batch_size={infer_limit_batch_size}",
5802
+ )
5538
5803
  return infer_limit, infer_limit_batch_size
5539
5804
 
5540
5805
  def _set_feature_generator(self, feature_generator="auto", feature_metadata=None, init_kwargs=None):
@@ -5565,7 +5830,9 @@ class TabularPredictor:
5565
5830
  if num_bag_folds < 2 and num_bag_folds != 0:
5566
5831
  raise ValueError(f"num_bag_folds must be equal to 0 or >=2. (num_bag_folds={num_bag_folds})")
5567
5832
  if num_stack_levels != 0 and num_bag_folds == 0:
5568
- raise ValueError(f"num_stack_levels must be 0 if num_bag_folds is 0. (num_stack_levels={num_stack_levels}, num_bag_folds={num_bag_folds})")
5833
+ raise ValueError(
5834
+ f"num_stack_levels must be 0 if num_bag_folds is 0. (num_stack_levels={num_stack_levels}, num_bag_folds={num_bag_folds})"
5835
+ )
5569
5836
  if not isinstance(num_bag_sets, int):
5570
5837
  raise ValueError(f"num_bag_sets must be an integer. (num_bag_sets={num_bag_sets})")
5571
5838
  if not isinstance(dynamic_stacking, bool):
@@ -5575,9 +5842,13 @@ class TabularPredictor:
5575
5842
 
5576
5843
  if use_bag_holdout_was_auto and num_bag_folds != 0:
5577
5844
  if use_bag_holdout:
5578
- log_extra = f"Reason: num_train_rows >= {USE_BAG_HOLDOUT_AUTO_THRESHOLD}. (num_train_rows={num_train_rows})"
5845
+ log_extra = (
5846
+ f"Reason: num_train_rows >= {USE_BAG_HOLDOUT_AUTO_THRESHOLD}. (num_train_rows={num_train_rows})"
5847
+ )
5579
5848
  else:
5580
- log_extra = f"Reason: num_train_rows < {USE_BAG_HOLDOUT_AUTO_THRESHOLD}. (num_train_rows={num_train_rows})"
5849
+ log_extra = (
5850
+ f"Reason: num_train_rows < {USE_BAG_HOLDOUT_AUTO_THRESHOLD}. (num_train_rows={num_train_rows})"
5851
+ )
5581
5852
  logger.log(20, f"Setting use_bag_holdout from 'auto' to {use_bag_holdout}. {log_extra}")
5582
5853
 
5583
5854
  if dynamic_stacking and num_stack_levels < 1:
@@ -5629,7 +5900,9 @@ class TabularPredictor:
5629
5900
  )
5630
5901
  return self.__class__.load(path=path_clone) if return_clone else path_clone
5631
5902
 
5632
- def clone_for_deployment(self, path: str, *, model: str = "best", return_clone: bool = False, dirs_exist_ok: bool = False) -> str | "TabularPredictor":
5903
+ def clone_for_deployment(
5904
+ self, path: str, *, model: str = "best", return_clone: bool = False, dirs_exist_ok: bool = False
5905
+ ) -> str | "TabularPredictor":
5633
5906
  """
5634
5907
  Clone the predictor and all of its artifacts to a new location on local disk,
5635
5908
  then delete the clones artifacts unnecessary during prediction.
@@ -5718,7 +5991,9 @@ class TabularPredictor:
5718
5991
  if self.can_predict_proba:
5719
5992
  pred_proba_dict_val = self.predict_proba_multi(inverse_transform=False, as_multiclass=False, models=models)
5720
5993
  if test_data is not None:
5721
- pred_proba_dict_test = self.predict_proba_multi(test_data, inverse_transform=False, as_multiclass=False, models=models)
5994
+ pred_proba_dict_test = self.predict_proba_multi(
5995
+ test_data, inverse_transform=False, as_multiclass=False, models=models
5996
+ )
5722
5997
  else:
5723
5998
  pred_proba_dict_val = self.predict_multi(inverse_transform=False, models=models)
5724
5999
  if test_data is not None:
@@ -5815,7 +6090,9 @@ class TabularPredictor:
5815
6090
  else:
5816
6091
  _validate_hyperparameters_util(params=hyperparameters)
5817
6092
 
5818
- def _sanitize_pseudo_data(self, pseudo_data: pd.DataFrame, name="pseudo_data") -> tuple[pd.DataFrame, pd.Series, pd.Series]:
6093
+ def _sanitize_pseudo_data(
6094
+ self, pseudo_data: pd.DataFrame, name="pseudo_data"
6095
+ ) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
5819
6096
  assert isinstance(pseudo_data, pd.DataFrame)
5820
6097
  if self.label not in pseudo_data.columns:
5821
6098
  raise ValueError(f"'{name}' does not contain the labeled column.")
@@ -5829,7 +6106,9 @@ class TabularPredictor:
5829
6106
  y_pseudo = self._learner.label_cleaner.transform(y_pseudo_og)
5830
6107
 
5831
6108
  if np.isnan(y_pseudo.unique()).any():
5832
- raise Exception(f"NaN was found in the label column for {name}." "Please ensure no NaN values in target column")
6109
+ raise Exception(
6110
+ f"NaN was found in the label column for {name}.Please ensure no NaN values in target column"
6111
+ )
5833
6112
  return X_pseudo, y_pseudo, y_pseudo_og
5834
6113
 
5835
6114
  def _assert_is_fit(self, message_suffix: str = None):
@@ -5890,7 +6169,9 @@ def _dystack(
5890
6169
  return False, None, e
5891
6170
 
5892
6171
  if not predictor.model_names():
5893
- logger.log(20, f"Unable to determine stacked overfitting. AutoGluon's sub-fit did not successfully train any models!")
6172
+ logger.log(
6173
+ 20, f"Unable to determine stacked overfitting. AutoGluon's sub-fit did not successfully train any models!"
6174
+ )
5894
6175
  stacked_overfitting = False
5895
6176
  ho_leaderboard = None
5896
6177
  else: