autogluon.tabular 1.5.0b20251228__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.0b20251228-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -70,7 +70,10 @@ class AbstractTabularLearner(AbstractLearner):
70
70
  if isinstance(quantile_levels, Iterable):
71
71
  for quantile in quantile_levels:
72
72
  if quantile <= 0.0 or quantile >= 1.0:
73
- raise ValueError("quantile values have to be non-negative and less than 1.0 (0.0 < q < 1.0). " "For example, 0.95 quantile = 95 percentile")
73
+ raise ValueError(
74
+ "quantile values have to be non-negative and less than 1.0 (0.0 < q < 1.0). "
75
+ "For example, 0.95 quantile = 95 percentile"
76
+ )
74
77
  quantile_levels = np.sort(np.array(quantile_levels))
75
78
  self.quantile_levels = quantile_levels
76
79
 
@@ -188,7 +191,11 @@ class AbstractTabularLearner(AbstractLearner):
188
191
  X = self.transform_features(X)
189
192
  y_pred_proba = self.load_trainer().predict_proba(X, model=model)
190
193
  y_pred_proba = self._post_process_predict_proba(
191
- y_pred_proba=y_pred_proba, as_pandas=as_pandas, index=X_index, as_multiclass=as_multiclass, inverse_transform=inverse_transform
194
+ y_pred_proba=y_pred_proba,
195
+ as_pandas=as_pandas,
196
+ index=X_index,
197
+ as_multiclass=as_multiclass,
198
+ inverse_transform=inverse_transform,
192
199
  )
193
200
  return y_pred_proba
194
201
 
@@ -206,11 +213,20 @@ class AbstractTabularLearner(AbstractLearner):
206
213
  decision_threshold = 0.5
207
214
  X_index = copy.deepcopy(X.index) if as_pandas else None
208
215
  y_pred_proba = self.predict_proba(
209
- X=X, model=model, as_pandas=False, as_multiclass=False, inverse_transform=False, transform_features=transform_features
216
+ X=X,
217
+ model=model,
218
+ as_pandas=False,
219
+ as_multiclass=False,
220
+ inverse_transform=False,
221
+ transform_features=transform_features,
210
222
  )
211
223
  problem_type = self.label_cleaner.problem_type_transform or self.problem_type
212
- y_pred = get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=problem_type, decision_threshold=decision_threshold)
213
- y_pred = self._post_process_predict(y_pred=y_pred, as_pandas=as_pandas, index=X_index, inverse_transform=inverse_transform)
224
+ y_pred = get_pred_from_proba(
225
+ y_pred_proba=y_pred_proba, problem_type=problem_type, decision_threshold=decision_threshold
226
+ )
227
+ y_pred = self._post_process_predict(
228
+ y_pred=y_pred, as_pandas=as_pandas, index=X_index, inverse_transform=inverse_transform
229
+ )
214
230
  return y_pred
215
231
 
216
232
  def _post_process_predict(
@@ -242,7 +258,12 @@ class AbstractTabularLearner(AbstractLearner):
242
258
  return y_pred
243
259
 
244
260
  def _post_process_predict_proba(
245
- self, y_pred_proba: np.ndarray, as_pandas: bool = True, index=None, as_multiclass: bool = True, inverse_transform: bool = True
261
+ self,
262
+ y_pred_proba: np.ndarray,
263
+ as_pandas: bool = True,
264
+ index=None,
265
+ as_multiclass: bool = True,
266
+ inverse_transform: bool = True,
246
267
  ):
247
268
  """
248
269
  Given internal prediction probabilities, post-process them to vend to user.
@@ -338,7 +359,11 @@ class AbstractTabularLearner(AbstractLearner):
338
359
  # Inverse Transform labels
339
360
  for m, pred_proba in predict_proba_dict.items():
340
361
  predict_proba_dict[m] = self._post_process_predict_proba(
341
- y_pred_proba=pred_proba, as_pandas=as_pandas, as_multiclass=as_multiclass, index=X_index, inverse_transform=inverse_transform
362
+ y_pred_proba=pred_proba,
363
+ as_pandas=as_pandas,
364
+ as_multiclass=as_multiclass,
365
+ index=X_index,
366
+ inverse_transform=inverse_transform,
342
367
  )
343
368
  return predict_proba_dict
344
369
 
@@ -369,18 +394,29 @@ class AbstractTabularLearner(AbstractLearner):
369
394
  predict_dict = {}
370
395
  for m in predict_proba_dict:
371
396
  predict_dict[m] = self.get_pred_from_proba(
372
- y_pred_proba=predict_proba_dict[m], decision_threshold=decision_threshold, inverse_transform=inverse_transform
397
+ y_pred_proba=predict_proba_dict[m],
398
+ decision_threshold=decision_threshold,
399
+ inverse_transform=inverse_transform,
373
400
  )
374
401
  return predict_dict
375
402
 
376
403
  def get_pred_from_proba(
377
- self, y_pred_proba: np.ndarray | pd.DataFrame, decision_threshold: float | None = None, inverse_transform: bool = True
404
+ self,
405
+ y_pred_proba: np.ndarray | pd.DataFrame,
406
+ decision_threshold: float | None = None,
407
+ inverse_transform: bool = True,
378
408
  ) -> np.array | pd.Series:
379
409
  if isinstance(y_pred_proba, pd.DataFrame):
380
- y_pred = get_pred_from_proba_df(y_pred_proba, problem_type=self.problem_type, decision_threshold=decision_threshold)
410
+ y_pred = get_pred_from_proba_df(
411
+ y_pred_proba, problem_type=self.problem_type, decision_threshold=decision_threshold
412
+ )
381
413
  else:
382
- y_pred = get_pred_from_proba(y_pred_proba, problem_type=self.problem_type, decision_threshold=decision_threshold)
383
- y_pred = self._post_process_predict(y_pred=y_pred, as_pandas=False, index=None, inverse_transform=inverse_transform)
414
+ y_pred = get_pred_from_proba(
415
+ y_pred_proba, problem_type=self.problem_type, decision_threshold=decision_threshold
416
+ )
417
+ y_pred = self._post_process_predict(
418
+ y_pred=y_pred, as_pandas=False, index=None, inverse_transform=inverse_transform
419
+ )
384
420
  return y_pred
385
421
 
386
422
  def _validate_fit_input(self, X: DataFrame, **kwargs):
@@ -398,7 +434,9 @@ class AbstractTabularLearner(AbstractLearner):
398
434
  Ensure that the label column is present in the training data
399
435
  """
400
436
  if self.label not in X.columns:
401
- raise KeyError(f"Label column '{self.label}' is missing from training data. Training data columns: {list(X.columns)}")
437
+ raise KeyError(
438
+ f"Label column '{self.label}' is missing from training data. Training data columns: {list(X.columns)}"
439
+ )
402
440
 
403
441
  def _validate_sample_weight(self, X, X_val):
404
442
  if self.sample_weight is not None:
@@ -408,7 +446,9 @@ class AbstractTabularLearner(AbstractLearner):
408
446
  prefix += " Warning: We do not recommend weight_evaluation=True with predefined sample weighting."
409
447
  else:
410
448
  if self.sample_weight not in X.columns:
411
- raise KeyError(f"sample_weight column '{self.sample_weight}' is missing from training data. Training data columns: {list(X.columns)}")
449
+ raise KeyError(
450
+ f"sample_weight column '{self.sample_weight}' is missing from training data. Training data columns: {list(X.columns)}"
451
+ )
412
452
  weight_vals = X[self.sample_weight]
413
453
  if weight_vals.isna().sum() > 0:
414
454
  raise ValueError(f"Sample weights in column '{self.sample_weight}' cannot be nan")
@@ -417,8 +457,12 @@ class AbstractTabularLearner(AbstractLearner):
417
457
  if weight_vals.min() < 0:
418
458
  raise ValueError(f"Sample weights in column '{self.sample_weight}' must be nonnegative")
419
459
  if self.weight_evaluation and X_val is not None and self.sample_weight not in X_val.columns:
420
- raise KeyError(f"sample_weight column '{self.sample_weight}' cannot be missing from validation data if weight_evaluation=True")
421
- prefix = f"Values in column '{self.sample_weight}' used as sample weights instead of predictive features."
460
+ raise KeyError(
461
+ f"sample_weight column '{self.sample_weight}' cannot be missing from validation data if weight_evaluation=True"
462
+ )
463
+ prefix = (
464
+ f"Values in column '{self.sample_weight}' used as sample weights instead of predictive features."
465
+ )
422
466
  if self.weight_evaluation:
423
467
  suffix = " Evaluation will report weighted metrics, so ensure same column exists in test data."
424
468
  else:
@@ -428,12 +472,18 @@ class AbstractTabularLearner(AbstractLearner):
428
472
  def _validate_groups(self, X, X_val):
429
473
  if self.groups is not None:
430
474
  if self.groups not in X.columns:
431
- raise KeyError(f"groups column '{self.groups}' is missing from training data. Training data columns: {list(X.columns)}")
475
+ raise KeyError(
476
+ f"groups column '{self.groups}' is missing from training data. Training data columns: {list(X.columns)}"
477
+ )
432
478
  groups_vals = X[self.groups]
433
479
  if len(groups_vals.unique()) < 2:
434
- raise ValueError(f"Groups in column '{self.groups}' cannot have fewer than 2 unique values. Values: {list(groups_vals.unique())}")
480
+ raise ValueError(
481
+ f"Groups in column '{self.groups}' cannot have fewer than 2 unique values. Values: {list(groups_vals.unique())}"
482
+ )
435
483
  if X_val is not None and self.groups in X_val.columns:
436
- raise KeyError(f"groups column '{self.groups}' cannot be in validation data. Validation data columns: {list(X_val.columns)}")
484
+ raise KeyError(
485
+ f"groups column '{self.groups}' cannot be in validation data. Validation data columns: {list(X_val.columns)}"
486
+ )
437
487
  logger.log(
438
488
  20,
439
489
  f"Values in column '{self.groups}' used as split folds instead of being automatically set. Bagged models will have {len(groups_vals.unique())} splits.",
@@ -534,7 +584,12 @@ class AbstractTabularLearner(AbstractLearner):
534
584
  set_refit_score_to_parent=False,
535
585
  display=False,
536
586
  ):
537
- leaderboard_df = self.leaderboard(extra_info=extra_info, refit_full=refit_full, set_refit_score_to_parent=set_refit_score_to_parent, display=display)
587
+ leaderboard_df = self.leaderboard(
588
+ extra_info=extra_info,
589
+ refit_full=refit_full,
590
+ set_refit_score_to_parent=set_refit_score_to_parent,
591
+ display=display,
592
+ )
538
593
  if extra_metrics is None:
539
594
  extra_metrics = []
540
595
  if y is None:
@@ -559,14 +614,21 @@ class AbstractTabularLearner(AbstractLearner):
559
614
  all_trained_models = [m for m in all_trained_models if m in leaderboard_models]
560
615
  all_trained_models_can_infer = trainer.get_model_names(models=all_trained_models, can_infer=True)
561
616
  all_trained_models_original = all_trained_models.copy()
562
- model_pred_proba_dict, pred_time_test_marginal = trainer.get_model_pred_proba_dict(X=X, models=all_trained_models_can_infer, record_pred_time=True)
617
+ model_pred_proba_dict, pred_time_test_marginal = trainer.get_model_pred_proba_dict(
618
+ X=X, models=all_trained_models_can_infer, record_pred_time=True
619
+ )
563
620
 
564
621
  if compute_oracle:
565
622
  pred_probas = list(model_pred_proba_dict.values())
566
623
  ensemble_selection = EnsembleSelection(
567
- ensemble_size=100, problem_type=trainer.problem_type, metric=self.eval_metric, quantile_levels=self.quantile_levels
624
+ ensemble_size=100,
625
+ problem_type=trainer.problem_type,
626
+ metric=self.eval_metric,
627
+ quantile_levels=self.quantile_levels,
568
628
  )
569
- ensemble_selection.fit(predictions=pred_probas, labels=y_internal, identifiers=None, sample_weight=w) # TODO: Only fit non-nan
629
+ ensemble_selection.fit(
630
+ predictions=pred_probas, labels=y_internal, identifiers=None, sample_weight=w
631
+ ) # TODO: Only fit non-nan
570
632
 
571
633
  oracle_weights = ensemble_selection.weights_
572
634
  oracle_pred_time_start = time.time()
@@ -585,14 +647,20 @@ class AbstractTabularLearner(AbstractLearner):
585
647
  scores[model_name] = np.nan
586
648
  else:
587
649
  scores[model_name] = self.score_with_pred_proba(
588
- y_pred_proba_internal=y_pred_proba_internal, metric=self.eval_metric, decision_threshold=decision_threshold, **scoring_args
650
+ y_pred_proba_internal=y_pred_proba_internal,
651
+ metric=self.eval_metric,
652
+ decision_threshold=decision_threshold,
653
+ **scoring_args,
589
654
  )
590
655
  for metric in extra_metrics:
591
656
  metric = get_metric(metric, self.problem_type, "leaderboard_metric")
592
657
  if metric.name not in extra_scores:
593
658
  extra_scores[metric.name] = {}
594
659
  extra_scores[metric.name][model_name] = self.score_with_pred_proba(
595
- y_pred_proba_internal=y_pred_proba_internal, metric=metric, decision_threshold=decision_threshold, **scoring_args
660
+ y_pred_proba_internal=y_pred_proba_internal,
661
+ metric=metric,
662
+ decision_threshold=decision_threshold,
663
+ **scoring_args,
596
664
  )
597
665
 
598
666
  if extra_scores:
@@ -629,8 +697,6 @@ class AbstractTabularLearner(AbstractLearner):
629
697
  pred_time_test[model] = None
630
698
  pred_time_test_marginal[model] = None
631
699
 
632
- logger.debug("Model scores:")
633
- logger.debug(str(scores))
634
700
  model_names_final = list(scores.keys())
635
701
  df = pd.DataFrame(
636
702
  data={
@@ -645,7 +711,8 @@ class AbstractTabularLearner(AbstractLearner):
645
711
 
646
712
  df_merged = pd.merge(df, leaderboard_df, on="model", how="left")
647
713
  df_merged = df_merged.sort_values(
648
- by=["score_test", "pred_time_test", "score_val", "pred_time_val", "model"], ascending=[False, True, False, True, False]
714
+ by=["score_test", "pred_time_test", "score_val", "pred_time_val", "model"],
715
+ ascending=[False, True, False, True, False],
649
716
  ).reset_index(drop=True)
650
717
  df_columns_lst = df_merged.columns.tolist()
651
718
  explicit_order = [
@@ -692,7 +759,9 @@ class AbstractTabularLearner(AbstractLearner):
692
759
  if metric.needs_pred or metric.needs_quantile:
693
760
  if self.problem_type == BINARY:
694
761
  # Use 1 and 0, otherwise f1 can crash due to unknown pos_label.
695
- y_pred = self.get_pred_from_proba(y_pred_proba_internal, decision_threshold=decision_threshold, inverse_transform=False)
762
+ y_pred = self.get_pred_from_proba(
763
+ y_pred_proba_internal, decision_threshold=decision_threshold, inverse_transform=False
764
+ )
696
765
  y_pred_proba = None
697
766
  y_tmp = y_internal
698
767
  else:
@@ -777,7 +846,16 @@ class AbstractTabularLearner(AbstractLearner):
777
846
  f"\n\t Known classes: {self.class_labels}"
778
847
  )
779
848
 
780
- def evaluate_predictions(self, y_true, y_pred, sample_weight=None, decision_threshold=None, display=False, auxiliary_metrics=True, detailed_report=False):
849
+ def evaluate_predictions(
850
+ self,
851
+ y_true,
852
+ y_pred,
853
+ sample_weight=None,
854
+ decision_threshold=None,
855
+ display=False,
856
+ auxiliary_metrics=True,
857
+ detailed_report=False,
858
+ ):
781
859
  """Evaluate predictions. Does not support sample weights since this method reports a variety of metrics.
782
860
  Args:
783
861
  display (bool): Should we print which metric is being used as well as performance.
@@ -868,13 +946,18 @@ class AbstractTabularLearner(AbstractLearner):
868
946
  if isinstance(aux_metric, str):
869
947
  aux_metric = get_metric(metric=aux_metric, problem_type=self.problem_type, metric_type="aux_metric")
870
948
  if not aux_metric.needs_pred and y_pred_proba_internal is None:
871
- logger.log(15, f"Skipping {aux_metric.name} because no prediction probabilities are available to score.")
949
+ logger.log(
950
+ 15, f"Skipping {aux_metric.name} because no prediction probabilities are available to score."
951
+ )
872
952
  continue
873
953
 
874
954
  if aux_metric.name not in performance_dict:
875
955
  if y_pred_proba_internal is not None:
876
956
  score = self.score_with_pred_proba(
877
- y_pred_proba_internal=y_pred_proba_internal, metric=aux_metric, decision_threshold=decision_threshold, **scoring_args
957
+ y_pred_proba_internal=y_pred_proba_internal,
958
+ metric=aux_metric,
959
+ decision_threshold=decision_threshold,
960
+ **scoring_args,
878
961
  )
879
962
  else:
880
963
  score = self.score_with_pred(y_pred_internal=y_pred_internal, metric=aux_metric, **scoring_args)
@@ -885,7 +968,10 @@ class AbstractTabularLearner(AbstractLearner):
885
968
  score_eval = performance_dict[self.eval_metric.name]
886
969
  logger.log(20, f"Evaluation: {self.eval_metric.name} on test data: {score_eval}")
887
970
  if not self.eval_metric.greater_is_better_internal:
888
- logger.log(20, f"\tNote: Scores are always higher_is_better. This metric score can be multiplied by -1 to get the metric value.")
971
+ logger.log(
972
+ 20,
973
+ f"\tNote: Scores are always higher_is_better. This metric score can be multiplied by -1 to get the metric value.",
974
+ )
889
975
  logger.log(20, "Evaluations on test data:")
890
976
  logger.log(20, json.dumps(performance_dict, indent=4))
891
977
 
@@ -951,7 +1037,9 @@ class AbstractTabularLearner(AbstractLearner):
951
1037
  if extra_metrics:
952
1038
  raise AssertionError("`extra_metrics` is only valid when data is specified.")
953
1039
  trainer = self.load_trainer()
954
- leaderboard = trainer.leaderboard(extra_info=extra_info, refit_full=refit_full, set_refit_score_to_parent=set_refit_score_to_parent)
1040
+ leaderboard = trainer.leaderboard(
1041
+ extra_info=extra_info, refit_full=refit_full, set_refit_score_to_parent=set_refit_score_to_parent
1042
+ )
955
1043
  if only_pareto_frontier:
956
1044
  if "score_test" in leaderboard.columns and "pred_time_test" in leaderboard.columns:
957
1045
  score_col = "score_test"
@@ -959,7 +1047,9 @@ class AbstractTabularLearner(AbstractLearner):
959
1047
  else:
960
1048
  score_col = "score_val"
961
1049
  inference_time_col = "pred_time_val"
962
- leaderboard = get_leaderboard_pareto_frontier(leaderboard=leaderboard, score_col=score_col, inference_time_col=inference_time_col)
1050
+ leaderboard = get_leaderboard_pareto_frontier(
1051
+ leaderboard=leaderboard, score_col=score_col, inference_time_col=inference_time_col
1052
+ )
963
1053
  if score_format == "error":
964
1054
  leaderboard.rename(
965
1055
  columns={
@@ -988,7 +1078,15 @@ class AbstractTabularLearner(AbstractLearner):
988
1078
  # features: list of feature names that feature importances are calculated for and returned, specify None to get all feature importances.
989
1079
  # feature_stage: Whether to compute feature importance on raw original features ('original'), transformed features ('transformed') or on the features used by the particular model ('transformed_model').
990
1080
  def get_feature_importance(
991
- self, model=None, X=None, y=None, features: list = None, feature_stage="original", subsample_size=5000, silent=False, **kwargs
1081
+ self,
1082
+ model=None,
1083
+ X=None,
1084
+ y=None,
1085
+ features: list = None,
1086
+ feature_stage="original",
1087
+ subsample_size=5000,
1088
+ silent=False,
1089
+ **kwargs,
992
1090
  ) -> DataFrame:
993
1091
  valid_feature_stages = ["original", "transformed", "transformed_model"]
994
1092
  if feature_stage not in valid_feature_stages:
@@ -1003,20 +1101,34 @@ class AbstractTabularLearner(AbstractLearner):
1003
1101
  X = X.drop(columns=self.ignored_columns, errors="ignore")
1004
1102
  unused_features = [f for f in list(X.columns) if f not in self.features]
1005
1103
  if len(unused_features) > 0:
1006
- logger.log(30, f"These features in provided data are not utilized by the predictor and will be ignored: {unused_features}")
1104
+ logger.log(
1105
+ 30,
1106
+ f"These features in provided data are not utilized by the predictor and will be ignored: {unused_features}",
1107
+ )
1007
1108
  X = X.drop(columns=unused_features)
1008
1109
 
1009
1110
  if feature_stage == "original":
1010
1111
  return trainer._get_feature_importance_raw(
1011
- model=model, X=X, y=y, features=features, subsample_size=subsample_size, transform_func=self.transform_features, silent=silent, **kwargs
1112
+ model=model,
1113
+ X=X,
1114
+ y=y,
1115
+ features=features,
1116
+ subsample_size=subsample_size,
1117
+ transform_func=self.transform_features,
1118
+ silent=silent,
1119
+ **kwargs,
1012
1120
  )
1013
1121
  X = self.transform_features(X)
1014
1122
  else:
1015
1123
  if feature_stage == "original":
1016
- raise AssertionError("Feature importance `dataset` cannot be None if `feature_stage=='original'`. A test dataset must be specified.")
1124
+ raise AssertionError(
1125
+ "Feature importance `dataset` cannot be None if `feature_stage=='original'`. A test dataset must be specified."
1126
+ )
1017
1127
  y = None
1018
1128
  raw = feature_stage == "transformed"
1019
- return trainer.get_feature_importance(X=X, y=y, model=model, features=features, raw=raw, subsample_size=subsample_size, silent=silent, **kwargs)
1129
+ return trainer.get_feature_importance(
1130
+ X=X, y=y, model=model, features=features, raw=raw, subsample_size=subsample_size, silent=silent, **kwargs
1131
+ )
1020
1132
 
1021
1133
  @staticmethod
1022
1134
  def _remove_nan_label_rows(X, y):
@@ -1029,7 +1141,9 @@ class AbstractTabularLearner(AbstractLearner):
1029
1141
  problem_type = self._infer_problem_type(y, silent=silent)
1030
1142
  if problem_type == QUANTILE:
1031
1143
  if self.quantile_levels is None:
1032
- raise AssertionError(f"problem_type is inferred to be {QUANTILE}, yet quantile_levels is not specified.")
1144
+ raise AssertionError(
1145
+ f"problem_type is inferred to be {QUANTILE}, yet quantile_levels is not specified."
1146
+ )
1033
1147
  elif self.quantile_levels is not None:
1034
1148
  if problem_type == REGRESSION:
1035
1149
  problem_type = QUANTILE
@@ -1073,7 +1187,11 @@ class AbstractTabularLearner(AbstractLearner):
1073
1187
  ):
1074
1188
  """See abstract_trainer.distill() for details."""
1075
1189
  if X is not None:
1076
- if (self.eval_metric is not None) and (self.eval_metric.name == "log_loss") and (self.problem_type == MULTICLASS):
1190
+ if (
1191
+ (self.eval_metric is not None)
1192
+ and (self.eval_metric.name == "log_loss")
1193
+ and (self.problem_type == MULTICLASS)
1194
+ ):
1077
1195
  X = augment_rare_classes(X, self.label, self.threshold)
1078
1196
  if y is None:
1079
1197
  X, y = self.extract_label(X)
@@ -92,16 +92,27 @@ class DefaultLearner(AbstractTabularLearner):
92
92
  num_bag_folds = len(X[self.groups].unique())
93
93
  X_og = None if infer_limit_batch_size is None else X
94
94
  logger.log(20, "Preprocessing data ...")
95
- X, y, X_val, y_val, X_test, y_test, X_unlabeled, holdout_frac, num_bag_folds, groups = self.general_data_processing(
96
- X=X, X_val=X_val, X_test=X_test, X_unlabeled=X_unlabeled, holdout_frac=holdout_frac, num_bag_folds=num_bag_folds
95
+ X, y, X_val, y_val, X_test, y_test, X_unlabeled, holdout_frac, num_bag_folds, groups = (
96
+ self.general_data_processing(
97
+ X=X,
98
+ X_val=X_val,
99
+ X_test=X_test,
100
+ X_unlabeled=X_unlabeled,
101
+ holdout_frac=holdout_frac,
102
+ num_bag_folds=num_bag_folds,
103
+ )
97
104
  )
98
105
  if X_og is not None:
99
- infer_limit = self._update_infer_limit(X=X_og, infer_limit_batch_size=infer_limit_batch_size, infer_limit=infer_limit)
106
+ infer_limit = self._update_infer_limit(
107
+ X=X_og, infer_limit_batch_size=infer_limit_batch_size, infer_limit=infer_limit
108
+ )
100
109
 
101
110
  self._post_X_rows = len(X)
102
111
  time_preprocessing_end = time.time()
103
112
  self._time_fit_preprocessing = time_preprocessing_end - time_preprocessing_start
104
- logger.log(20, f"Data preprocessing and feature engineering runtime = {round(self._time_fit_preprocessing, 2)}s ...")
113
+ logger.log(
114
+ 20, f"Data preprocessing and feature engineering runtime = {round(self._time_fit_preprocessing, 2)}s ..."
115
+ )
105
116
  if time_limit:
106
117
  time_limit_trainer = time_limit - self._time_fit_preprocessing
107
118
  else:
@@ -152,12 +163,18 @@ class DefaultLearner(AbstractTabularLearner):
152
163
  self._time_fit_total = time_end - time_preprocessing_start
153
164
  log_throughput = ""
154
165
  if trainer.model_best is not None:
155
- predict_n_time_per_row = trainer.get_model_attribute_full(model=trainer.model_best, attribute="predict_n_time_per_row")
156
- predict_n_size = trainer.get_model_attribute_full(model=trainer.model_best, attribute="predict_n_size", func=min)
166
+ predict_n_time_per_row = trainer.get_model_attribute_full(
167
+ model=trainer.model_best, attribute="predict_n_time_per_row"
168
+ )
169
+ predict_n_size = trainer.get_model_attribute_full(
170
+ model=trainer.model_best, attribute="predict_n_size", func=min
171
+ )
157
172
  if predict_n_time_per_row is not None and predict_n_size is not None:
158
- log_throughput = f" | Estimated inference throughput: {1/(predict_n_time_per_row if predict_n_time_per_row else np.finfo(np.float16).eps):.1f} rows/s ({int(predict_n_size)} batch size)"
173
+ log_throughput = f" | Estimated inference throughput: {1 / (predict_n_time_per_row if predict_n_time_per_row else np.finfo(np.float16).eps):.1f} rows/s ({int(predict_n_size)} batch size)"
159
174
  logger.log(
160
- 20, f"AutoGluon training complete, total runtime = {round(self._time_fit_total, 2)}s ... Best model: {trainer.model_best}" f"{log_throughput}"
175
+ 20,
176
+ f"AutoGluon training complete, total runtime = {round(self._time_fit_total, 2)}s ... Best model: {trainer.model_best}"
177
+ f"{log_throughput}",
161
178
  )
162
179
 
163
180
  def _update_infer_limit(self, X: DataFrame, *, infer_limit_batch_size: int, infer_limit: float = None):
@@ -172,7 +189,8 @@ class DefaultLearner(AbstractTabularLearner):
172
189
  self.preprocess_1_batch_size = infer_limit_batch_size
173
190
  preprocess_1_time_log, time_unit_preprocess_1_time = convert_time_in_s_to_log_friendly(self.preprocess_1_time)
174
191
  logger.log(
175
- 20, f"\t{round(preprocess_1_time_log, 3)}{time_unit_preprocess_1_time}\t= Feature Preprocessing Time (1 row | {infer_limit_batch_size} batch size)"
192
+ 20,
193
+ f"\t{round(preprocess_1_time_log, 3)}{time_unit_preprocess_1_time}\t= Feature Preprocessing Time (1 row | {infer_limit_batch_size} batch size)",
176
194
  )
177
195
 
178
196
  if infer_limit is not None:
@@ -182,7 +200,7 @@ class DefaultLearner(AbstractTabularLearner):
182
200
 
183
201
  logger.log(
184
202
  20,
185
- f"\t\tFeature Preprocessing requires {round(self.preprocess_1_time/infer_limit*100, 2)}% "
203
+ f"\t\tFeature Preprocessing requires {round(self.preprocess_1_time / infer_limit * 100, 2)}% "
186
204
  f"of the overall inference constraint ({infer_limit_log}{time_unit_infer_limit})\n"
187
205
  f"\t\t{round(infer_limit_new_log, 3)}{time_unit_infer_limit_new} inference time budget remaining for models...",
188
206
  )
@@ -199,7 +217,13 @@ class DefaultLearner(AbstractTabularLearner):
199
217
 
200
218
  # TODO: Add default values to X_val, X_unlabeled, holdout_frac, and num_bag_folds
201
219
  def general_data_processing(
202
- self, X: DataFrame, X_val: DataFrame = None, X_test: DataFrame = None, X_unlabeled: DataFrame = None, holdout_frac: float = 1, num_bag_folds: int = 0
220
+ self,
221
+ X: DataFrame,
222
+ X_val: DataFrame = None,
223
+ X_test: DataFrame = None,
224
+ X_unlabeled: DataFrame = None,
225
+ holdout_frac: float = 1,
226
+ num_bag_folds: int = 0,
203
227
  ):
204
228
  """General data processing steps used for all models."""
205
229
  X = self._check_for_non_finite_values(X, name="train", is_train=True)
@@ -231,7 +255,9 @@ class DefaultLearner(AbstractTabularLearner):
231
255
  self.cleaner = Cleaner.construct(problem_type=self.problem_type, label=self.label, threshold=self.threshold)
232
256
  X = self.cleaner.fit_transform(X) # TODO: Consider merging cleaner into label_cleaner
233
257
  X, y = self.extract_label(X)
234
- self.label_cleaner = LabelCleaner.construct(problem_type=self.problem_type, y=y, y_uncleaned=y_uncleaned, positive_class=self._positive_class)
258
+ self.label_cleaner = LabelCleaner.construct(
259
+ problem_type=self.problem_type, y=y, y_uncleaned=y_uncleaned, positive_class=self._positive_class
260
+ )
235
261
  y = self.label_cleaner.transform(y)
236
262
  X = self.set_predefined_weights(X, y)
237
263
  X, w = extract_column(X, self.sample_weight)
@@ -240,10 +266,20 @@ class DefaultLearner(AbstractTabularLearner):
240
266
  logger.log(20, f"Train Data Class Count: {self.label_cleaner.num_classes}")
241
267
 
242
268
  X_val, y_val, w_val, holdout_frac = self._apply_cleaner_transform(
243
- X=X_val, y_uncleaned=y_uncleaned, holdout_frac=holdout_frac, holdout_frac_og=holdout_frac_og, name="val", is_test=False
269
+ X=X_val,
270
+ y_uncleaned=y_uncleaned,
271
+ holdout_frac=holdout_frac,
272
+ holdout_frac_og=holdout_frac_og,
273
+ name="val",
274
+ is_test=False,
244
275
  )
245
276
  X_test, y_test, w_test, _ = self._apply_cleaner_transform(
246
- X=X_test, y_uncleaned=y_uncleaned, holdout_frac=holdout_frac, holdout_frac_og=holdout_frac_og, name="test", is_test=True
277
+ X=X_test,
278
+ y_uncleaned=y_uncleaned,
279
+ holdout_frac=holdout_frac,
280
+ holdout_frac_og=holdout_frac_og,
281
+ name="test",
282
+ is_test=True,
247
283
  )
248
284
 
249
285
  self._original_features = list(X.columns)
@@ -281,7 +317,9 @@ class DefaultLearner(AbstractTabularLearner):
281
317
  y_unlabeled = pd.Series(np.nan, index=X_unlabeled.index) if X_unlabeled is not None else None
282
318
  y_list = [y, y_val, y_test_super, y_unlabeled]
283
319
  y_super = pd.concat(y_list, ignore_index=True)
284
- X_super = self.fit_transform_features(X_super, y_super, problem_type=self.label_cleaner.problem_type_transform, eval_metric=self.eval_metric)
320
+ X_super = self.fit_transform_features(
321
+ X_super, y_super, problem_type=self.label_cleaner.problem_type_transform, eval_metric=self.eval_metric
322
+ )
285
323
  if not transform_with_test and X_test is not None:
286
324
  X_test = self.feature_generator.transform(X_test)
287
325
 
@@ -360,7 +398,13 @@ class DefaultLearner(AbstractTabularLearner):
360
398
  return X
361
399
 
362
400
  def _apply_cleaner_transform(
363
- self, X: DataFrame, y_uncleaned: Series, holdout_frac: float | int, holdout_frac_og: float | int, name: str, is_test: bool = False
401
+ self,
402
+ X: DataFrame,
403
+ y_uncleaned: Series,
404
+ holdout_frac: float | int,
405
+ holdout_frac_og: float | int,
406
+ name: str,
407
+ is_test: bool = False,
364
408
  ) -> tuple[DataFrame, Series, Series | None, float | int]:
365
409
  if X is not None and self.label in X.columns:
366
410
  y_og = X[self.label]
@@ -387,7 +431,9 @@ class DefaultLearner(AbstractTabularLearner):
387
431
  logger.warning(f"\t{name} Class Dtype: {y_og.dtype}")
388
432
  missing_classes = [c for c in val_classes if c not in train_classes]
389
433
  logger.warning(f"\tClasses missing from Training Data: {missing_classes}")
390
- logger.warning("############################################################################################################")
434
+ logger.warning(
435
+ "############################################################################################################"
436
+ )
391
437
 
392
438
  X = None
393
439
  y = None
@@ -405,15 +451,23 @@ class DefaultLearner(AbstractTabularLearner):
405
451
  return X, y, w, holdout_frac
406
452
 
407
453
  def adjust_threshold_if_necessary(self, y, threshold, holdout_frac, num_bag_folds):
408
- new_threshold, new_holdout_frac, new_num_bag_folds = self._adjust_threshold_if_necessary(y, threshold, holdout_frac, num_bag_folds)
454
+ new_threshold, new_holdout_frac, new_num_bag_folds = self._adjust_threshold_if_necessary(
455
+ y, threshold, holdout_frac, num_bag_folds
456
+ )
409
457
  if new_threshold != threshold:
410
458
  if new_threshold < threshold:
411
- logger.warning(f"Warning: Updated label_count_threshold from {threshold} to {new_threshold} to avoid cutting too many classes.")
459
+ logger.warning(
460
+ f"Warning: Updated label_count_threshold from {threshold} to {new_threshold} to avoid cutting too many classes."
461
+ )
412
462
  if new_holdout_frac != holdout_frac:
413
463
  if new_holdout_frac > holdout_frac:
414
- logger.warning(f"Warning: Updated holdout_frac from {holdout_frac} to {new_holdout_frac} to avoid cutting too many classes.")
464
+ logger.warning(
465
+ f"Warning: Updated holdout_frac from {holdout_frac} to {new_holdout_frac} to avoid cutting too many classes."
466
+ )
415
467
  if new_num_bag_folds != num_bag_folds:
416
- logger.warning(f"Warning: Updated num_bag_folds from {num_bag_folds} to {new_num_bag_folds} to avoid cutting too many classes.")
468
+ logger.warning(
469
+ f"Warning: Updated num_bag_folds from {num_bag_folds} to {new_num_bag_folds} to avoid cutting too many classes."
470
+ )
417
471
  return new_threshold, new_holdout_frac, new_num_bag_folds
418
472
 
419
473
  def _adjust_threshold_if_necessary(self, y, threshold, holdout_frac, num_bag_folds):
@@ -462,7 +516,9 @@ class DefaultLearner(AbstractTabularLearner):
462
516
  def get_info(self, include_model_info=False, include_model_failures=False, **kwargs):
463
517
  learner_info = super().get_info(**kwargs)
464
518
  trainer = self.load_trainer()
465
- trainer_info = trainer.get_info(include_model_info=include_model_info, include_model_failures=include_model_failures)
519
+ trainer_info = trainer.get_info(
520
+ include_model_info=include_model_info, include_model_failures=include_model_failures
521
+ )
466
522
  learner_info.update(
467
523
  {
468
524
  "time_fit_preprocessing": self._time_fit_preprocessing,
@@ -1,6 +1,5 @@
1
1
  from autogluon.core.models.abstract.abstract_model import AbstractModel
2
2
 
3
- from .tabprep.prep_lgb_model import PrepLGBModel
4
3
  from .automm.automm_model import MultiModalPredictorModel
5
4
  from .automm.ft_transformer import FTTransformerModel
6
5
  from .catboost.catboost_model import CatBoostModel
@@ -19,6 +18,7 @@ from .imodels.imodels_models import (
19
18
  from .knn.knn_model import KNNModel
20
19
  from .lgb.lgb_model import LGBModel
21
20
  from .lr.lr_model import LinearModel
21
+ from .mitra.mitra_model import MitraModel
22
22
  from .realmlp.realmlp_model import RealMLPModel
23
23
  from .rf.rf_model import RFModel
24
24
  from .tabdpt.tabdpt_model import TabDPTModel
@@ -26,7 +26,7 @@ from .tabicl.tabicl_model import TabICLModel
26
26
  from .tabm.tabm_model import TabMModel
27
27
  from .tabpfnmix.tabpfnmix_model import TabPFNMixModel
28
28
  from .tabpfnv2.tabpfnv2_5_model import RealTabPFNv2Model, RealTabPFNv25Model
29
- from .mitra.mitra_model import MitraModel
29
+ from .tabprep.prep_lgb_model import PrepLGBModel
30
30
  from .tabular_nn.torch.tabular_nn_torch import TabularNeuralNetTorchModel
31
31
  from .text_prediction.text_prediction_v1_model import TextPredictorModel
32
32
  from .xgboost.xgboost_model import XGBoostModel
@@ -16,7 +16,9 @@ class RapidsModelMixin:
16
16
 
17
17
  def _get_default_resources(self):
18
18
  num_cpus, _ = super()._get_default_resources()
19
- num_gpus = min(ResourceManager.get_gpu_count_torch(), 1) # Use single gpu training by default. Consider revising it later.
19
+ num_gpus = min(
20
+ ResourceManager.get_gpu_count_torch(), 1
21
+ ) # Use single gpu training by default. Consider revising it later.
20
22
  return num_cpus, num_gpus
21
23
 
22
24
  def get_minimum_resources(self, is_gpu_available=False) -> Dict[str, int]:
@@ -12,6 +12,7 @@ class AbstractTorchModel(AbstractModel):
12
12
  """
13
13
  .. versionadded:: 1.5.0
14
14
  """
15
+
15
16
  def __init__(self, **kwargs):
16
17
  super().__init__(**kwargs)
17
18
  self.device = None
@@ -51,6 +52,7 @@ class AbstractTorchModel(AbstractModel):
51
52
  @classmethod
52
53
  def to_torch_device(cls, device: str):
53
54
  import torch
55
+
54
56
  return torch.device(device)
55
57
 
56
58
  def get_device(self) -> str: