autogluon.tabular 1.5.0b20251228__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.0b20251228-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -1,2 +1,2 @@
1
- from ._model_registry import ModelRegistry
2
1
  from ._ag_model_registry import ag_model_registry
2
+ from ._model_registry import ModelRegistry
@@ -4,7 +4,6 @@ from autogluon.core.models import (
4
4
  SimpleWeightedEnsembleModel,
5
5
  )
6
6
 
7
- from . import ModelRegistry
8
7
  from ..models import (
9
8
  BoostedRulesModel,
10
9
  CatBoostModel,
@@ -18,25 +17,25 @@ from ..models import (
18
17
  KNNModel,
19
18
  LGBModel,
20
19
  LinearModel,
20
+ MitraModel,
21
21
  MultiModalPredictorModel,
22
22
  NNFastAiTabularModel,
23
23
  PrepLGBModel,
24
24
  RealMLPModel,
25
+ RealTabPFNv2Model,
26
+ RealTabPFNv25Model,
25
27
  RFModel,
26
28
  RuleFitModel,
27
29
  TabDPTModel,
28
30
  TabICLModel,
29
31
  TabMModel,
30
32
  TabPFNMixModel,
31
- MitraModel,
32
- RealTabPFNv2Model,
33
- RealTabPFNv25Model,
34
33
  TabularNeuralNetTorchModel,
35
34
  TextPredictorModel,
36
35
  XGBoostModel,
37
36
  XTModel,
38
37
  )
39
-
38
+ from ._model_registry import ModelRegistry
40
39
 
41
40
  # When adding a new model officially to AutoGluon, the model class should be added to the bottom of this list.
42
41
  REGISTERED_MODEL_CLS_LST = [
@@ -35,6 +35,7 @@ class ModelRegistry:
35
35
  predictor.fit(..., hyperparameters={"MY_MODEL": ...})
36
36
  ```
37
37
  """
38
+
38
39
  def __init__(self, model_cls_list: list[Type[AbstractModel]] | None = None):
39
40
  if model_cls_list is None:
40
41
  model_cls_list = []
@@ -2,14 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  import copy
4
4
  import os
5
- import pandas as pd
6
5
  import shutil
7
- import sys
8
6
  import subprocess
7
+ import sys
9
8
  import textwrap
10
9
  import uuid
11
10
  from typing import Any, Type
12
11
 
12
+ import numpy as np
13
+ import pandas as pd
14
+
13
15
  from autogluon.common.utils.path_converter import PathConverter
14
16
  from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION
15
17
  from autogluon.core.metrics import METRICS
@@ -17,19 +19,18 @@ from autogluon.core.models import AbstractModel, BaggedEnsembleModel
17
19
  from autogluon.core.stacked_overfitting.utils import check_stacked_overfitting_from_leaderboard
18
20
  from autogluon.core.testing.global_context_snapshot import GlobalContextSnapshot
19
21
  from autogluon.core.utils import download, generate_train_test_split_combined, infer_problem_type, unzip
20
-
21
22
  from autogluon.tabular import TabularDataset, TabularPredictor
22
23
  from autogluon.tabular.testing.generate_datasets import (
23
- generate_toy_binary_dataset,
24
24
  generate_toy_binary_10_dataset,
25
+ generate_toy_binary_dataset,
26
+ generate_toy_multiclass_10_dataset,
27
+ generate_toy_multiclass_30_dataset,
25
28
  generate_toy_multiclass_dataset,
26
- generate_toy_regression_dataset,
29
+ generate_toy_quantile_10_dataset,
27
30
  generate_toy_quantile_dataset,
28
31
  generate_toy_quantile_single_level_dataset,
29
- generate_toy_multiclass_10_dataset,
30
32
  generate_toy_regression_10_dataset,
31
- generate_toy_quantile_10_dataset,
32
- generate_toy_multiclass_30_dataset,
33
+ generate_toy_regression_dataset,
33
34
  )
34
35
 
35
36
 
@@ -154,6 +155,7 @@ class FitHelper:
154
155
  """
155
156
  Helper functions to test and verify predictors and models when fit through TabularPredictor's API.
156
157
  """
158
+
157
159
  @staticmethod
158
160
  def fit_and_validate_dataset(
159
161
  dataset_name: str,
@@ -181,11 +183,14 @@ class FitHelper:
181
183
  deepcopy_fit_args: bool = True,
182
184
  verify_model_seed: bool = False,
183
185
  verify_load_wo_cuda: bool = False,
186
+ verify_single_prediction_equivalent_to_multi: bool = True,
184
187
  ) -> TabularPredictor:
185
188
  if compiler_configs is None:
186
189
  compiler_configs = {}
187
190
  directory_prefix = "./datasets/"
188
- train_data, test_data, dataset_info = DatasetLoaderHelper.load_dataset(name=dataset_name, directory_prefix=directory_prefix)
191
+ train_data, test_data, dataset_info = DatasetLoaderHelper.load_dataset(
192
+ name=dataset_name, directory_prefix=directory_prefix
193
+ )
189
194
  label = dataset_info["label"]
190
195
  problem_type = dataset_info["problem_type"]
191
196
  _init_args = dict(
@@ -234,7 +239,7 @@ class FitHelper:
234
239
  scikit_api=scikit_api,
235
240
  min_cls_count_train=min_cls_count_train,
236
241
  )
237
-
242
+
238
243
  ctx_after = GlobalContextSnapshot.capture()
239
244
  ctx_before.assert_unchanged(ctx_after)
240
245
 
@@ -249,6 +254,26 @@ class FitHelper:
249
254
  if predictor.can_predict_proba:
250
255
  pred_proba = predictor.predict_proba(test_data)
251
256
  predictor.evaluate_predictions(y_true=test_data[label], y_pred=pred_proba)
257
+
258
+ pred_proba_repeat = predictor.predict_proba(test_data)
259
+ are_close = np.isclose(pred_proba, pred_proba_repeat).all()
260
+ if not are_close:
261
+ raise AssertionError(
262
+ "Predictions differ when predicting on the same data multiple times\n"
263
+ f"First Predict:\n{pred_proba}\n"
264
+ f"Second Predict:\n{pred_proba_repeat}\n"
265
+ )
266
+
267
+ pred_proba_1 = predictor.predict_proba(test_data.head(1)) # Verify model can predict on a single sample
268
+ if verify_single_prediction_equivalent_to_multi:
269
+ pred_proba_1_from_multi = pred_proba.head(1)
270
+ are_close = np.isclose(pred_proba_1, pred_proba_1_from_multi).all()
271
+ if not are_close:
272
+ raise AssertionError(
273
+ "Predictions differ when predicting a single sample vs predicting multiple samples\n"
274
+ f"Single Sample:\n{pred_proba_1}\n"
275
+ f"Multi Sample:\n{pred_proba_1_from_multi}\n"
276
+ )
252
277
  else:
253
278
  try:
254
279
  predictor.predict_proba(test_data)
@@ -278,7 +303,9 @@ class FitHelper:
278
303
  model_info = model.get_info()
279
304
  can_refit_full = model._get_tags()["can_refit_full"]
280
305
  if can_refit_full:
281
- assert not model_info["val_in_fit"], f"val data must not be present in refit model if `can_refit_full=True`. Maybe an exception occurred?"
306
+ assert not model_info["val_in_fit"], (
307
+ f"val data must not be present in refit model if `can_refit_full=True`. Maybe an exception occurred?"
308
+ )
282
309
  else:
283
310
  assert model_info["val_in_fit"], f"val data must be present in refit model if `can_refit_full=False`"
284
311
  if verify_model_seed:
@@ -293,7 +320,9 @@ class FitHelper:
293
320
  if extra_info:
294
321
  lb_kwargs["extra_info"] = True
295
322
  lb = predictor.leaderboard(test_data, extra_metrics=extra_metrics, **lb_kwargs)
296
- stacked_overfitting_assert(lb, predictor, expected_stacked_overfitting_at_val, expected_stacked_overfitting_at_test)
323
+ stacked_overfitting_assert(
324
+ lb, predictor, expected_stacked_overfitting_at_val, expected_stacked_overfitting_at_test
325
+ )
297
326
 
298
327
  predictor_load = predictor.load(path=predictor.path)
299
328
  predictor_load.predict(test_data)
@@ -301,6 +330,7 @@ class FitHelper:
301
330
  # TODO: This is expensive, only do this sparingly.
302
331
  if verify_load_wo_cuda:
303
332
  import torch
333
+
304
334
  if torch.cuda.is_available():
305
335
  # Checks if the model is able to predict w/o CUDA.
306
336
  # This verifies that a model artifact works on a CPU machine.
@@ -322,7 +352,9 @@ class FitHelper:
322
352
 
323
353
  assert os.path.realpath(save_path) == os.path.realpath(predictor.path)
324
354
  if delete_directory:
325
- shutil.rmtree(save_path, ignore_errors=True) # Delete AutoGluon output directory to ensure runs' information has been removed.
355
+ shutil.rmtree(
356
+ save_path, ignore_errors=True
357
+ ) # Delete AutoGluon output directory to ensure runs' information has been removed.
326
358
  return predictor
327
359
 
328
360
  @staticmethod
@@ -379,6 +411,7 @@ class FitHelper:
379
411
  raise_on_model_failure: bool = True,
380
412
  problem_types: list[str] | None = None,
381
413
  verify_model_seed: bool = True,
414
+ verify_single_prediction_equivalent_to_multi: bool = True,
382
415
  **kwargs,
383
416
  ):
384
417
  """
@@ -396,6 +429,7 @@ class FitHelper:
396
429
  If specified, checks the given problem_types.
397
430
  If None, checks `model_cls.supported_problem_types()`
398
431
  verify_model_seed: bool = True
432
+ verify_single_prediction_equivalent_to_multi: bool = True
399
433
  **kwargs
400
434
 
401
435
  Returns
@@ -476,6 +510,7 @@ class FitHelper:
476
510
  extra_metrics=_extra_metrics,
477
511
  raise_on_model_failure=raise_on_model_failure,
478
512
  verify_model_seed=verify_model_seed,
513
+ verify_single_prediction_equivalent_to_multi=verify_single_prediction_equivalent_to_multi,
479
514
  **kwargs,
480
515
  )
481
516
 
@@ -508,6 +543,7 @@ class FitHelper:
508
543
  extra_metrics=_extra_metrics,
509
544
  raise_on_model_failure=raise_on_model_failure,
510
545
  verify_model_seed=verify_model_seed,
546
+ verify_single_prediction_equivalent_to_multi=verify_single_prediction_equivalent_to_multi,
511
547
  **kwargs,
512
548
  )
513
549
 
@@ -519,11 +555,15 @@ def stacked_overfitting_assert(
519
555
  expected_stacked_overfitting_at_test: bool | None,
520
556
  ):
521
557
  if expected_stacked_overfitting_at_val is not None:
522
- assert predictor._stacked_overfitting_occurred == expected_stacked_overfitting_at_val, "Expected stacked overfitting at val mismatch!"
558
+ assert predictor._stacked_overfitting_occurred == expected_stacked_overfitting_at_val, (
559
+ "Expected stacked overfitting at val mismatch!"
560
+ )
523
561
 
524
562
  if expected_stacked_overfitting_at_test is not None:
525
563
  stacked_overfitting = check_stacked_overfitting_from_leaderboard(lb)
526
- assert stacked_overfitting == expected_stacked_overfitting_at_test, "Expected stacked overfitting at test mismatch!"
564
+ assert stacked_overfitting == expected_stacked_overfitting_at_test, (
565
+ "Expected stacked overfitting at test mismatch!"
566
+ )
527
567
 
528
568
 
529
569
  def _verify_model_seed(model: AbstractModel):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import pandas as pd
4
4
  from sklearn.datasets import make_blobs
5
5
 
6
- from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION, QUANTILE
6
+ from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REGRESSION
7
7
 
8
8
 
9
9
  def generate_toy_binary_dataset():
@@ -9,7 +9,6 @@ from autogluon.core.data.label_cleaner import LabelCleaner
9
9
  from autogluon.core.models import AbstractModel, BaggedEnsembleModel
10
10
  from autogluon.core.utils import generate_train_test_split, infer_problem_type
11
11
  from autogluon.features.generators import AbstractFeatureGenerator, AutoMLPipelineFeatureGenerator
12
-
13
12
  from autogluon.tabular.testing.fit_helper import FitHelper
14
13
 
15
14
 
@@ -18,6 +17,7 @@ class ModelFitHelper:
18
17
  """
19
18
  Helper functions to test and verify models when fit outside TabularPredictor's API (aka as stand-alone models)
20
19
  """
20
+
21
21
  @staticmethod
22
22
  def fit_and_validate_dataset(
23
23
  dataset_name: str,
@@ -27,7 +27,9 @@ class ModelFitHelper:
27
27
  check_predict_children: bool = False,
28
28
  ) -> AbstractModel:
29
29
  directory_prefix = "./datasets/"
30
- train_data, test_data, dataset_info = FitHelper.load_dataset(name=dataset_name, directory_prefix=directory_prefix)
30
+ train_data, test_data, dataset_info = FitHelper.load_dataset(
31
+ name=dataset_name, directory_prefix=directory_prefix
32
+ )
31
33
  label = dataset_info["label"]
32
34
  model, label_cleaner, feature_generator = ModelFitHelper.fit_dataset(
33
35
  train_data=train_data, model=model, label=label, fit_args=fit_args, sample_size=sample_size
@@ -39,10 +41,14 @@ class ModelFitHelper:
39
41
  X_test = feature_generator.transform(X_test)
40
42
 
41
43
  y_pred = model.predict(X_test)
42
- assert isinstance(y_pred, np.ndarray), f"Expected np.ndarray as model.predict(X_test) output. Got: {y_pred.__class__}"
44
+ assert isinstance(y_pred, np.ndarray), (
45
+ f"Expected np.ndarray as model.predict(X_test) output. Got: {y_pred.__class__}"
46
+ )
43
47
 
44
48
  y_pred_proba = model.predict_proba(X_test)
45
- assert isinstance(y_pred_proba, np.ndarray), f"Expected np.ndarray as model.predict_proba(X_test) output. Got: {y_pred.__class__}"
49
+ assert isinstance(y_pred_proba, np.ndarray), (
50
+ f"Expected np.ndarray as model.predict_proba(X_test) output. Got: {y_pred.__class__}"
51
+ )
46
52
  model.get_info()
47
53
 
48
54
  if check_predict_children: