autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -27,7 +27,7 @@ from autogluon.core.calibrate.conformity_score import compute_conformity_score
27
27
  from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling
28
28
  from autogluon.core.callbacks import AbstractCallback
29
29
  from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS
30
- from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary, LabelCleaner
30
+ from autogluon.core.data.label_cleaner import LabelCleaner, LabelCleanerMulticlassToBinary
31
31
  from autogluon.core.metrics import Scorer, compute_metric, get_metric
32
32
  from autogluon.core.models import (
33
33
  AbstractModel,
@@ -64,7 +64,6 @@ from autogluon.core.utils.feature_selection import FeatureSelector
64
64
  from autogluon.core.utils.loaders import load_pkl
65
65
  from autogluon.core.utils.savers import save_pkl
66
66
 
67
-
68
67
  logger = logging.getLogger(__name__)
69
68
 
70
69
 
@@ -349,11 +348,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
349
348
  self._y_test_saved = True
350
349
 
351
350
  def get_model_names(
352
- self,
353
- stack_name: list[str] | str | None = None,
354
- level: list[int] | int | None = None,
355
- can_infer: bool | None = None,
356
- models: list[str] | None = None
351
+ self,
352
+ stack_name: list[str] | str | None = None,
353
+ level: list[int] | int | None = None,
354
+ can_infer: bool | None = None,
355
+ models: list[str] | None = None,
357
356
  ) -> list[str]:
358
357
  if models is None:
359
358
  models = list(self.model_graph.nodes)
@@ -385,7 +384,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
385
384
  """Constructs a list of unfit models based on the hyperparameters dict."""
386
385
  raise NotImplementedError
387
386
 
388
- def construct_model_templates_distillation(self, hyperparameters: dict, **kwargs) -> tuple[list[AbstractModel], dict]:
387
+ def construct_model_templates_distillation(
388
+ self, hyperparameters: dict, **kwargs
389
+ ) -> tuple[list[AbstractModel], dict]:
389
390
  """Constructs a list of unfit models based on the hyperparameters dict for softclass distillation."""
390
391
  raise NotImplementedError
391
392
 
@@ -438,7 +439,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
438
439
  self._fit_setup(time_limit=time_limit, callbacks=callbacks)
439
440
  time_train_start = self._time_train_start
440
441
  assert time_train_start is not None
441
-
442
+
442
443
  if self.callbacks:
443
444
  callback_classes = [c.__class__.__name__ for c in self.callbacks]
444
445
  logger.log(20, f"User-specified callbacks ({len(self.callbacks)}): {callback_classes}")
@@ -447,7 +448,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
447
448
 
448
449
  if relative_stack:
449
450
  if level_start != 1:
450
- raise AssertionError(f"level_start must be 1 when `relative_stack=True`. (level_start = {level_start})")
451
+ raise AssertionError(
452
+ f"level_start must be 1 when `relative_stack=True`. (level_start = {level_start})"
453
+ )
451
454
  level_add = 0
452
455
  if base_model_names:
453
456
  max_base_model_level = self.get_max_level(models=base_model_names)
@@ -488,19 +491,30 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
488
491
 
489
492
  model_names_fit = []
490
493
  if level_start != level_end:
491
- logger.log(20, f"AutoGluon will fit {level_end - level_start + 1} stack levels (L{level_start} to L{level_end}) ...")
494
+ logger.log(
495
+ 20,
496
+ f"AutoGluon will fit {level_end - level_start + 1} stack levels (L{level_start} to L{level_end}) ...",
497
+ )
492
498
  for level in range(level_start, level_end + 1):
493
499
  core_kwargs_level = core_kwargs.copy()
494
500
  aux_kwargs_level = aux_kwargs.copy()
495
- full_weighted_ensemble = aux_kwargs_level.pop("fit_full_last_level_weighted_ensemble", True) and (level == level_end) and (level > 1)
496
- additional_full_weighted_ensemble = aux_kwargs_level.pop("full_weighted_ensemble_additionally", False) and full_weighted_ensemble
501
+ full_weighted_ensemble = (
502
+ aux_kwargs_level.pop("fit_full_last_level_weighted_ensemble", True)
503
+ and (level == level_end)
504
+ and (level > 1)
505
+ )
506
+ additional_full_weighted_ensemble = (
507
+ aux_kwargs_level.pop("full_weighted_ensemble_additionally", False) and full_weighted_ensemble
508
+ )
497
509
  if time_limit is not None:
498
510
  time_train_level_start = time.time()
499
511
  levels_left = level_end - level + 1
500
512
  time_left = time_limit - (time_train_level_start - time_train_start)
501
513
  time_limit_for_level = min(time_left / levels_left * (1 + level_time_modifier), time_left)
502
514
  time_limit_core = time_limit_for_level
503
- time_limit_aux = max(time_limit_for_level * 0.1, min(time_limit, 360)) # Allows aux to go over time_limit, but only by a small amount
515
+ time_limit_aux = max(
516
+ time_limit_for_level * 0.1, min(time_limit, 360)
517
+ ) # Allows aux to go over time_limit, but only by a small amount
504
518
  core_kwargs_level["time_limit"] = core_kwargs_level.get("time_limit", time_limit_core)
505
519
  aux_kwargs_level["time_limit"] = aux_kwargs_level.get("time_limit", time_limit_aux)
506
520
  base_model_names, aux_models = self.stack_new_level(
@@ -530,7 +544,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
530
544
  self.save()
531
545
  return model_names_fit
532
546
 
533
- def _fit_setup(self, time_limit: float | None = None, callbacks: list[AbstractCallback | list | tuple] | None = None):
547
+ def _fit_setup(
548
+ self, time_limit: float | None = None, callbacks: list[AbstractCallback | list | tuple] | None = None
549
+ ):
534
550
  """
535
551
  Prepare the trainer state at the start of / prior to a fit call.
536
552
  Should be paired with a `self._fit_cleanup()` at the conclusion of the fit call.
@@ -544,12 +560,16 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
544
560
  assert isinstance(callbacks, list), f"`callbacks` must be a list. Found invalid type: `{type(callbacks)}`."
545
561
  for callback in callbacks:
546
562
  if isinstance(callback, (list, tuple)):
547
- assert len(callback) == 2, f"Callback must either be an initialized object or a tuple/list of length 2, found: {callback}"
563
+ assert len(callback) == 2, (
564
+ f"Callback must either be an initialized object or a tuple/list of length 2, found: {callback}"
565
+ )
548
566
  callback_cls = callback[0]
549
567
  if isinstance(callback_cls, str):
550
- from autogluon.core.callbacks._early_stopping_count_callback import EarlyStoppingCountCallback
551
568
  from autogluon.core.callbacks._early_stopping_callback import EarlyStoppingCallback
552
- from autogluon.core.callbacks._early_stopping_ensemble_callback import EarlyStoppingEnsembleCallback
569
+ from autogluon.core.callbacks._early_stopping_count_callback import EarlyStoppingCountCallback
570
+ from autogluon.core.callbacks._early_stopping_ensemble_callback import (
571
+ EarlyStoppingEnsembleCallback,
572
+ )
553
573
 
554
574
  _callback_cls_lst = [
555
575
  EarlyStoppingCallback,
@@ -557,9 +577,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
557
577
  EarlyStoppingEnsembleCallback,
558
578
  ]
559
579
 
560
- _callback_cls_name_map = {
561
- c.__name__: c for c in _callback_cls_lst
562
- }
580
+ _callback_cls_name_map = {c.__name__: c for c in _callback_cls_lst}
563
581
 
564
582
  assert callback_cls in _callback_cls_name_map.keys(), (
565
583
  f"Unknown callback class: {callback_cls}. "
@@ -568,12 +586,14 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
568
586
  callback_cls = _callback_cls_name_map[callback_cls]
569
587
 
570
588
  callback_kwargs = callback[1]
571
- assert isinstance(callback_kwargs, dict), f"Callback kwargs must be a dictionary, found: {callback_kwargs}"
589
+ assert isinstance(callback_kwargs, dict), (
590
+ f"Callback kwargs must be a dictionary, found: {callback_kwargs}"
591
+ )
572
592
  callback = callback_cls(**callback_kwargs)
573
593
  else:
574
- assert isinstance(
575
- callback, AbstractCallback
576
- ), f"Elements in `callbacks` must be of type AbstractCallback. Found invalid type: `{type(callback)}`."
594
+ assert isinstance(callback, AbstractCallback), (
595
+ f"Elements in `callbacks` must be of type AbstractCallback. Found invalid type: `{type(callback)}`."
596
+ )
577
597
  callbacks_new.append(callback)
578
598
  else:
579
599
  callbacks_new = []
@@ -604,11 +624,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
604
624
 
605
625
  # TODO: Consider better greedy approximation method such as via fitting a weighted ensemble to evaluate the value of a subset.
606
626
  def _filter_base_models_via_infer_limit(
607
- self,
608
- base_model_names: list[str],
609
- infer_limit: float | None,
610
- infer_limit_modifier: float = 1.0,
611
- as_child: bool = True,
627
+ self,
628
+ base_model_names: list[str],
629
+ infer_limit: float | None,
630
+ infer_limit_modifier: float = 1.0,
631
+ as_child: bool = True,
612
632
  verbose: bool = True,
613
633
  ) -> list[str]:
614
634
  """
@@ -663,8 +683,12 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
663
683
  base_model_names.remove(base_model_name)
664
684
  predict_1_time_full_set = self.get_model_attribute_full(model=base_model_names, attribute=attribute)
665
685
  if verbose:
666
- predict_1_time_full_set_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set)
667
- predict_1_time_full_set_old_log, time_unit_old = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set_old)
686
+ predict_1_time_full_set_log, time_unit = convert_time_in_s_to_log_friendly(
687
+ time_in_sec=predict_1_time_full_set
688
+ )
689
+ predict_1_time_full_set_old_log, time_unit_old = convert_time_in_s_to_log_friendly(
690
+ time_in_sec=predict_1_time_full_set_old
691
+ )
668
692
  messages_to_log.append(
669
693
  f"\t{round(predict_1_time_full_set_old_log, 3)}{time_unit_old}\t-> {round(predict_1_time_full_set_log, 3)}{time_unit}\t({base_model_name})"
670
694
  )
@@ -681,14 +705,20 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
681
705
  i += 1
682
706
  predict_1_time_full_set = self.get_model_attribute_full(model=base_model_names, attribute=attribute)
683
707
  if verbose:
684
- predict_1_time_full_set_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set)
685
- predict_1_time_full_set_old_log, time_unit_old = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set_old)
708
+ predict_1_time_full_set_log, time_unit = convert_time_in_s_to_log_friendly(
709
+ time_in_sec=predict_1_time_full_set
710
+ )
711
+ predict_1_time_full_set_old_log, time_unit_old = convert_time_in_s_to_log_friendly(
712
+ time_in_sec=predict_1_time_full_set_old
713
+ )
686
714
  messages_to_log.append(
687
715
  f"\t{round(predict_1_time_full_set_old_log, 3)}{time_unit_old}\t-> {round(predict_1_time_full_set_log, 3)}{time_unit}\t({base_model_to_remove})"
688
716
  )
689
717
 
690
718
  if messages_to_log:
691
- infer_limit_threshold_log, time_unit_threshold = convert_time_in_s_to_log_friendly(time_in_sec=infer_limit_threshold)
719
+ infer_limit_threshold_log, time_unit_threshold = convert_time_in_s_to_log_friendly(
720
+ time_in_sec=infer_limit_threshold
721
+ )
692
722
  logger.log(
693
723
  20,
694
724
  f"Removing {len(messages_to_log)}/{num_models_og} base models to satisfy inference constraint "
@@ -729,7 +759,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
729
759
  if level < 1:
730
760
  raise AssertionError(f"Stack level must be >= 1, but level={level}.")
731
761
  if base_model_names and level == 1:
732
- raise AssertionError(f"Stack level 1 models cannot have base models, but base_model_names={base_model_names}.")
762
+ raise AssertionError(
763
+ f"Stack level 1 models cannot have base models, but base_model_names={base_model_names}."
764
+ )
733
765
  if name_suffix:
734
766
  core_kwargs["name_suffix"] = core_kwargs.get("name_suffix", "") + name_suffix
735
767
  aux_kwargs["name_suffix"] = aux_kwargs.get("name_suffix", "") + name_suffix
@@ -754,11 +786,17 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
754
786
  full_aux_kwargs = aux_kwargs.copy()
755
787
  if additional_full_weighted_ensemble:
756
788
  full_aux_kwargs["name_extra"] = "_ALL"
757
- all_base_model_names = self.get_model_names(stack_name="core") # Fit weighted ensemble on all previously fitted core models
758
- aux_models += self._stack_new_level_aux(X_val, y_val, X, y, all_base_model_names, level, infer_limit, infer_limit_batch_size, **full_aux_kwargs)
789
+ all_base_model_names = self.get_model_names(
790
+ stack_name="core"
791
+ ) # Fit weighted ensemble on all previously fitted core models
792
+ aux_models += self._stack_new_level_aux(
793
+ X_val, y_val, X, y, all_base_model_names, level, infer_limit, infer_limit_batch_size, **full_aux_kwargs
794
+ )
759
795
 
760
796
  if (not full_weighted_ensemble) or additional_full_weighted_ensemble:
761
- aux_models += self._stack_new_level_aux(X_val, y_val, X, y, core_models, level, infer_limit, infer_limit_batch_size, **aux_kwargs)
797
+ aux_models += self._stack_new_level_aux(
798
+ X_val, y_val, X, y, core_models, level, infer_limit, infer_limit_batch_size, **aux_kwargs
799
+ )
762
800
 
763
801
  return core_models, aux_models
764
802
 
@@ -806,8 +844,8 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
806
844
  raise ValueError("Stack Ensembling is not valid for non-bagged mode.")
807
845
 
808
846
  base_model_names = self._filter_base_models_via_infer_limit(
809
- base_model_names=base_model_names,
810
- infer_limit=infer_limit,
847
+ base_model_names=base_model_names,
848
+ infer_limit=infer_limit,
811
849
  infer_limit_modifier=0.8,
812
850
  )
813
851
  if ag_args_fit is None:
@@ -830,7 +868,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
830
868
  if level == 1:
831
869
  (base_model_names, base_model_paths, base_model_types) = (None, None, None)
832
870
  elif level > 1:
833
- base_model_names, base_model_paths, base_model_types = self._get_models_load_info(model_names=base_model_names)
871
+ base_model_names, base_model_paths, base_model_types = self._get_models_load_info(
872
+ model_names=base_model_names
873
+ )
834
874
  if len(base_model_names) == 0: # type: ignore
835
875
  logger.log(20, f"No base models to train on, skipping stack level {level}...")
836
876
  return []
@@ -841,8 +881,12 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
841
881
  "base_model_names": base_model_names,
842
882
  "base_model_paths_dict": base_model_paths,
843
883
  "base_model_types_dict": base_model_types,
844
- "base_model_types_inner_dict": self.get_models_attribute_dict(attribute="type_inner", models=base_model_names),
845
- "base_model_performances_dict": self.get_models_attribute_dict(attribute="val_score", models=base_model_names),
884
+ "base_model_types_inner_dict": self.get_models_attribute_dict(
885
+ attribute="type_inner", models=base_model_names
886
+ ),
887
+ "base_model_performances_dict": self.get_models_attribute_dict(
888
+ attribute="val_score", models=base_model_names
889
+ ),
846
890
  "random_state": level + self.random_state,
847
891
  }
848
892
  get_models_kwargs.update(
@@ -861,7 +905,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
861
905
  }
862
906
  kwargs["hyperparameter_tune_kwargs"] = hyperparameter_tune_kwargs
863
907
 
864
- logger.log(10 if ((not refit_full) and DistributedContext.is_distributed_mode()) else 20, f'Fitting {len(models)} L{level} models, fit_strategy="{fit_strategy}" ...')
908
+ logger.log(
909
+ 10 if ((not refit_full) and DistributedContext.is_distributed_mode()) else 20,
910
+ f'Fitting {len(models)} L{level} models, fit_strategy="{fit_strategy}" ...',
911
+ )
865
912
 
866
913
  X_init = self.get_inputs_to_stacker(X, base_models=base_model_names, fit=True)
867
914
  feature_metadata = self.get_feature_metadata(use_orig_features=True, base_models=base_model_names)
@@ -901,10 +948,18 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
901
948
  **kwargs,
902
949
  )
903
950
 
904
- def _stack_new_level_aux(self, X_val, y_val, X, y, core_models, level, infer_limit, infer_limit_batch_size, **kwargs):
951
+ def _stack_new_level_aux(
952
+ self, X_val, y_val, X, y, core_models, level, infer_limit, infer_limit_batch_size, **kwargs
953
+ ):
905
954
  if X_val is None:
906
955
  aux_models = self.stack_new_level_aux(
907
- X=X, y=y, base_model_names=core_models, level=level + 1, infer_limit=infer_limit, infer_limit_batch_size=infer_limit_batch_size, **kwargs
956
+ X=X,
957
+ y=y,
958
+ base_model_names=core_models,
959
+ level=level + 1,
960
+ infer_limit=infer_limit,
961
+ infer_limit_batch_size=infer_limit_batch_size,
962
+ **kwargs,
908
963
  )
909
964
  else:
910
965
  aux_models = self.stack_new_level_aux(
@@ -952,7 +1007,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
952
1007
  # Skip fitting of aux models
953
1008
  return []
954
1009
 
955
- base_model_names = self._filter_base_models_via_infer_limit(base_model_names=base_model_names, infer_limit=infer_limit, infer_limit_modifier=0.95)
1010
+ base_model_names = self._filter_base_models_via_infer_limit(
1011
+ base_model_names=base_model_names, infer_limit=infer_limit, infer_limit_modifier=0.95
1012
+ )
956
1013
 
957
1014
  if len(base_model_names) == 0:
958
1015
  logger.log(20, f"No base models to train on, skipping auxiliary stack level {level}...")
@@ -972,9 +1029,13 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
972
1029
  ag_args_fit["predict_1_batch_size"] = infer_limit_batch_size
973
1030
  else:
974
1031
  ag_args_fit = None
975
- X_stack_preds = self.get_inputs_to_stacker(X, base_models=base_model_names, fit=fit, use_orig_features=False, use_val_cache=use_val_cache)
1032
+ X_stack_preds = self.get_inputs_to_stacker(
1033
+ X, base_models=base_model_names, fit=fit, use_orig_features=False, use_val_cache=use_val_cache
1034
+ )
976
1035
  if self.weight_evaluation:
977
- X, w = extract_column(X, self.sample_weight) # TODO: consider redesign with w as separate arg instead of bundled inside X
1036
+ X, w = extract_column(
1037
+ X, self.sample_weight
1038
+ ) # TODO: consider redesign with w as separate arg instead of bundled inside X
978
1039
  if w is not None:
979
1040
  X_stack_preds[self.sample_weight] = w.values / w.mean()
980
1041
  child_hyperparameters = None
@@ -1036,9 +1097,18 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1036
1097
  model_pred_proba_dict = None
1037
1098
  else:
1038
1099
  model_set = self.get_minimum_model_set(model)
1039
- model_set = [m for m in model_set if m != model.name] # TODO: Can probably be faster, get this result from graph
1040
- model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=model_set, model_pred_proba_dict=model_pred_proba_dict)
1041
- X = model.preprocess(X=X, preprocess_nonadaptive=preprocess_nonadaptive, fit=fit, model_pred_proba_dict=model_pred_proba_dict)
1100
+ model_set = [
1101
+ m for m in model_set if m != model.name
1102
+ ] # TODO: Can probably be faster, get this result from graph
1103
+ model_pred_proba_dict = self.get_model_pred_proba_dict(
1104
+ X=X, models=model_set, model_pred_proba_dict=model_pred_proba_dict
1105
+ )
1106
+ X = model.preprocess(
1107
+ X=X,
1108
+ preprocess_nonadaptive=preprocess_nonadaptive,
1109
+ fit=fit,
1110
+ model_pred_proba_dict=model_pred_proba_dict,
1111
+ )
1042
1112
  elif preprocess_nonadaptive:
1043
1113
  X = model.preprocess(X=X, preprocess_stateful=False)
1044
1114
  return X
@@ -1180,22 +1250,27 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1180
1250
  models_to_load = list(model_set)
1181
1251
  subgraph = nx.DiGraph(nx.subgraph(self.model_graph, models_to_load)) # Wrap subgraph in DiGraph to unfreeze it
1182
1252
  # For model in models_to_ignore, remove model node from graph and all ancestors that have no remaining descendants and are not in `models`
1183
- models_to_ignore = [model for model in models_to_load if (model not in models) and (not list(subgraph.successors(model)))]
1253
+ models_to_ignore = [
1254
+ model for model in models_to_load if (model not in models) and (not list(subgraph.successors(model)))
1255
+ ]
1184
1256
  while models_to_ignore:
1185
1257
  model = models_to_ignore[0]
1186
1258
  predecessors = list(subgraph.predecessors(model))
1187
1259
  subgraph.remove_node(model)
1188
1260
  models_to_ignore = models_to_ignore[1:]
1189
1261
  for predecessor in predecessors:
1190
- if (predecessor not in models) and (not list(subgraph.successors(predecessor))) and (predecessor not in models_to_ignore):
1262
+ if (
1263
+ (predecessor not in models)
1264
+ and (not list(subgraph.successors(predecessor)))
1265
+ and (predecessor not in models_to_ignore)
1266
+ ):
1191
1267
  models_to_ignore.append(predecessor)
1192
1268
 
1193
1269
  # Get model prediction order
1194
1270
  return list(nx.lexicographical_topological_sort(subgraph))
1195
-
1271
+
1196
1272
  def get_models_attribute_dict(self, attribute: str, models: list | None = None) -> dict[str, Any]:
1197
- """Returns dictionary of model name -> attribute value for the provided attribute.
1198
- """
1273
+ """Returns dictionary of model name -> attribute value for the provided attribute."""
1199
1274
  models_attribute_dict = nx.get_node_attributes(self.model_graph, attribute)
1200
1275
  if models is not None:
1201
1276
  model_names = []
@@ -1204,11 +1279,13 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1204
1279
  model = model.name
1205
1280
  model_names.append(model)
1206
1281
  if attribute == "path":
1207
- models_attribute_dict = {key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names}
1282
+ models_attribute_dict = {
1283
+ key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names
1284
+ }
1208
1285
  else:
1209
1286
  models_attribute_dict = {key: val for key, val in models_attribute_dict.items() if key in model_names}
1210
1287
  return models_attribute_dict
1211
-
1288
+
1212
1289
  # TODO: Consider adding persist to disk functionality for pred_proba dictionary to lessen memory burden on large multiclass problems.
1213
1290
  # For datasets with 100+ classes, this function could potentially run the system OOM due to each pred_proba numpy array taking significant amounts of space.
1214
1291
  # This issue already existed in the previous level-based version but only had the minimum required predictions in memory at a time, whereas this has all model predictions in memory.
@@ -1261,11 +1338,15 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1261
1338
  model_pred_time_dict = {}
1262
1339
 
1263
1340
  if use_val_cache:
1264
- _, model_pred_proba_dict = self._update_pred_proba_dict_with_val_cache(model_set=set(models), model_pred_proba_dict=model_pred_proba_dict)
1341
+ _, model_pred_proba_dict = self._update_pred_proba_dict_with_val_cache(
1342
+ model_set=set(models), model_pred_proba_dict=model_pred_proba_dict
1343
+ )
1265
1344
  if not model_pred_proba_dict:
1266
1345
  model_pred_order = self._construct_model_pred_order(models)
1267
1346
  else:
1268
- model_pred_order = self._construct_model_pred_order_with_pred_dict(models, models_to_ignore=list(model_pred_proba_dict.keys()))
1347
+ model_pred_order = self._construct_model_pred_order_with_pred_dict(
1348
+ models, models_to_ignore=list(model_pred_proba_dict.keys())
1349
+ )
1269
1350
  if use_val_cache:
1270
1351
  model_set, model_pred_proba_dict = self._update_pred_proba_dict_with_val_cache(
1271
1352
  model_set=set(model_pred_order), model_pred_proba_dict=model_pred_proba_dict
@@ -1324,7 +1405,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1324
1405
  -------
1325
1406
  If `record_pred_time==True`, outputs tuple of dicts (model_pred_dict, model_pred_time_dict), else output only model_pred_dict
1326
1407
  """
1327
- model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=models, record_pred_time=record_pred_time, **kwargs)
1408
+ model_pred_proba_dict = self.get_model_pred_proba_dict(
1409
+ X=X, models=models, record_pred_time=record_pred_time, **kwargs
1410
+ )
1328
1411
  if record_pred_time:
1329
1412
  model_pred_proba_dict, model_pred_time_dict = model_pred_proba_dict
1330
1413
  else:
@@ -1333,7 +1416,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1333
1416
  model_pred_dict = {}
1334
1417
  for m in model_pred_proba_dict:
1335
1418
  # Convert pred_proba to pred
1336
- model_pred_dict[m] = get_pred_from_proba(y_pred_proba=model_pred_proba_dict[m], problem_type=self.problem_type)
1419
+ model_pred_dict[m] = get_pred_from_proba(
1420
+ y_pred_proba=model_pred_proba_dict[m], problem_type=self.problem_type
1421
+ )
1337
1422
 
1338
1423
  if record_pred_time:
1339
1424
  return model_pred_dict, model_pred_time_dict
@@ -1447,14 +1532,18 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1447
1532
  )
1448
1533
  pred_proba_list = [model_pred_proba_dict[model] for model in base_models]
1449
1534
  stack_column_names, _ = self._get_stack_column_names(models=base_models)
1450
- X_stacker = convert_pred_probas_to_df(pred_proba_list=pred_proba_list, problem_type=self.problem_type, columns=stack_column_names, index=X.index)
1535
+ X_stacker = convert_pred_probas_to_df(
1536
+ pred_proba_list=pred_proba_list, problem_type=self.problem_type, columns=stack_column_names, index=X.index
1537
+ )
1451
1538
  if use_orig_features:
1452
1539
  X = pd.concat([X_stacker, X], axis=1)
1453
1540
  else:
1454
1541
  X = X_stacker
1455
1542
  return X
1456
1543
 
1457
- def get_feature_metadata(self, use_orig_features: bool = True, model: str | None = None, base_models: list[str] | None = None) -> FeatureMetadata:
1544
+ def get_feature_metadata(
1545
+ self, use_orig_features: bool = True, model: str | None = None, base_models: list[str] | None = None
1546
+ ) -> FeatureMetadata:
1458
1547
  """
1459
1548
  Returns the FeatureMetadata input to a `model.fit` call.
1460
1549
  Pairs with `X = self.get_inputs_to_stacker(...)`. The returned FeatureMetadata should reflect the contents of `X`.
@@ -1487,7 +1576,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1487
1576
  stack_column_names, _ = self._get_stack_column_names(models=base_models)
1488
1577
  stacker_type_map_raw = {column: R_FLOAT for column in stack_column_names}
1489
1578
  stacker_type_group_map_special = {S_STACK: stack_column_names}
1490
- stacker_feature_metadata = FeatureMetadata(type_map_raw=stacker_type_map_raw, type_group_map_special=stacker_type_group_map_special)
1579
+ stacker_feature_metadata = FeatureMetadata(
1580
+ type_map_raw=stacker_type_map_raw, type_group_map_special=stacker_type_group_map_special
1581
+ )
1491
1582
  if feature_metadata is not None:
1492
1583
  feature_metadata = feature_metadata.join_metadata(stacker_feature_metadata)
1493
1584
  else:
@@ -1502,10 +1593,16 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1502
1593
  Additionally output the number of columns per model as an int.
1503
1594
  """
1504
1595
  if self.problem_type in [MULTICLASS, SOFTCLASS]:
1505
- stack_column_names = [stack_column_prefix + "_" + str(cls) for stack_column_prefix in models for cls in range(self.num_classes)]
1596
+ stack_column_names = [
1597
+ stack_column_prefix + "_" + str(cls)
1598
+ for stack_column_prefix in models
1599
+ for cls in range(self.num_classes)
1600
+ ]
1506
1601
  num_columns_per_model = self.num_classes
1507
1602
  elif self.problem_type == QUANTILE:
1508
- stack_column_names = [stack_column_prefix + "_" + str(q) for stack_column_prefix in models for q in self.quantile_levels]
1603
+ stack_column_names = [
1604
+ stack_column_prefix + "_" + str(q) for stack_column_prefix in models for q in self.quantile_levels
1605
+ ]
1509
1606
  num_columns_per_model = len(self.quantile_levels)
1510
1607
  else:
1511
1608
  stack_column_names = models
@@ -1526,7 +1623,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1526
1623
  **kwargs,
1527
1624
  ) -> list[str]:
1528
1625
  if fit_strategy == "parallel":
1529
- logger.log(30, f"Note: refit_full does not yet support fit_strategy='parallel', switching to 'sequential'...")
1626
+ logger.log(
1627
+ 30, f"Note: refit_full does not yet support fit_strategy='parallel', switching to 'sequential'..."
1628
+ )
1530
1629
  fit_strategy = "sequential"
1531
1630
  if X is None:
1532
1631
  X = self.load_X()
@@ -1544,7 +1643,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1544
1643
  ignore_models = []
1545
1644
  ignore_stack_names = [REFIT_FULL_NAME]
1546
1645
  for stack_name in ignore_stack_names:
1547
- ignore_models += self.get_model_names(stack_name=stack_name) # get_model_names returns [] if stack_name does not exist
1646
+ ignore_models += self.get_model_names(
1647
+ stack_name=stack_name
1648
+ ) # get_model_names returns [] if stack_name does not exist
1548
1649
  models = [model for model in models if model not in ignore_models]
1549
1650
  for model in models:
1550
1651
  model_level = self.get_model_level(model)
@@ -1616,7 +1717,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1616
1717
  distributed_manager.job_kwargs["level"] = level
1617
1718
  models_level = model_levels[level]
1618
1719
 
1619
- logger.log(20, f"Scheduling distributed model-workers for refitting {len(models_level)} L{level} models...")
1720
+ logger.log(
1721
+ 20, f"Scheduling distributed model-workers for refitting {len(models_level)} L{level} models..."
1722
+ )
1620
1723
  unfinished_job_refs = distributed_manager.schedule_jobs(models_to_fit=models_level)
1621
1724
 
1622
1725
  while unfinished_job_refs:
@@ -1624,21 +1727,21 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1624
1727
  refit_full_parent, model_trained, model_path, model_type = ray.get(finished[0])
1625
1728
 
1626
1729
  self._add_model(
1627
- model_type.load(path=os.path.join(self.path,model_path), reset_paths=self.reset_paths),
1730
+ model_type.load(path=os.path.join(self.path, model_path), reset_paths=self.reset_paths),
1628
1731
  stack_name=REFIT_FULL_NAME,
1629
1732
  level=level,
1630
- _is_refit=True
1733
+ _is_refit=True,
1631
1734
  )
1632
1735
  model_refit_map[refit_full_parent] = model_trained
1633
1736
  self._update_model_attr(
1634
1737
  model_trained,
1635
1738
  refit_full=True,
1636
1739
  refit_full_parent=refit_full_parent,
1637
- refit_full_parent_val_score=self.get_model_attribute(refit_full_parent,"val_score"),
1740
+ refit_full_parent_val_score=self.get_model_attribute(refit_full_parent, "val_score"),
1638
1741
  )
1639
1742
  models_trained_full_level.append(model_trained)
1640
1743
 
1641
- logger.log(20,f"Finished refit model for {refit_full_parent}")
1744
+ logger.log(20, f"Finished refit model for {refit_full_parent}")
1642
1745
  unfinished_job_refs += distributed_manager.schedule_jobs()
1643
1746
 
1644
1747
  logger.log(20, f"Finished distributed refitting for {len(models_trained_full_level)} L{level} models.")
@@ -1675,7 +1778,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1675
1778
  model_refit_map = self.model_refit_map()
1676
1779
  for model in ensemble_set:
1677
1780
  if model in model_refit_map and model_refit_map[model] in existing_models:
1678
- logger.log(20, f"Model '{model}' already has a refit _FULL model: '{model_refit_map[model]}', skipping refit...")
1781
+ logger.log(
1782
+ 20,
1783
+ f"Model '{model}' already has a refit _FULL model: '{model_refit_map[model]}', skipping refit...",
1784
+ )
1679
1785
  else:
1680
1786
  ensemble_set_valid.append(model)
1681
1787
  if ensemble_set_valid:
@@ -1718,11 +1824,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1718
1824
  return self.get_model_attribute(model=model, attribute="refit_full_parent", default=model)
1719
1825
 
1720
1826
  def get_model_best(
1721
- self,
1722
- can_infer: bool | None = None,
1723
- allow_full: bool = True,
1724
- infer_limit: float | None = None,
1725
- infer_limit_as_child: bool = False
1827
+ self,
1828
+ can_infer: bool | None = None,
1829
+ allow_full: bool = True,
1830
+ infer_limit: float | None = None,
1831
+ infer_limit_as_child: bool = False,
1726
1832
  ) -> str:
1727
1833
  """
1728
1834
  Returns the name of the model with the best validation score that satisfies all specified constraints.
@@ -1774,7 +1880,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1774
1880
  models_predict_time_list = [models_predict_1_time[m] for m in models_og]
1775
1881
  min_time = np.array(models_predict_time_list).min()
1776
1882
  infer_limit_new = min_time * 1.2 # Give 20% lee-way
1777
- logger.log(30, f"WARNING: Impossible to satisfy infer_limit constraint. Relaxing constraint from {infer_limit} to {infer_limit_new} ...")
1883
+ logger.log(
1884
+ 30,
1885
+ f"WARNING: Impossible to satisfy infer_limit constraint. Relaxing constraint from {infer_limit} to {infer_limit_new} ...",
1886
+ )
1778
1887
  models = models_og
1779
1888
  for model_key in models_predict_1_time:
1780
1889
  if models_predict_1_time[model_key] > infer_limit_new:
@@ -1788,12 +1897,19 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1788
1897
  predict_time_attr = predict_1_time_attribute if predict_1_time_attribute is not None else "predict_time"
1789
1898
  models_predict_time = self.get_models_attribute_full(models=models, attribute=predict_time_attr)
1790
1899
 
1791
- perfs = [(m, model_performances[m], models_predict_time[m]) for m in models if model_performances[m] is not None]
1900
+ perfs = [
1901
+ (m, model_performances[m], models_predict_time[m]) for m in models if model_performances[m] is not None
1902
+ ]
1792
1903
  if not perfs:
1793
1904
  models = [m for m in models if m in models_full]
1794
- perfs = [(m, self.get_model_attribute(model=m, attribute="refit_full_parent_val_score"), models_predict_time[m]) for m in models]
1905
+ perfs = [
1906
+ (m, self.get_model_attribute(model=m, attribute="refit_full_parent_val_score"), models_predict_time[m])
1907
+ for m in models
1908
+ ]
1795
1909
  if not perfs:
1796
- raise AssertionError("No fit models that can infer exist with a validation score to choose the best model.")
1910
+ raise AssertionError(
1911
+ "No fit models that can infer exist with a validation score to choose the best model."
1912
+ )
1797
1913
  elif not allow_full:
1798
1914
  raise AssertionError(
1799
1915
  "No fit models that can infer exist with a validation score to choose the best model, but refit_full models exist. Set `allow_full=True` to get the best refit_full model."
@@ -1869,7 +1985,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1869
1985
  # Check if already compiled, or if can't compile due to missing dependencies,
1870
1986
  # or if model hasn't implemented compiling.
1871
1987
  if "compiler" in config and model.get_compiler_name() == config["compiler"]:
1872
- logger.log(20, f'Skipping compilation for {model_name} ... (Already compiled with "{model.get_compiler_name()}" backend)')
1988
+ logger.log(
1989
+ 20,
1990
+ f'Skipping compilation for {model_name} ... (Already compiled with "{model.get_compiler_name()}" backend)',
1991
+ )
1873
1992
  elif model.can_compile(compiler_configs=config):
1874
1993
  logger.log(20, f"Compiling model: {model.name} ... Config = {config}")
1875
1994
  compile_start_time = time.time()
@@ -1886,7 +2005,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1886
2005
  logger.log(20, f'\tCompiled model with "{compile_type}" backend ...')
1887
2006
  logger.log(20, f"\t{round(model.compile_time, 2)}s\t = Compile runtime")
1888
2007
  else:
1889
- logger.log(20, f"Skipping compilation for {model.name} ... (Unable to compile with the provided config: {config})")
2008
+ logger.log(
2009
+ 20,
2010
+ f"Skipping compilation for {model.name} ... (Unable to compile with the provided config: {config})",
2011
+ )
1890
2012
  logger.log(20, f"Finished compiling models, total runtime = {round(total_compile_time, 2)}s.")
1891
2013
  self.save()
1892
2014
  return model_names
@@ -1911,7 +2033,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1911
2033
  )
1912
2034
  model_names = [model_name for model_name in model_names if model_name not in model_names_already_persisted]
1913
2035
  if not model_names:
1914
- logger.log(30, f"No valid unpersisted models were specified to be persisted, so no change in model persistence was performed.")
2036
+ logger.log(
2037
+ 30,
2038
+ f"No valid unpersisted models were specified to be persisted, so no change in model persistence was performed.",
2039
+ )
1915
2040
  return []
1916
2041
  if max_memory is not None:
1917
2042
 
@@ -1929,7 +2054,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1929
2054
  if memory_proportion > max_memory:
1930
2055
  logger.log(
1931
2056
  30,
1932
- f"Models will not be persisted in memory as they are expected to require {round(memory_proportion * 100, 2)}% of memory, which is greater than the specified max_memory limit of {round(max_memory*100, 2)}%.",
2057
+ f"Models will not be persisted in memory as they are expected to require {round(memory_proportion * 100, 2)}% of memory, which is greater than the specified max_memory limit of {round(max_memory * 100, 2)}%.",
1933
2058
  )
1934
2059
  logger.log(
1935
2060
  30,
@@ -1937,7 +2062,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1937
2062
  )
1938
2063
  return False
1939
2064
  else:
1940
- logger.log(20, f"Persisting {len(model_names)} models in memory. Models will require {round(memory_proportion*100, 2)}% of memory.")
2065
+ logger.log(
2066
+ 20,
2067
+ f"Persisting {len(model_names)} models in memory. Models will require {round(memory_proportion * 100, 2)}% of memory.",
2068
+ )
1941
2069
  return True
1942
2070
 
1943
2071
  if not _check_memory():
@@ -1970,7 +2098,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
1970
2098
  if unpersisted_models:
1971
2099
  logger.log(20, f"Unpersisted {len(unpersisted_models)} models: {unpersisted_models}")
1972
2100
  else:
1973
- logger.log(30, f"No valid persisted models were specified to be unpersisted, so no change in model persistence was performed.")
2101
+ logger.log(
2102
+ 30,
2103
+ f"No valid persisted models were specified to be unpersisted, so no change in model persistence was performed.",
2104
+ )
1974
2105
  return unpersisted_models
1975
2106
 
1976
2107
  def generate_weighted_ensemble(
@@ -2023,8 +2154,12 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2023
2154
  base_model_names=base_model_names,
2024
2155
  base_model_paths_dict=base_model_paths_dict,
2025
2156
  base_model_types_dict=self.get_models_attribute_dict(attribute="type", models=base_model_names),
2026
- base_model_types_inner_dict=self.get_models_attribute_dict(attribute="type_inner", models=base_model_names),
2027
- base_model_performances_dict=self.get_models_attribute_dict(attribute="val_score", models=base_model_names),
2157
+ base_model_types_inner_dict=self.get_models_attribute_dict(
2158
+ attribute="type_inner", models=base_model_names
2159
+ ),
2160
+ base_model_performances_dict=self.get_models_attribute_dict(
2161
+ attribute="val_score", models=base_model_names
2162
+ ),
2028
2163
  hyperparameters=hyperparameters,
2029
2164
  random_state=level + self.random_state,
2030
2165
  ),
@@ -2051,7 +2186,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2051
2186
  level=level,
2052
2187
  time_limit=time_limit,
2053
2188
  ens_sample_weight=w,
2054
- fit_kwargs=dict(feature_metadata=feature_metadata, num_classes=self.num_classes, groups=None), # FIXME: Is this the right way to do this?
2189
+ fit_kwargs=dict(
2190
+ feature_metadata=feature_metadata, num_classes=self.num_classes, groups=None
2191
+ ), # FIXME: Is this the right way to do this?
2055
2192
  total_resources=total_resources,
2056
2193
  )
2057
2194
  for weighted_ensemble_model_name in models:
@@ -2082,7 +2219,16 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2082
2219
  Trains model but does not add the trained model to this Trainer.
2083
2220
  Returns trained model object.
2084
2221
  """
2085
- model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test, total_resources=total_resources, **model_fit_kwargs)
2222
+ model = model.fit(
2223
+ X=X,
2224
+ y=y,
2225
+ X_val=X_val,
2226
+ y_val=y_val,
2227
+ X_test=X_test,
2228
+ y_test=y_test,
2229
+ total_resources=total_resources,
2230
+ **model_fit_kwargs,
2231
+ )
2086
2232
  return model
2087
2233
 
2088
2234
  def _train_and_save(
@@ -2150,12 +2296,19 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2150
2296
  if not_enough_time:
2151
2297
  skip_msg = f"Skipping {model.name} due to lack of time remaining."
2152
2298
  not_enough_time_exception = InsufficientTime(skip_msg)
2153
- if self._check_raise_exception(exception=not_enough_time_exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise):
2299
+ if self._check_raise_exception(
2300
+ exception=not_enough_time_exception,
2301
+ errors=errors,
2302
+ errors_ignore=errors_ignore,
2303
+ errors_raise=errors_raise,
2304
+ ):
2154
2305
  raise not_enough_time_exception
2155
2306
  else:
2156
2307
  logger.log(15, skip_msg)
2157
2308
  return []
2158
- fit_log_message += f" Training model for up to {time_limit:.2f}s of the {time_left_total:.2f}s of remaining time."
2309
+ fit_log_message += (
2310
+ f" Training model for up to {time_limit:.2f}s of the {time_left_total:.2f}s of remaining time."
2311
+ )
2159
2312
  logger.log(10 if is_distributed_mode else 20, fit_log_message)
2160
2313
 
2161
2314
  if isinstance(model, BaggedEnsembleModel) and not compute_score:
@@ -2178,7 +2331,12 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2178
2331
  # If model is not bagged model and not stacked then pseudolabeled data needs to be incorporated at this level
2179
2332
  # Bagged model does validation on the fit level where as single models do it separately. Hence this if statement
2180
2333
  # is required
2181
- if not isinstance(model, BaggedEnsembleModel) and X_pseudo is not None and y_pseudo is not None and X_pseudo.columns.equals(X.columns):
2334
+ if (
2335
+ not isinstance(model, BaggedEnsembleModel)
2336
+ and X_pseudo is not None
2337
+ and y_pseudo is not None
2338
+ and X_pseudo.columns.equals(X.columns)
2339
+ ):
2182
2340
  assert_pseudo_column_match(X=X, X_pseudo=X_pseudo)
2183
2341
  # Needs .astype(X.dtypes) because pd.concat will convert categorical features to int/float unexpectedly. Need to convert them back to original.
2184
2342
  X_w_pseudo = pd.concat([X, X_pseudo], ignore_index=True).astype(X.dtypes)
@@ -2231,7 +2389,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2231
2389
  except Exception as exc:
2232
2390
  if self.raise_on_model_failure:
2233
2391
  # immediately raise instead of skipping to next model, useful for debugging during development
2234
- logger.warning("Model failure occurred... Raising exception instead of continuing to next model. (raise_on_model_failure=True)")
2392
+ logger.warning(
2393
+ "Model failure occurred... Raising exception instead of continuing to next model. (raise_on_model_failure=True)"
2394
+ )
2235
2395
  raise exc
2236
2396
  exception = exc # required to reference exc outside of `except` statement
2237
2397
  del_model = True
@@ -2250,13 +2410,17 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2250
2410
  elif isinstance(exception, NotEnoughCudaMemoryError):
2251
2411
  logger.warning(f"\tNot enough CUDA memory available to train {model.name}... Skipping this model.")
2252
2412
  elif isinstance(exception, ImportError):
2253
- logger.error(f"\tWarning: Exception caused {model.name} to fail during training (ImportError)... Skipping this model.")
2413
+ logger.error(
2414
+ f"\tWarning: Exception caused {model.name} to fail during training (ImportError)... Skipping this model."
2415
+ )
2254
2416
  logger.error(f"\t\t{exception}")
2255
2417
  del_model = False
2256
2418
  if self.verbosity > 2:
2257
2419
  logger.exception("Detailed Traceback:")
2258
2420
  else: # all other exceptions
2259
- logger.error(f"\tWarning: Exception caused {model.name} to fail during training... Skipping this model.")
2421
+ logger.error(
2422
+ f"\tWarning: Exception caused {model.name} to fail during training... Skipping this model."
2423
+ )
2260
2424
  logger.error(f"\t\t{exception}")
2261
2425
  if self.verbosity > 0:
2262
2426
  logger.exception("Detailed Traceback:")
@@ -2275,12 +2439,20 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2275
2439
  if del_model:
2276
2440
  del model
2277
2441
  else:
2278
- self._add_model(model=model, stack_name=stack_name, level=level, y_pred_proba_val=y_pred_proba_val, is_ray_worker=is_ray_worker)
2442
+ self._add_model(
2443
+ model=model,
2444
+ stack_name=stack_name,
2445
+ level=level,
2446
+ y_pred_proba_val=y_pred_proba_val,
2447
+ is_ray_worker=is_ray_worker,
2448
+ )
2279
2449
  model_names_trained.append(model.name)
2280
2450
  if self.low_memory:
2281
2451
  del model
2282
2452
  if exception is not None:
2283
- if self._check_raise_exception(exception=exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise):
2453
+ if self._check_raise_exception(
2454
+ exception=exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise
2455
+ ):
2284
2456
  raise exception
2285
2457
  return model_names_trained
2286
2458
 
@@ -2324,12 +2496,23 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2324
2496
  fit_num_gpus=model.fit_num_gpus,
2325
2497
  fit_num_cpus_child=model.fit_num_cpus_child,
2326
2498
  fit_num_gpus_child=model.fit_num_gpus_child,
2327
- refit_full_requires_gpu=(model.fit_num_gpus_child is not None) and (model.fit_num_gpus_child >= 1) and model._user_params.get("refit_folds", False),
2499
+ refit_full_requires_gpu=(model.fit_num_gpus_child is not None)
2500
+ and (model.fit_num_gpus_child >= 1)
2501
+ and model._user_params.get("refit_folds", False),
2328
2502
  **fit_metadata,
2329
2503
  )
2330
2504
  return model_metadata
2331
2505
 
2332
- def _add_model(self, model: AbstractModel, stack_name: str = "core", level: int = 1, y_pred_proba_val=None, _is_refit=False, is_distributed_main=False, is_ray_worker: bool = False) -> bool:
2506
+ def _add_model(
2507
+ self,
2508
+ model: AbstractModel,
2509
+ stack_name: str = "core",
2510
+ level: int = 1,
2511
+ y_pred_proba_val=None,
2512
+ _is_refit=False,
2513
+ is_distributed_main=False,
2514
+ is_ray_worker: bool = False,
2515
+ ) -> bool:
2333
2516
  """
2334
2517
  Registers the fit model in the Trainer object. Stores information such as model performance, save path, model type, and more.
2335
2518
  To use a model in Trainer, self._add_model must be called.
@@ -2391,7 +2574,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2391
2574
  f"Model '{model.name}' depends on model '{base_model_name}', but '{base_model_name}' is not in a lower stack level. ('{model.name}' level: {level}, '{base_model_name}' level: {self.model_graph.nodes[base_model_name]['level']})"
2392
2575
  )
2393
2576
  self.model_graph.add_edge(base_model_name, model.name)
2394
- self._log_model_stats(model, _is_refit=_is_refit, is_distributed_main=is_distributed_main, is_ray_worker=is_ray_worker)
2577
+ self._log_model_stats(
2578
+ model, _is_refit=_is_refit, is_distributed_main=is_distributed_main, is_ray_worker=is_ray_worker
2579
+ )
2395
2580
  if self.low_memory:
2396
2581
  del model
2397
2582
  return True
@@ -2406,7 +2591,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2406
2591
 
2407
2592
  def _save_model_y_pred_proba_val(self, model: str, y_pred_proba_val):
2408
2593
  """Cache y_pred_proba_val for later reuse to avoid redundant predict calls"""
2409
- save_pkl.save(path=self._path_to_model_attr(model=model, attribute="y_pred_proba_val"), object=y_pred_proba_val)
2594
+ save_pkl.save(
2595
+ path=self._path_to_model_attr(model=model, attribute="y_pred_proba_val"), object=y_pred_proba_val
2596
+ )
2410
2597
 
2411
2598
  def _load_model_y_pred_proba_val(self, model: str):
2412
2599
  """Load cached y_pred_proba_val for a given model"""
@@ -2449,7 +2636,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2449
2636
  sign_str = "-"
2450
2637
  else:
2451
2638
  sign_str = ""
2452
- logger.log(log_level, f"\t{round(model.val_score, 4)}\t = Validation score ({sign_str}{model.eval_metric.name})")
2639
+ logger.log(
2640
+ log_level, f"\t{round(model.val_score, 4)}\t = Validation score ({sign_str}{model.eval_metric.name})"
2641
+ )
2453
2642
  if model.fit_time is not None:
2454
2643
  logger.log(log_level, f"\t{round(model.fit_time, 2)}s\t = Training runtime")
2455
2644
  if model.predict_time is not None:
@@ -2459,13 +2648,15 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2459
2648
  if predict_n_time_per_row is not None and predict_n_size is not None:
2460
2649
  logger.log(
2461
2650
  15,
2462
- f"\t{round(1/(predict_n_time_per_row if predict_n_time_per_row else np.finfo(np.float16).eps), 1)}"
2651
+ f"\t{round(1 / (predict_n_time_per_row if predict_n_time_per_row else np.finfo(np.float16).eps), 1)}"
2463
2652
  f"\t = Inference throughput (rows/s | {int(predict_n_size)} batch size)",
2464
2653
  )
2465
2654
  if model.predict_1_time is not None:
2466
2655
  fit_metadata = model.get_fit_metadata()
2467
2656
  predict_1_batch_size = fit_metadata.get("predict_1_batch_size", None)
2468
- assert predict_1_batch_size is not None, "predict_1_batch_size cannot be None if predict_1_time is not None"
2657
+ assert predict_1_batch_size is not None, (
2658
+ "predict_1_batch_size cannot be None if predict_1_time is not None"
2659
+ )
2469
2660
 
2470
2661
  if _is_refit:
2471
2662
  predict_1_time = self.get_model_attribute(model=model.name, attribute="predict_1_child_time")
@@ -2475,23 +2666,36 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2475
2666
  predict_1_time_full = self.get_model_attribute_full(model=model.name, attribute="predict_1_time")
2476
2667
 
2477
2668
  predict_1_time_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time)
2478
- logger.log(log_level, f"\t{round(predict_1_time_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | MARGINAL)")
2669
+ logger.log(
2670
+ log_level,
2671
+ f"\t{round(predict_1_time_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | MARGINAL)",
2672
+ )
2479
2673
 
2480
2674
  predict_1_time_full_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full)
2481
- logger.log(log_level, f"\t{round(predict_1_time_full_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size)")
2675
+ logger.log(
2676
+ log_level,
2677
+ f"\t{round(predict_1_time_full_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size)",
2678
+ )
2482
2679
 
2483
2680
  if not _is_refit:
2484
2681
  predict_1_time_child = self.get_model_attribute(model=model.name, attribute="predict_1_child_time")
2485
- predict_1_time_child_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_child)
2682
+ predict_1_time_child_log, time_unit = convert_time_in_s_to_log_friendly(
2683
+ time_in_sec=predict_1_time_child
2684
+ )
2486
2685
  logger.log(
2487
2686
  log_level,
2488
2687
  f"\t{round(predict_1_time_child_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | REFIT | MARGINAL)",
2489
2688
  )
2490
2689
 
2491
- predict_1_time_full_child = self.get_model_attribute_full(model=model.name, attribute="predict_1_child_time")
2492
- predict_1_time_full_child_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_child)
2690
+ predict_1_time_full_child = self.get_model_attribute_full(
2691
+ model=model.name, attribute="predict_1_child_time"
2692
+ )
2693
+ predict_1_time_full_child_log, time_unit = convert_time_in_s_to_log_friendly(
2694
+ time_in_sec=predict_1_time_full_child
2695
+ )
2493
2696
  logger.log(
2494
- log_level, f"\t{round(predict_1_time_full_child_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | REFIT)"
2697
+ log_level,
2698
+ f"\t{round(predict_1_time_full_child_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | REFIT)",
2495
2699
  )
2496
2700
 
2497
2701
  # TODO: Split this to avoid confusion, HPO should go elsewhere?
@@ -2558,8 +2762,13 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2558
2762
  return []
2559
2763
 
2560
2764
  model_fit_kwargs = self._get_model_fit_kwargs(
2561
- X=X, X_val=X_val, time_limit=time_limit, k_fold=k_fold, fit_kwargs=fit_kwargs,
2562
- ens_sample_weight=kwargs.get("ens_sample_weight", None), label_cleaner=label_cleaner,
2765
+ X=X,
2766
+ X_val=X_val,
2767
+ time_limit=time_limit,
2768
+ k_fold=k_fold,
2769
+ fit_kwargs=fit_kwargs,
2770
+ ens_sample_weight=kwargs.get("ens_sample_weight", None),
2771
+ label_cleaner=label_cleaner,
2563
2772
  )
2564
2773
  exception = None
2565
2774
  if hyperparameter_tune_kwargs:
@@ -2583,7 +2792,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2583
2792
  try:
2584
2793
  if isinstance(model, BaggedEnsembleModel):
2585
2794
  bagged_model_fit_kwargs = self._get_bagged_model_fit_kwargs(
2586
- k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start
2795
+ k_fold=k_fold,
2796
+ k_fold_start=k_fold_start,
2797
+ k_fold_end=k_fold_end,
2798
+ n_repeats=n_repeats,
2799
+ n_repeat_start=n_repeat_start,
2587
2800
  )
2588
2801
  model_fit_kwargs.update(bagged_model_fit_kwargs)
2589
2802
  hpo_models, hpo_results = model.hyperparameter_tune(
@@ -2611,7 +2824,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2611
2824
  **model_fit_kwargs,
2612
2825
  )
2613
2826
  if len(hpo_models) == 0:
2614
- logger.warning(f"No model was trained during hyperparameter tuning {model.name}... Skipping this model.")
2827
+ logger.warning(
2828
+ f"No model was trained during hyperparameter tuning {model.name}... Skipping this model."
2829
+ )
2615
2830
  except Exception as exc:
2616
2831
  exception = exc # required to provide exc outside of `except` statement
2617
2832
  if isinstance(exception, NoStackFeatures):
@@ -2621,7 +2836,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2621
2836
  elif isinstance(exception, NoValidFeatures):
2622
2837
  logger.warning(f"\tNo valid features to train {model.name}... Skipping this model.")
2623
2838
  else:
2624
- logger.exception(f"Warning: Exception caused {model.name} to fail during hyperparameter tuning... Skipping this model.")
2839
+ logger.exception(
2840
+ f"Warning: Exception caused {model.name} to fail during hyperparameter tuning... Skipping this model."
2841
+ )
2625
2842
  logger.warning(exception)
2626
2843
  del model
2627
2844
  model_names_trained = []
@@ -2631,7 +2848,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2631
2848
  model_names_trained = []
2632
2849
  self._extra_banned_names.add(model.name)
2633
2850
  for model_hpo_name, model_info in hpo_models.items():
2634
- model_hpo = self.load_model(model_hpo_name, path=os.path.relpath(model_info["path"], self.path), model_type=type(model))
2851
+ model_hpo = self.load_model(
2852
+ model_hpo_name, path=os.path.relpath(model_info["path"], self.path), model_type=type(model)
2853
+ )
2635
2854
  logger.log(20, f"Fitted model: {model_hpo.name} ...")
2636
2855
  if self._add_model(model=model_hpo, stack_name=stack_name, level=level):
2637
2856
  model_names_trained.append(model_hpo.name)
@@ -2639,7 +2858,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2639
2858
  model_fit_kwargs.update(dict(X_pseudo=X_pseudo, y_pseudo=y_pseudo))
2640
2859
  if isinstance(model, BaggedEnsembleModel):
2641
2860
  bagged_model_fit_kwargs = self._get_bagged_model_fit_kwargs(
2642
- k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start
2861
+ k_fold=k_fold,
2862
+ k_fold_start=k_fold_start,
2863
+ k_fold_end=k_fold_end,
2864
+ n_repeats=n_repeats,
2865
+ n_repeat_start=n_repeat_start,
2643
2866
  )
2644
2867
  model_fit_kwargs.update(bagged_model_fit_kwargs)
2645
2868
  model_names_trained = self._train_and_save(
@@ -2665,7 +2888,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2665
2888
  self._callbacks_after_fit(model_names=model_names_trained, stack_name=stack_name, level=level)
2666
2889
  self.save()
2667
2890
  if exception is not None:
2668
- if self._check_raise_exception(exception=exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise):
2891
+ if self._check_raise_exception(
2892
+ exception=exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise
2893
+ ):
2669
2894
  raise exception
2670
2895
  return model_names_trained
2671
2896
 
@@ -2769,7 +2994,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2769
2994
  # TODO: Time allowance not accurate if running from fit_continue
2770
2995
  # TODO: Remove level and stack_name arguments, can get them automatically
2771
2996
  # TODO: Make sure that pretraining on X_unlabeled only happens 1 time rather than every fold of bagging. (Do during pretrain API work?)
2772
- def _train_multi_repeats(self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs) -> list[str]:
2997
+ def _train_multi_repeats(
2998
+ self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs
2999
+ ) -> list[str]:
2773
3000
  """
2774
3001
  Fits bagged ensemble models with additional folds and/or bagged repeats.
2775
3002
  Models must have already been fit prior to entering this method.
@@ -2795,7 +3022,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2795
3022
  if time_left < time_required:
2796
3023
  logger.log(15, "Not enough time left to finish repeated k-fold bagging, stopping early ...")
2797
3024
  break
2798
- logger.log(20, f"Repeating k-fold bagging: {n+1}/{n_repeats}")
3025
+ logger.log(20, f"Repeating k-fold bagging: {n + 1}/{n_repeats}")
2799
3026
  for i, model in enumerate(models_valid):
2800
3027
  if self._callback_early_stop:
2801
3028
  break
@@ -2819,7 +3046,15 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2819
3046
  time_left = time_limit - (time_start_model - time_start)
2820
3047
 
2821
3048
  models_valid_next += self._train_single_full(
2822
- X=X, y=y, model=model, k_fold_start=0, k_fold_end=None, n_repeats=n + 1, n_repeat_start=n, time_limit=time_left, **kwargs
3049
+ X=X,
3050
+ y=y,
3051
+ model=model,
3052
+ k_fold_start=0,
3053
+ k_fold_end=None,
3054
+ n_repeats=n + 1,
3055
+ n_repeat_start=n,
3056
+ time_limit=time_left,
3057
+ **kwargs,
2823
3058
  )
2824
3059
  models_valid = copy.deepcopy(models_valid_next)
2825
3060
  models_valid_next = []
@@ -2828,7 +3063,16 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2828
3063
  return models_valid
2829
3064
 
2830
3065
  def _train_multi_initial(
2831
- self, X, y, models: list[AbstractModel], k_fold, n_repeats, hyperparameter_tune_kwargs=None, time_limit=None, feature_prune_kwargs=None, **kwargs
3066
+ self,
3067
+ X,
3068
+ y,
3069
+ models: list[AbstractModel],
3070
+ k_fold,
3071
+ n_repeats,
3072
+ hyperparameter_tune_kwargs=None,
3073
+ time_limit=None,
3074
+ feature_prune_kwargs=None,
3075
+ **kwargs,
2832
3076
  ):
2833
3077
  """
2834
3078
  Fits models that have not previously been fit.
@@ -2917,7 +3161,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2917
3161
  time_limit = time_limit - (time.time() - feature_prune_time_start)
2918
3162
 
2919
3163
  fit_args["X"] = X[candidate_features]
2920
- fit_args["X_val"] = kwargs["X_val"][candidate_features] if isinstance(kwargs.get("X_val", None), pd.DataFrame) else kwargs.get("X_val", None)
3164
+ fit_args["X_val"] = (
3165
+ kwargs["X_val"][candidate_features]
3166
+ if isinstance(kwargs.get("X_val", None), pd.DataFrame)
3167
+ else kwargs.get("X_val", None)
3168
+ )
2921
3169
 
2922
3170
  if len(candidate_features) < len(X.columns):
2923
3171
  unfit_models = []
@@ -2938,7 +3186,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2938
3186
  **fit_args,
2939
3187
  )
2940
3188
  force_prune = feature_prune_kwargs.get("force_prune", False)
2941
- models = self._retain_better_pruned_models(pruned_models=pruned_models, original_prune_map=original_prune_map, force_prune=force_prune)
3189
+ models = self._retain_better_pruned_models(
3190
+ pruned_models=pruned_models, original_prune_map=original_prune_map, force_prune=force_prune
3191
+ )
2942
3192
  return models
2943
3193
 
2944
3194
  # TODO: Ban KNN from being a Stacker model outside of aux. Will need to ensemble select on all stack layers ensemble selector to make it work
@@ -2977,7 +3227,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2977
3227
  30,
2978
3228
  f"WARNING: fit_strategy='parallel', but `hyperparameter_tune_kwargs` is specified for model '{k}' with value {v}. "
2979
3229
  f"Hyperparameter tuning does not yet support `parallel` fit_strategy. "
2980
- f"Falling back to fit_strategy='sequential' ... "
3230
+ f"Falling back to fit_strategy='sequential' ... ",
2981
3231
  )
2982
3232
  fit_strategy = "sequential"
2983
3233
  break
@@ -2993,7 +3243,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
2993
3243
  f"Note: fit_strategy='parallel', but `num_cpus={num_cpus}`. "
2994
3244
  f"Running parallel mode with fewer than 12 CPUs is not recommended and has been disabled. "
2995
3245
  f'You can override this by specifying `os.environ["AG_FORCE_PARALLEL"] = "True"`. '
2996
- f"Falling back to fit_strategy='sequential' ..."
3246
+ f"Falling back to fit_strategy='sequential' ...",
2997
3247
  )
2998
3248
  fit_strategy = "sequential"
2999
3249
  if fit_strategy == "parallel":
@@ -3005,7 +3255,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3005
3255
  30,
3006
3256
  f"WARNING: fit_strategy='parallel', but `num_gpus={num_gpus}` is specified. "
3007
3257
  f"GPU is not yet supported for `parallel` fit_strategy. To enable parallel, ensure you specify `num_gpus=0` in the fit call. "
3008
- f"Falling back to fit_strategy='sequential' ... "
3258
+ f"Falling back to fit_strategy='sequential' ... ",
3009
3259
  )
3010
3260
  fit_strategy = "sequential"
3011
3261
  if fit_strategy == "parallel":
@@ -3016,7 +3266,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3016
3266
  30,
3017
3267
  f"WARNING: Exception encountered when trying to import ray (fit_strategy='parallel'). "
3018
3268
  f"ray is required for 'parallel' fit_strategy. Falling back to fit_strategy='sequential' ... "
3019
- f"\n\tException details: {e.__class__.__name__}: {e}"
3269
+ f"\n\tException details: {e.__class__.__name__}: {e}",
3020
3270
  )
3021
3271
  fit_strategy = "sequential"
3022
3272
 
@@ -3120,9 +3370,13 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3120
3370
  if time_limit is not None:
3121
3371
  # allow between 5 and 60 seconds overhead before force killing jobs to give some leniency to jobs with overhead.
3122
3372
  time_overhead = min(max(time_limit * 0.01, 5), 60)
3123
- min_time_required_base = min(self._time_limit * 0.01, 10) # This is checked in the worker thread, will skip if not satisfied
3373
+ min_time_required_base = min(
3374
+ self._time_limit * 0.01, 10
3375
+ ) # This is checked in the worker thread, will skip if not satisfied
3124
3376
  # If time remaining is less than min_time_required, avoid scheduling new jobs and only wait for existing ones to finish.
3125
- min_time_required = min_time_required_base * 1.5 + 1 # Add 50% buffer and 1 second to account for ray overhead
3377
+ min_time_required = (
3378
+ min_time_required_base * 1.5 + 1
3379
+ ) # Add 50% buffer and 1 second to account for ray overhead
3126
3380
  else:
3127
3381
  time_overhead = None
3128
3382
  min_time_required = None
@@ -3143,9 +3397,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3143
3397
 
3144
3398
  distributed_manager.deallocate_resources(job_ref=finished[0])
3145
3399
  model_name, model_path, model_type, exc, model_failure_info = ray.get(finished[0])
3146
- assert model_name in expected_model_names, (f"Unexpected model name outputted during parallel fit: {model_name}\n"
3147
- f"Valid Names: {expected_model_names}\n"
3148
- f"This should never happen. Please create a GitHub Issue.")
3400
+ assert model_name in expected_model_names, (
3401
+ f"Unexpected model name outputted during parallel fit: {model_name}\n"
3402
+ f"Valid Names: {expected_model_names}\n"
3403
+ f"This should never happen. Please create a GitHub Issue."
3404
+ )
3149
3405
  jobs_finished += 1
3150
3406
 
3151
3407
  if exc is not None or model_path is None:
@@ -3166,7 +3422,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3166
3422
  if exc_type is not None and issubclass(exc_type, InsufficientTime):
3167
3423
  logger.log(20, exc_str)
3168
3424
  else:
3169
- logger.log(20, f"Skipping {model_name if isinstance(model_name, str) else model_name.name} due to exception{extra_log}")
3425
+ logger.log(
3426
+ 20,
3427
+ f"Skipping {model_name if isinstance(model_name, str) else model_name.name} due to exception{extra_log}",
3428
+ )
3170
3429
  if model_failure_info is not None:
3171
3430
  self._models_failed_to_train_errors[model_name] = model_failure_info
3172
3431
  else:
@@ -3179,9 +3438,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3179
3438
  # Self object is not permanently mutated during worker execution, so we need to add model to the "main" self (again).
3180
3439
  # This is the synchronization point between the distributed and main processes.
3181
3440
  if self._add_model(
3182
- model_type.load(path=os.path.join(self.path, model_path), reset_paths=self.reset_paths),
3183
- stack_name=kwargs["stack_name"],
3184
- level=kwargs["level"]
3441
+ model_type.load(path=os.path.join(self.path, model_path), reset_paths=self.reset_paths),
3442
+ stack_name=kwargs["stack_name"],
3443
+ level=kwargs["level"],
3185
3444
  ):
3186
3445
  jobs_running = len(unfinished_job_refs)
3187
3446
  if can_schedule_jobs:
@@ -3199,7 +3458,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3199
3458
  logger.log(20, parallel_status_log)
3200
3459
  models_valid.append(model_name)
3201
3460
  else:
3202
- logger.log(40, f"Failed to add {model_name} to model graph. This should never happen. Please create a GitHub issue.")
3461
+ logger.log(
3462
+ 40,
3463
+ f"Failed to add {model_name} to model graph. This should never happen. Please create a GitHub issue.",
3464
+ )
3203
3465
 
3204
3466
  if not unfinished_job_refs and not distributed_manager.models_to_schedule:
3205
3467
  # Completed all jobs
@@ -3207,7 +3469,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3207
3469
 
3208
3470
  # TODO: look into what this does / how this works for distributed training
3209
3471
  if self._callback_early_stop:
3210
- logger.log(20, "Callback triggered in parallel setting. Stopping model training and cancelling remaining jobs.")
3472
+ logger.log(
3473
+ 20,
3474
+ "Callback triggered in parallel setting. Stopping model training and cancelling remaining jobs.",
3475
+ )
3211
3476
  break
3212
3477
 
3213
3478
  # Stop due to time limit after adding model
@@ -3216,7 +3481,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3216
3481
  time_left = time_limit - time_elapsed
3217
3482
  time_left_models = time_limit_models - time_elapsed
3218
3483
  if (time_left + time_overhead) <= 0:
3219
- logger.log(20, "Time limit reached for this stacking layer. Stopping model training and cancelling remaining jobs.")
3484
+ logger.log(
3485
+ 20,
3486
+ "Time limit reached for this stacking layer. Stopping model training and cancelling remaining jobs.",
3487
+ )
3220
3488
  break
3221
3489
  elif time_left_models < min_time_required:
3222
3490
  if can_schedule_jobs:
@@ -3224,7 +3492,7 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3224
3492
  logger.log(
3225
3493
  20,
3226
3494
  f"Low on time, skipping {len(distributed_manager.models_to_schedule)} "
3227
- f"pending jobs and waiting for running jobs to finish... ({time_left:.0f}s remaining time)"
3495
+ f"pending jobs and waiting for running jobs to finish... ({time_left:.0f}s remaining time)",
3228
3496
  )
3229
3497
  can_schedule_jobs = False
3230
3498
 
@@ -3321,7 +3589,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3321
3589
  ) -> list[str]:
3322
3590
  """Identical to self.train_multi_levels, but also saves the data to disk. This should only ever be called once."""
3323
3591
  if time_limit is not None and time_limit <= 0:
3324
- raise AssertionError(f"Not enough time left to train models. Consider specifying a larger time_limit. Time remaining: {round(time_limit, 2)}s")
3592
+ raise AssertionError(
3593
+ f"Not enough time left to train models. Consider specifying a larger time_limit. Time remaining: {round(time_limit, 2)}s"
3594
+ )
3325
3595
  if self.save_data and not self.is_data_saved:
3326
3596
  self.save_X(X)
3327
3597
  self.save_y(y)
@@ -3365,14 +3635,24 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3365
3635
  y_pred_proba = self._predict_proba_model(X=X, model=model, model_pred_proba_dict=model_pred_proba_dict)
3366
3636
  return get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=self.problem_type)
3367
3637
 
3368
- def _predict_proba_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
3369
- model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=[model], model_pred_proba_dict=model_pred_proba_dict)
3638
+ def _predict_proba_model(
3639
+ self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None
3640
+ ) -> np.ndarray:
3641
+ model_pred_proba_dict = self.get_model_pred_proba_dict(
3642
+ X=X, models=[model], model_pred_proba_dict=model_pred_proba_dict
3643
+ )
3370
3644
  if not isinstance(model, str):
3371
3645
  model = model.name
3372
3646
  return model_pred_proba_dict[model]
3373
3647
 
3374
3648
  def _proxy_model_feature_prune(
3375
- self, model_fit_kwargs: dict, time_limit: float, layer_fit_time: float, level: int, features: list[str], **feature_prune_kwargs: dict
3649
+ self,
3650
+ model_fit_kwargs: dict,
3651
+ time_limit: float,
3652
+ layer_fit_time: float,
3653
+ level: int,
3654
+ features: list[str],
3655
+ **feature_prune_kwargs: dict,
3376
3656
  ) -> list[str]:
3377
3657
  """
3378
3658
  Uses the best LightGBM-based base learner of this layer to perform time-aware permutation feature importance based feature pruning.
@@ -3414,7 +3694,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3414
3694
  if feature_prune_time_limit is not None:
3415
3695
  feature_prune_time_limit = min(max(time_limit - layer_fit_time, 0), feature_prune_time_limit)
3416
3696
  elif time_limit is not None:
3417
- feature_prune_time_limit = min(max(time_limit - layer_fit_time, 0), max(k * layer_fit_time, 0.05 * time_limit))
3697
+ feature_prune_time_limit = min(
3698
+ max(time_limit - layer_fit_time, 0), max(k * layer_fit_time, 0.05 * time_limit)
3699
+ )
3418
3700
  else:
3419
3701
  feature_prune_time_limit = max(k * layer_fit_time, 300)
3420
3702
 
@@ -3425,7 +3707,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3425
3707
  )
3426
3708
  return features
3427
3709
  selector = FeatureSelector(
3428
- model=proxy_model, time_limit=feature_prune_time_limit, raise_exception=raise_exception_on_fail, problem_type=self.problem_type
3710
+ model=proxy_model,
3711
+ time_limit=feature_prune_time_limit,
3712
+ raise_exception=raise_exception_on_fail,
3713
+ problem_type=self.problem_type,
3429
3714
  )
3430
3715
  candidate_features = selector.select_features(**feature_prune_kwargs, **model_fit_kwargs)
3431
3716
  return candidate_features
@@ -3433,7 +3718,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3433
3718
  def _get_default_proxy_model_class(self):
3434
3719
  return None
3435
3720
 
3436
- def _retain_better_pruned_models(self, pruned_models: list[str], original_prune_map: dict, force_prune: bool = False) -> list[str]:
3721
+ def _retain_better_pruned_models(
3722
+ self, pruned_models: list[str], original_prune_map: dict, force_prune: bool = False
3723
+ ) -> list[str]:
3437
3724
  """
3438
3725
  Compares models fit on the pruned set of features with their counterpart, models fit on full set of features.
3439
3726
  Take the model that achieved a higher validation set score and delete the other from self.model_graph.
@@ -3460,15 +3747,24 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3460
3747
  pruned_score = leaderboard[leaderboard["model"] == pruned_model]["score_val"].item()
3461
3748
  score_str = f"({round(pruned_score, 4)} vs {round(original_score, 4)})"
3462
3749
  if force_prune:
3463
- logger.log(30, f"Pruned score vs original score is {score_str}. Replacing original model since force_prune=True...")
3750
+ logger.log(
3751
+ 30,
3752
+ f"Pruned score vs original score is {score_str}. Replacing original model since force_prune=True...",
3753
+ )
3464
3754
  self.delete_models(models_to_delete=original_model, dry_run=False)
3465
3755
  models.append(pruned_model)
3466
3756
  elif pruned_score > original_score:
3467
- logger.log(30, f"Model trained with feature pruning score is better than original model's score {score_str}. Replacing original model...")
3757
+ logger.log(
3758
+ 30,
3759
+ f"Model trained with feature pruning score is better than original model's score {score_str}. Replacing original model...",
3760
+ )
3468
3761
  self.delete_models(models_to_delete=original_model, dry_run=False)
3469
3762
  models.append(pruned_model)
3470
3763
  else:
3471
- logger.log(30, f"Model trained with feature pruning score is not better than original model's score {score_str}. Keeping original model...")
3764
+ logger.log(
3765
+ 30,
3766
+ f"Model trained with feature pruning score is not better than original model's score {score_str}. Keeping original model...",
3767
+ )
3472
3768
  self.delete_models(models_to_delete=pruned_model, dry_run=False)
3473
3769
  models.append(original_model)
3474
3770
  return models
@@ -3759,7 +4055,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3759
4055
  bagged_info = model_info[model_name].get("bagged_info", {})
3760
4056
  custom_info["num_models"] = bagged_info.get("num_child_models", 1)
3761
4057
  custom_info["memory_size"] = bagged_info.get("max_memory_size", model_info[model_name]["memory_size"])
3762
- custom_info["memory_size_min"] = bagged_info.get("min_memory_size", model_info[model_name]["memory_size"])
4058
+ custom_info["memory_size_min"] = bagged_info.get(
4059
+ "min_memory_size", model_info[model_name]["memory_size"]
4060
+ )
3763
4061
  custom_info["compile_time"] = bagged_info.get("compile_time", model_info[model_name]["compile_time"])
3764
4062
  custom_info["child_model_type"] = bagged_info.get("child_model_type", None)
3765
4063
  custom_info["child_hyperparameters"] = bagged_info.get("child_hyperparameters", None)
@@ -3767,13 +4065,23 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3767
4065
  custom_info["child_ag_args_fit"] = bagged_info.get("child_ag_args_fit", None)
3768
4066
  custom_model_info[model_name] = custom_info
3769
4067
 
3770
- model_info_keys = ["num_features", "model_type", "hyperparameters", "hyperparameters_fit", "ag_args_fit", "features"]
4068
+ model_info_keys = [
4069
+ "num_features",
4070
+ "model_type",
4071
+ "hyperparameters",
4072
+ "hyperparameters_fit",
4073
+ "ag_args_fit",
4074
+ "features",
4075
+ ]
3771
4076
  model_info_sum_keys = []
3772
4077
  for key in model_info_keys:
3773
4078
  model_info_dict[key] = [model_info[model_name][key] for model_name in model_names]
3774
4079
  if key in model_info_sum_keys:
3775
4080
  key_dict = {model_name: model_info[model_name][key] for model_name in model_names}
3776
- model_info_dict[key + "_full"] = [self.get_model_attribute_full(model=model_name, attribute=key_dict) for model_name in model_names]
4081
+ model_info_dict[key + "_full"] = [
4082
+ self.get_model_attribute_full(model=model_name, attribute=key_dict)
4083
+ for model_name in model_names
4084
+ ]
3777
4085
 
3778
4086
  model_info_keys = [
3779
4087
  "num_models",
@@ -3796,7 +4104,8 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3796
4104
  key_dict = {model_name: custom_model_info[model_name][key] for model_name in model_names}
3797
4105
  for column_name, func in model_info_full_keys[key]:
3798
4106
  model_info_dict[column_name] = [
3799
- self.get_model_attribute_full(model=model_name, attribute=key_dict, func=func) for model_name in model_names
4107
+ self.get_model_attribute_full(model=model_name, attribute=key_dict, func=func)
4108
+ for model_name in model_names
3800
4109
  ]
3801
4110
 
3802
4111
  ancestors = [list(nx.dag.ancestors(self.model_graph, model_name)) for model_name in model_names]
@@ -3827,7 +4136,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
3827
4136
  **model_info_dict,
3828
4137
  }
3829
4138
  )
3830
- df_sorted = df.sort_values(by=["score_val", "pred_time_val", "model"], ascending=[False, True, False]).reset_index(drop=True)
4139
+ df_sorted = df.sort_values(
4140
+ by=["score_val", "pred_time_val", "model"], ascending=[False, True, False]
4141
+ ).reset_index(drop=True)
3831
4142
 
3832
4143
  df_columns_lst = df_sorted.columns.tolist()
3833
4144
  explicit_order = [
@@ -4009,7 +4320,14 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4009
4320
  return info
4010
4321
 
4011
4322
  def reduce_memory_size(
4012
- self, remove_data=True, remove_fit_stack=False, remove_fit=True, remove_info=False, requires_save=True, reduce_children=False, **kwargs
4323
+ self,
4324
+ remove_data=True,
4325
+ remove_fit_stack=False,
4326
+ remove_fit=True,
4327
+ remove_info=False,
4328
+ requires_save=True,
4329
+ reduce_children=False,
4330
+ **kwargs,
4013
4331
  ):
4014
4332
  if remove_data and self.is_data_saved:
4015
4333
  data_files = [
@@ -4056,7 +4374,14 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4056
4374
  # TODO: Also enable deletion of models which didn't succeed in training (files may still be persisted)
4057
4375
  # This includes the original HPO fold for stacking
4058
4376
  # Deletes specified models from trainer and from disk (if delete_from_disk=True).
4059
- def delete_models(self, models_to_keep=None, models_to_delete=None, allow_delete_cascade=False, delete_from_disk=True, dry_run=True):
4377
+ def delete_models(
4378
+ self,
4379
+ models_to_keep=None,
4380
+ models_to_delete=None,
4381
+ allow_delete_cascade=False,
4382
+ delete_from_disk=True,
4383
+ dry_run=True,
4384
+ ):
4060
4385
  if models_to_keep is not None and models_to_delete is not None:
4061
4386
  raise ValueError("Exactly one of [models_to_keep, models_to_delete] must be set.")
4062
4387
  if models_to_keep is not None:
@@ -4176,7 +4501,10 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4176
4501
  if augmentation_data is not None and teacher_preds is None:
4177
4502
  raise ValueError("augmentation_data must be None if teacher_preds is None")
4178
4503
 
4179
- logger.log(20, f"Distilling with teacher='{teacher}', teacher_preds={str(teacher_preds)}, augment_method={str(augment_method)} ...")
4504
+ logger.log(
4505
+ 20,
4506
+ f"Distilling with teacher='{teacher}', teacher_preds={str(teacher_preds)}, augment_method={str(augment_method)} ...",
4507
+ )
4180
4508
  if teacher not in self.get_model_names(can_infer=True):
4181
4509
  raise AssertionError(
4182
4510
  f"Teacher model '{teacher}' is not a valid teacher model! Either it does not exist or it cannot infer on new data.\n"
@@ -4197,7 +4525,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4197
4525
  raise ValueError("X_val cannot be None when y_val specified.")
4198
4526
  if holdout_frac is None:
4199
4527
  holdout_frac = default_holdout_frac(len(X), hyperparameter_tune)
4200
- X, X_val, y, y_val = generate_train_test_split(X, y, problem_type=self.problem_type, test_size=holdout_frac)
4528
+ X, X_val, y, y_val = generate_train_test_split(
4529
+ X, y, problem_type=self.problem_type, test_size=holdout_frac
4530
+ )
4201
4531
 
4202
4532
  y_val_og = y_val.copy()
4203
4533
  og_bagged_mode = self.bagged_mode
@@ -4211,7 +4541,8 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4211
4541
  if teacher_preds is None or teacher_preds == "onehot":
4212
4542
  augment_method = None
4213
4543
  logger.log(
4214
- 20, "Training students without a teacher model. Set teacher_preds = 'soft' or 'hard' to distill using the best AutoGluon predictor as teacher."
4544
+ 20,
4545
+ "Training students without a teacher model. Set teacher_preds = 'soft' or 'hard' to distill using the best AutoGluon predictor as teacher.",
4215
4546
  )
4216
4547
 
4217
4548
  if teacher_preds in ["onehot", "soft"]:
@@ -4221,8 +4552,12 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4221
4552
  if augment_method is None and augmentation_data is None:
4222
4553
  if teacher_preds == "hard":
4223
4554
  y_pred = pd.Series(self.predict(X, model=teacher))
4224
- if (self.problem_type != REGRESSION) and (len(y_pred.unique()) < len(y.unique())): # add missing labels
4225
- logger.log(15, "Adding missing labels to distillation dataset by including some real training examples")
4555
+ if (self.problem_type != REGRESSION) and (
4556
+ len(y_pred.unique()) < len(y.unique())
4557
+ ): # add missing labels
4558
+ logger.log(
4559
+ 15, "Adding missing labels to distillation dataset by including some real training examples"
4560
+ )
4226
4561
  indices_to_add = []
4227
4562
  for clss in y.unique():
4228
4563
  if clss not in y_pred.unique():
@@ -4244,7 +4579,11 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4244
4579
  y = pd.Series(y)
4245
4580
  else:
4246
4581
  X_aug = augment_data(
4247
- X=X, feature_metadata=self.feature_metadata, augmentation_data=augmentation_data, augment_method=augment_method, augment_args=augment_args
4582
+ X=X,
4583
+ feature_metadata=self.feature_metadata,
4584
+ augmentation_data=augmentation_data,
4585
+ augment_method=augment_method,
4586
+ augment_args=augment_args,
4248
4587
  )
4249
4588
  if len(X_aug) > 0:
4250
4589
  if teacher_preds == "hard":
@@ -4326,8 +4665,14 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4326
4665
  return distilled_model_names
4327
4666
 
4328
4667
  def _get_model_fit_kwargs(
4329
- self, X: pd.DataFrame, X_val: pd.DataFrame, time_limit: float, k_fold: int,
4330
- fit_kwargs: dict, ens_sample_weight: list | None = None, label_cleaner: None | LabelCleaner = None
4668
+ self,
4669
+ X: pd.DataFrame,
4670
+ X_val: pd.DataFrame,
4671
+ time_limit: float,
4672
+ k_fold: int,
4673
+ fit_kwargs: dict,
4674
+ ens_sample_weight: list | None = None,
4675
+ label_cleaner: None | LabelCleaner = None,
4331
4676
  ) -> dict:
4332
4677
  # Returns kwargs to be passed to AbstractModel's fit function
4333
4678
  if fit_kwargs is None:
@@ -4338,13 +4683,19 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4338
4683
  X, w_train = extract_column(X, self.sample_weight)
4339
4684
  if w_train is not None: # may be None for ensemble
4340
4685
  # TODO: consider moving weight normalization into AbstractModel.fit()
4341
- model_fit_kwargs["sample_weight"] = w_train.values / w_train.mean() # normalization can affect gradient algorithms like boosting
4686
+ model_fit_kwargs["sample_weight"] = (
4687
+ w_train.values / w_train.mean()
4688
+ ) # normalization can affect gradient algorithms like boosting
4342
4689
  if X_val is not None:
4343
4690
  X_val, w_val = extract_column(X_val, self.sample_weight)
4344
- if self.weight_evaluation and w_val is not None: # ignore validation sample weights unless weight_evaluation specified
4691
+ if (
4692
+ self.weight_evaluation and w_val is not None
4693
+ ): # ignore validation sample weights unless weight_evaluation specified
4345
4694
  model_fit_kwargs["sample_weight_val"] = w_val.values / w_val.mean()
4346
4695
  if ens_sample_weight is not None:
4347
- model_fit_kwargs["sample_weight"] = ens_sample_weight # sample weights to use for weighted ensemble only
4696
+ model_fit_kwargs["sample_weight"] = (
4697
+ ens_sample_weight # sample weights to use for weighted ensemble only
4698
+ )
4348
4699
  if self._groups is not None and "groups" not in model_fit_kwargs:
4349
4700
  if k_fold == self.k_fold: # don't do this on refit full
4350
4701
  model_fit_kwargs["groups"] = self._groups
@@ -4357,14 +4708,21 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4357
4708
  raise AssertionError(f"Missing expected parameter 'feature_metadata'.")
4358
4709
  return model_fit_kwargs
4359
4710
 
4360
- def _get_bagged_model_fit_kwargs(self, k_fold: int, k_fold_start: int, k_fold_end: int, n_repeats: int, n_repeat_start: int) -> dict:
4711
+ def _get_bagged_model_fit_kwargs(
4712
+ self, k_fold: int, k_fold_start: int, k_fold_end: int, n_repeats: int, n_repeat_start: int
4713
+ ) -> dict:
4361
4714
  # Returns additional kwargs (aside from _get_model_fit_kwargs) to be passed to BaggedEnsembleModel's fit function
4362
4715
  if k_fold is None:
4363
4716
  k_fold = self.k_fold
4364
4717
  if n_repeats is None:
4365
4718
  n_repeats = self.n_repeats
4366
4719
  return dict(
4367
- k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start, compute_base_preds=False
4720
+ k_fold=k_fold,
4721
+ k_fold_start=k_fold_start,
4722
+ k_fold_end=k_fold_end,
4723
+ n_repeats=n_repeats,
4724
+ n_repeat_start=n_repeat_start,
4725
+ compute_base_preds=False,
4368
4726
  )
4369
4727
 
4370
4728
  def _get_feature_prune_proxy_model(self, proxy_model_class: AbstractModel | None, level: int) -> AbstractModel:
@@ -4375,14 +4733,20 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4375
4733
  """
4376
4734
  proxy_model = None
4377
4735
  if isinstance(proxy_model_class, str):
4378
- raise AssertionError(f"proxy_model_class must be a subclass of AbstractModel. Was instead a string: {proxy_model_class}")
4736
+ raise AssertionError(
4737
+ f"proxy_model_class must be a subclass of AbstractModel. Was instead a string: {proxy_model_class}"
4738
+ )
4379
4739
  banned_models = [GreedyWeightedEnsembleModel, SimpleWeightedEnsembleModel]
4380
- assert proxy_model_class not in banned_models, "WeightedEnsemble models cannot be feature pruning proxy models."
4740
+ assert proxy_model_class not in banned_models, (
4741
+ "WeightedEnsemble models cannot be feature pruning proxy models."
4742
+ )
4381
4743
 
4382
4744
  leaderboard = self.leaderboard()
4383
4745
  banned_names = []
4384
4746
  candidate_model_rows = leaderboard[(~leaderboard["score_val"].isna()) & (leaderboard["stack_level"] == level)]
4385
- candidate_models_type_inner = self.get_models_attribute_dict(attribute="type_inner", models=candidate_model_rows["model"])
4747
+ candidate_models_type_inner = self.get_models_attribute_dict(
4748
+ attribute="type_inner", models=candidate_model_rows["model"]
4749
+ )
4386
4750
  for model_name, type_inner in candidate_models_type_inner.copy().items():
4387
4751
  if type_inner in banned_models:
4388
4752
  banned_names.append(model_name)
@@ -4390,18 +4754,28 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4390
4754
  banned_names = set(banned_names)
4391
4755
  candidate_model_rows = candidate_model_rows[~candidate_model_rows["model"].isin(banned_names)]
4392
4756
  if proxy_model_class is not None:
4393
- candidate_model_names = [model_name for model_name, model_class in candidate_models_type_inner.items() if model_class == proxy_model_class]
4757
+ candidate_model_names = [
4758
+ model_name
4759
+ for model_name, model_class in candidate_models_type_inner.items()
4760
+ if model_class == proxy_model_class
4761
+ ]
4394
4762
  candidate_model_rows = candidate_model_rows[candidate_model_rows["model"].isin(candidate_model_names)]
4395
4763
  if len(candidate_model_rows) == 0:
4396
4764
  if proxy_model_class is None:
4397
4765
  logger.warning(f"No models from level {level} have been successfully fit. Skipping feature pruning.")
4398
4766
  else:
4399
- logger.warning(f"No models of type {proxy_model_class} have finished training in level {level}. Skipping feature pruning.")
4767
+ logger.warning(
4768
+ f"No models of type {proxy_model_class} have finished training in level {level}. Skipping feature pruning."
4769
+ )
4400
4770
  return proxy_model
4401
- best_candidate_model_rows = candidate_model_rows.loc[candidate_model_rows["score_val"] == candidate_model_rows["score_val"].max()]
4771
+ best_candidate_model_rows = candidate_model_rows.loc[
4772
+ candidate_model_rows["score_val"] == candidate_model_rows["score_val"].max()
4773
+ ]
4402
4774
  return self.load_model(best_candidate_model_rows.loc[best_candidate_model_rows["fit_time"].idxmin()]["model"])
4403
4775
 
4404
- def calibrate_model(self, model_name: str | None = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0):
4776
+ def calibrate_model(
4777
+ self, model_name: str | None = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0
4778
+ ):
4405
4779
  """
4406
4780
  Applies temperature scaling to a model.
4407
4781
  Applies inverse softmax to predicted probs then trains temperature scalar
@@ -4464,12 +4838,16 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4464
4838
  model = self.load_model(model_name=model_name)
4465
4839
  if self.problem_type == QUANTILE:
4466
4840
  logger.log(15, f"Conformity scores being computed to calibrate model: {model_name}")
4467
- conformalize = compute_conformity_score(y_val_pred=y_val_probs, y_val=y_val, quantile_levels=self.quantile_levels)
4841
+ conformalize = compute_conformity_score(
4842
+ y_val_pred=y_val_probs, y_val=y_val, quantile_levels=self.quantile_levels
4843
+ )
4468
4844
  model.conformalize = conformalize
4469
4845
  model.save()
4470
4846
  else:
4471
4847
  logger.log(15, f"Temperature scaling term being tuned for model: {model_name}")
4472
- temp_scalar = tune_temperature_scaling(y_val_probs=y_val_probs, y_val=y_val, init_val=init_val, max_iter=max_iter, lr=lr)
4848
+ temp_scalar = tune_temperature_scaling(
4849
+ y_val_probs=y_val_probs, y_val=y_val, init_val=init_val, max_iter=max_iter, lr=lr
4850
+ )
4473
4851
  if temp_scalar is None:
4474
4852
  logger.log(
4475
4853
  15,
@@ -4484,7 +4862,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4484
4862
  else:
4485
4863
  # Check that scaling improves performance for the target metric
4486
4864
  score_without_temp = self.score_with_y_pred_proba(y=y_val, y_pred_proba=y_val_probs_og, weights=None)
4487
- scaled_y_val_probs = apply_temperature_scaling(y_val_probs, temp_scalar, problem_type=self.problem_type, transform_binary_proba=False)
4865
+ scaled_y_val_probs = apply_temperature_scaling(
4866
+ y_val_probs, temp_scalar, problem_type=self.problem_type, transform_binary_proba=False
4867
+ )
4488
4868
  score_with_temp = self.score_with_y_pred_proba(y=y_val, y_pred_proba=scaled_y_val_probs, weights=None)
4489
4869
 
4490
4870
  if score_with_temp > score_without_temp:
@@ -4507,7 +4887,9 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4507
4887
  **kwargs,
4508
4888
  ) -> float:
4509
4889
  # TODO: Docstring
4510
- assert self.problem_type == BINARY, f'calibrate_decision_threshold is only available for `problem_type="{BINARY}"`'
4890
+ assert self.problem_type == BINARY, (
4891
+ f'calibrate_decision_threshold is only available for `problem_type="{BINARY}"`'
4892
+ )
4511
4893
 
4512
4894
  if metric is None:
4513
4895
  metric = self.eval_metric
@@ -4570,22 +4952,38 @@ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
4570
4952
  @staticmethod
4571
4953
  def _validate_num_classes(num_classes: int | None, problem_type: str):
4572
4954
  if problem_type == BINARY:
4573
- assert num_classes is not None and num_classes == 2, f"num_classes must be 2 when problem_type='{problem_type}' (num_classes={num_classes})"
4955
+ assert num_classes is not None and num_classes == 2, (
4956
+ f"num_classes must be 2 when problem_type='{problem_type}' (num_classes={num_classes})"
4957
+ )
4574
4958
  elif problem_type in [MULTICLASS, SOFTCLASS]:
4575
- assert num_classes is not None and num_classes >= 2, f"num_classes must be >=2 when problem_type='{problem_type}' (num_classes={num_classes})"
4959
+ assert num_classes is not None and num_classes >= 2, (
4960
+ f"num_classes must be >=2 when problem_type='{problem_type}' (num_classes={num_classes})"
4961
+ )
4576
4962
  elif problem_type in [REGRESSION, QUANTILE]:
4577
- assert num_classes is None, f"num_classes must be None when problem_type='{problem_type}' (num_classes={num_classes})"
4963
+ assert num_classes is None, (
4964
+ f"num_classes must be None when problem_type='{problem_type}' (num_classes={num_classes})"
4965
+ )
4578
4966
  else:
4579
- raise AssertionError(f"Unknown problem_type: '{problem_type}'. Valid problem types: {[BINARY, MULTICLASS, REGRESSION, SOFTCLASS, QUANTILE]}")
4967
+ raise AssertionError(
4968
+ f"Unknown problem_type: '{problem_type}'. Valid problem types: {[BINARY, MULTICLASS, REGRESSION, SOFTCLASS, QUANTILE]}"
4969
+ )
4580
4970
 
4581
4971
  @staticmethod
4582
4972
  def _validate_quantile_levels(quantile_levels: list[float] | np.ndarray | None, problem_type: str):
4583
4973
  if problem_type == QUANTILE:
4584
- assert quantile_levels is not None, f"quantile_levels must not be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4585
- assert isinstance(quantile_levels, (list, np.ndarray)), f"quantile_levels must be a list or np.ndarray (quantile_levels={quantile_levels})"
4586
- assert len(quantile_levels) > 0, f"quantile_levels must not be an empty list (quantile_levels={quantile_levels})"
4974
+ assert quantile_levels is not None, (
4975
+ f"quantile_levels must not be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4976
+ )
4977
+ assert isinstance(quantile_levels, (list, np.ndarray)), (
4978
+ f"quantile_levels must be a list or np.ndarray (quantile_levels={quantile_levels})"
4979
+ )
4980
+ assert len(quantile_levels) > 0, (
4981
+ f"quantile_levels must not be an empty list (quantile_levels={quantile_levels})"
4982
+ )
4587
4983
  else:
4588
- assert quantile_levels is None, f"quantile_levels must be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4984
+ assert quantile_levels is None, (
4985
+ f"quantile_levels must be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4986
+ )
4589
4987
 
4590
4988
 
4591
4989
  def _detached_train_multi_fold(
@@ -4603,23 +5001,23 @@ def _detached_train_multi_fold(
4603
5001
  kwargs: dict,
4604
5002
  ) -> list[str]:
4605
5003
  """Dedicated class-detached function to train a single model on multiple folds."""
4606
- if isinstance(model,str):
5004
+ if isinstance(model, str):
4607
5005
  model = _self.load_model(model)
4608
5006
  elif _self.low_memory:
4609
5007
  model = copy.deepcopy(model)
4610
- if hyperparameter_tune_kwargs is not None and isinstance(hyperparameter_tune_kwargs,dict):
4611
- hyperparameter_tune_kwargs_model = hyperparameter_tune_kwargs.get(model.name,None)
5008
+ if hyperparameter_tune_kwargs is not None and isinstance(hyperparameter_tune_kwargs, dict):
5009
+ hyperparameter_tune_kwargs_model = hyperparameter_tune_kwargs.get(model.name, None)
4612
5010
  else:
4613
- hyperparameter_tune_kwargs_model=None
5011
+ hyperparameter_tune_kwargs_model = None
4614
5012
  # TODO: Only update scores when finished, only update model as part of final models if finished!
4615
5013
  if time_split:
4616
- time_left=time_limit_model_split
5014
+ time_left = time_limit_model_split
4617
5015
  else:
4618
5016
  if time_limit is None:
4619
- time_left=None
5017
+ time_left = None
4620
5018
  else:
4621
- time_start_model=time.time()
4622
- time_left=time_limit-(time_start_model-time_start)
5019
+ time_start_model = time.time()
5020
+ time_left = time_limit - (time_start_model - time_start)
4623
5021
 
4624
5022
  model_name_trained_lst = _self._train_single_full(
4625
5023
  X,
@@ -4628,7 +5026,7 @@ def _detached_train_multi_fold(
4628
5026
  time_limit=time_left,
4629
5027
  hyperparameter_tune_kwargs=hyperparameter_tune_kwargs_model,
4630
5028
  is_ray_worker=is_ray_worker,
4631
- **kwargs
5029
+ **kwargs,
4632
5030
  )
4633
5031
 
4634
5032
  if _self.low_memory:
@@ -4692,7 +5090,13 @@ def _remote_train_multi_fold(
4692
5090
  model_name = model if isinstance(model, str) else model.name
4693
5091
  return model_name, None, None, None, None
4694
5092
  model_name = model_name_list[0]
4695
- return model_name, _self.get_model_attribute(model=model_name, attribute="path"), _self.get_model_attribute(model=model_name, attribute="type"), None, None
5093
+ return (
5094
+ model_name,
5095
+ _self.get_model_attribute(model=model_name, attribute="path"),
5096
+ _self.get_model_attribute(model=model_name, attribute="type"),
5097
+ None,
5098
+ None,
5099
+ )
4696
5100
 
4697
5101
 
4698
5102
  def _detached_refit_single_full(
@@ -4709,26 +5113,26 @@ def _detached_refit_single_full(
4709
5113
  fit_strategy: Literal["sequential", "parallel"] = "sequential",
4710
5114
  ) -> tuple[str, list[str]]:
4711
5115
  # TODO: loading the model is the reasons we must allocate GPU resources for this job in cases where models require GPU when loaded from disk
4712
- model=_self.load_model(model)
5116
+ model = _self.load_model(model)
4713
5117
  model_name = model.name
4714
5118
  reuse_first_fold = False
4715
5119
 
4716
- if isinstance(model,BaggedEnsembleModel):
5120
+ if isinstance(model, BaggedEnsembleModel):
4717
5121
  # Reuse if model is already _FULL and no X_val
4718
5122
  if X_val is None:
4719
5123
  reuse_first_fold = not model._bagged_mode
4720
5124
 
4721
5125
  if not reuse_first_fold:
4722
- if isinstance(model,BaggedEnsembleModel):
4723
- can_refit_full=model._get_tags_child().get("can_refit_full",False)
5126
+ if isinstance(model, BaggedEnsembleModel):
5127
+ can_refit_full = model._get_tags_child().get("can_refit_full", False)
4724
5128
  else:
4725
- can_refit_full=model._get_tags().get("can_refit_full",False)
5129
+ can_refit_full = model._get_tags().get("can_refit_full", False)
4726
5130
  reuse_first_fold = not can_refit_full
4727
5131
 
4728
5132
  if not reuse_first_fold:
4729
- model_full=model.convert_to_refit_full_template()
5133
+ model_full = model.convert_to_refit_full_template()
4730
5134
  # Mitigates situation where bagged models barely had enough memory and refit requires more. Worst case results in OOM, but this lowers chance of failure.
4731
- model_full._user_params_aux["max_memory_usage_ratio"]=model.params_aux["max_memory_usage_ratio"]*1.15
5135
+ model_full._user_params_aux["max_memory_usage_ratio"] = model.params_aux["max_memory_usage_ratio"] * 1.15
4732
5136
  # Re-set user specified training resources.
4733
5137
  # FIXME: this is technically also a bug for non-distributed mode, but there it is good to use more/all resources per refit.
4734
5138
  # FIXME: Unsure if it is better to do model.fit_num_cpus or model.fit_num_cpus_child,
@@ -4742,7 +5146,7 @@ def _detached_refit_single_full(
4742
5146
  if model.fit_num_gpus_child is not None:
4743
5147
  model_full._user_params_aux["num_gpus"] = model.fit_num_gpus_child
4744
5148
  # TODO: Do it for all models in the level at once to avoid repeated processing of data?
4745
- base_model_names=_self.get_base_model_names(model_name)
5149
+ base_model_names = _self.get_base_model_names(model_name)
4746
5150
  # FIXME: Logs for inference speed (1 row) are incorrect because
4747
5151
  # parents are non-refit models in this sequence and later correct after logging.
4748
5152
  # Avoiding fix at present to minimize hacks in the code.
@@ -4765,25 +5169,30 @@ def _detached_refit_single_full(
4765
5169
  refit_full=True,
4766
5170
  **kwargs,
4767
5171
  )
4768
- if len(models_trained)==0:
4769
- reuse_first_fold=True
4770
- logger.log(30,f"WARNING: Refit training failure detected for '{model_name}'... "
4771
- f"Falling back to using first fold to avoid downstream exception."
4772
- f"\n\tThis is likely due to an out-of-memory error or other memory related issue. "
4773
- f"\n\tPlease create a GitHub issue if this was triggered from a non-memory related problem.",)
4774
- if not model.params.get("save_bag_folds",True):
4775
- raise AssertionError(f"Cannot avoid training failure during refit for '{model_name}' by falling back to "
4776
- f"copying the first fold because it does not exist! (save_bag_folds=False)"
4777
- f"\n\tPlease specify `save_bag_folds=True` in the `.fit` call to avoid this exception.")
5172
+ if len(models_trained) == 0:
5173
+ reuse_first_fold = True
5174
+ logger.log(
5175
+ 30,
5176
+ f"WARNING: Refit training failure detected for '{model_name}'... "
5177
+ f"Falling back to using first fold to avoid downstream exception."
5178
+ f"\n\tThis is likely due to an out-of-memory error or other memory related issue. "
5179
+ f"\n\tPlease create a GitHub issue if this was triggered from a non-memory related problem.",
5180
+ )
5181
+ if not model.params.get("save_bag_folds", True):
5182
+ raise AssertionError(
5183
+ f"Cannot avoid training failure during refit for '{model_name}' by falling back to "
5184
+ f"copying the first fold because it does not exist! (save_bag_folds=False)"
5185
+ f"\n\tPlease specify `save_bag_folds=True` in the `.fit` call to avoid this exception."
5186
+ )
4778
5187
 
4779
5188
  if reuse_first_fold:
4780
5189
  # Perform fallback black-box refit logic that doesn't retrain.
4781
- model_full=model.convert_to_refit_full_via_copy()
5190
+ model_full = model.convert_to_refit_full_via_copy()
4782
5191
  # FIXME: validation time not correct for infer 1 batch time, needed to hack _is_refit=True to fix
4783
- logger.log(20,f"Fitting model: {model_full.name} | Skipping fit via cloning parent ...")
4784
- _self._add_model(model_full,stack_name=REFIT_FULL_NAME,level=level,_is_refit=True)
5192
+ logger.log(20, f"Fitting model: {model_full.name} | Skipping fit via cloning parent ...")
5193
+ _self._add_model(model_full, stack_name=REFIT_FULL_NAME, level=level, _is_refit=True)
4785
5194
  _self.save_model(model_full)
4786
- models_trained=[model_full.name]
5195
+ models_trained = [model_full.name]
4787
5196
 
4788
5197
  return model_name, models_trained
4789
5198
 
@@ -4819,4 +5228,9 @@ def _remote_refit_single_full(
4819
5228
  # We always just refit one model per call, so this must be the case.
4820
5229
  assert len(models_trained) == 1
4821
5230
  refitted_model_name = models_trained[0]
4822
- return model_name, refitted_model_name, _self.get_model_attribute(model=refitted_model_name,attribute="path"),_self.get_model_attribute(model=refitted_model_name, attribute="type")
5231
+ return (
5232
+ model_name,
5233
+ refitted_model_name,
5234
+ _self.get_model_attribute(model=refitted_model_name, attribute="path"),
5235
+ _self.get_model_attribute(model=refitted_model_name, attribute="type"),
5236
+ )