autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/METADATA +27 -27
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260117-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/zip-safe +0 -0
@@ -1,2 +1,422 @@
1
1
  # State-of-the-art for datasets < 100k samples. Requires a GPU with at least 20 GB VRAM.
2
- hyperparameter_portfolio_zeroshot_gpu_2025_12_18 = {'TABDPT': [{'ag_args': {'name_suffix': '_c1', 'priority': -3}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}}, {'ag_args': {'name_suffix': '_r20', 'priority': -5}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 8, 'feature_reduction': 'subsample', 'missing_indicators': False, 'normalizer': 'quantile-uniform', 'permute_classes': False, 'temperature': 0.5}, {'ag_args': {'name_suffix': '_r1', 'priority': -7}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 16, 'feature_reduction': 'subsample', 'missing_indicators': False, 'normalizer': 'log1p', 'permute_classes': False, 'temperature': 0.5}, {'ag_args': {'name_suffix': '_r15', 'priority': -9}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 16, 'feature_reduction': 'subsample', 'missing_indicators': False, 'normalizer': 'standard', 'permute_classes': True, 'temperature': 0.7}, {'ag_args': {'name_suffix': '_r22', 'priority': -11}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 8, 'feature_reduction': 'pca', 'missing_indicators': True, 'normalizer': 'robust', 'permute_classes': False, 'temperature': 0.5}], 'TABICL': [{'ag_args': {'name_suffix': '_c1', 'priority': -4}, 'ag_args_ensemble': {'refit_folds': True}}], 'MITRA': [{'ag_args': {'name_suffix': '_c1', 'priority': -12}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}}], 'TABM': [{'ag_args': {'name_suffix': '_r99', 'priority': -13}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 880, 'd_embedding': 24, 'dropout': 0.10792355695428629, 'gradient_clipping_norm': 1.0, 'lr': 0.0013641856391615784, 'n_blocks': 5, 'num_emb_n_bins': 16, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0}, {'ag_args': {'name_suffix': '_r124', 'priority': -17}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 208, 'd_embedding': 16, 'dropout': 0.0, 'gradient_clipping_norm': 1.0, 'lr': 0.00042152744054701374, 'n_blocks': 2, 'num_emb_n_bins': 109, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.00014007839435474664}, {'ag_args': {'name_suffix': '_r69', 'priority': -21}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 848, 'd_embedding': 28, 'dropout': 0.40215621636031007, 'gradient_clipping_norm': 1.0, 'lr': 0.0010413640454559532, 'n_blocks': 3, 'num_emb_n_bins': 18, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0}, {'ag_args': {'name_suffix': '_r184', 'priority': -24}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 864, 'd_embedding': 24, 'dropout': 0.0, 'gradient_clipping_norm': 1.0, 'lr': 0.0019256819924656217, 'n_blocks': 3, 'num_emb_n_bins': 3, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0}, {'ag_args': {'name_suffix': '_r34', 'priority': -26}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 896, 'd_embedding': 8, 'dropout': 0.0, 'gradient_clipping_norm': 1.0, 'lr': 0.002459175026451607, 'n_blocks': 4, 'num_emb_n_bins': 104, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0006299584388562901}], 'GBM_PREP': [{'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r13', 'priority': -14}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9923026236907, 'bagging_freq': 1, 'cat_l2': 0.014290368488, 'cat_smooth': 1.8662939903973, 'extra_trees': True, 'feature_fraction': 0.5533919718605, 'lambda_l1': 0.914411672958, 'lambda_l2': 1.90439560009, 'learning_rate': 0.0193225778401, 'max_cat_to_onehot': 18, 'min_data_in_leaf': 28, 'min_data_per_group': 54, 'num_leaves': 64}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r41', 'priority': -16}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7215411996558, 'bagging_freq': 1, 'cat_l2': 1.887369154362, 'cat_smooth': 0.0278693980873, 'extra_trees': True, 'feature_fraction': 0.4247583287144, 'lambda_l1': 0.1129800247772, 'lambda_l2': 0.2623265718536, 'learning_rate': 0.0074201920651, 'max_cat_to_onehot': 9, 'min_data_in_leaf': 15, 'min_data_per_group': 10, 'num_leaves': 8}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r31', 'priority': -18}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9591526242875, 'bagging_freq': 1, 'cat_l2': 1.8962346412823, 'cat_smooth': 0.0215219089995, 'extra_trees': False, 'feature_fraction': 0.5791844062459, 'lambda_l1': 0.938461750637, 'lambda_l2': 0.9899852075056, 'learning_rate': 0.0397613094741, 'max_cat_to_onehot': 27, 'min_data_in_leaf': 1, 'min_data_per_group': 39, 'num_leaves': 16}, {'ag.prep_params': [], 'ag_args': {'name_suffix': '_r21', 'priority': -20}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7111549514262, 'bagging_freq': 1, 'cat_l2': 0.8679131150136, 'cat_smooth': 48.7244965504817, 'extra_trees': False, 'feature_fraction': 0.425140839263, 'lambda_l1': 0.5140528525242, 'lambda_l2': 0.5134051978198, 'learning_rate': 0.0134375321277, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 2, 'min_data_per_group': 32, 'num_leaves': 20}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r17', 'priority': -23}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9277474245702, 'bagging_freq': 1, 'cat_l2': 0.0731876168104, 'cat_smooth': 0.1369210915339, 'extra_trees': False, 'feature_fraction': 0.6680440910385, 'lambda_l1': 0.0125057410295, 'lambda_l2': 0.7157181359874, 'learning_rate': 0.0351342879995, 'max_cat_to_onehot': 20, 'min_data_in_leaf': 1, 'min_data_per_group': 2, 'num_leaves': 64}], 'CAT': [{'ag_args': {'name_suffix': '_c1', 'priority': -15}}], 'GBM': [{'ag_args': {'name_suffix': '_r73', 'priority': -19}, 'bagging_fraction': 0.7295548973583, 'bagging_freq': 1, 'cat_l2': 1.8025485263237, 'cat_smooth': 59.6178463268351, 'extra_trees': False, 'feature_fraction': 0.8242607305914, 'lambda_l1': 0.7265522905459, 'lambda_l2': 0.3492160682092, 'learning_rate': 0.0068803786367, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 1, 'min_data_per_group': 10, 'num_leaves': 24}, {'ag_args': {'name_suffix': '_r37', 'priority': -22}, 'bagging_fraction': 0.8096374561947, 'bagging_freq': 1, 'cat_l2': 1.6385754694703, 'cat_smooth': 16.1922506671724, 'extra_trees': True, 'feature_fraction': 0.885927003286, 'lambda_l1': 0.0430386950502, 'lambda_l2': 0.2507506811761, 'learning_rate': 0.0079622660542, 'max_cat_to_onehot': 23, 'min_data_in_leaf': 7, 'min_data_per_group': 49, 'num_leaves': 6}, {'ag_args': {'name_suffix': '_r162', 'priority': -25}, 'bagging_fraction': 0.7552878818396, 'bagging_freq': 1, 'cat_l2': 0.0081083103544, 'cat_smooth': 75.7373446363438, 'extra_trees': False, 'feature_fraction': 0.6171258454584, 'lambda_l1': 0.1071522383181, 'lambda_l2': 1.7882554584069, 'learning_rate': 0.0229328987255, 'max_cat_to_onehot': 24, 'min_data_in_leaf': 23, 'min_data_per_group': 2, 'num_leaves': 125}, {'ag_args': {'name_suffix': '_r57', 'priority': -27}, 'bagging_fraction': 0.8515739264605, 'bagging_freq': 1, 'cat_l2': 0.2263901847144, 'cat_smooth': 1.7397457971767, 'extra_trees': True, 'feature_fraction': 0.6284015946887, 'lambda_l1': 0.6935431676756, 'lambda_l2': 1.7605230133162, 'learning_rate': 0.0294830579218, 'max_cat_to_onehot': 52, 'min_data_in_leaf': 8, 'min_data_per_group': 3, 'num_leaves': 43}, {'ag_args': {'name_suffix': '_r33', 'priority': -28}, 'bagging_fraction': 0.9625293420216, 'bagging_freq': 1, 'cat_l2': 0.1236875455555, 'cat_smooth': 68.8584757332856, 'extra_trees': False, 'feature_fraction': 0.6189215809382, 'lambda_l1': 0.1641757352921, 'lambda_l2': 0.6937755557881, 'learning_rate': 0.0154031028561, 'max_cat_to_onehot': 17, 'min_data_in_leaf': 1, 'min_data_per_group': 30, 'num_leaves': 68}], 'REALTABPFN-V2': [{'ag_args': {'name_suffix': '_r13', 'priority': -1}, 'ag_args_ensemble': {'model_random_seed': 104, 'vary_seed_across_folds': True}, 'balance_probabilities': False, 'inference_config/OUTLIER_REMOVAL_STD': 6, 'inference_config/POLYNOMIAL_FEATURES': 'no', 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': [None, 'safepower'], 'preprocessing/append_original': False, 'preprocessing/categoricals': 'numeric', 'preprocessing/global': None, 'preprocessing/scaling': ['squashing_scaler_default', 'quantile_uni_coarse'], 'softmax_temperature': 1.0, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_r106', 'priority': -2}, 'ag_args_ensemble': {'model_random_seed': 848, 'vary_seed_across_folds': True}, 'balance_probabilities': False, 'inference_config/OUTLIER_REMOVAL_STD': 6, 'inference_config/POLYNOMIAL_FEATURES': 'no', 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': [None], 'preprocessing/append_original': True, 'preprocessing/categoricals': 'numeric', 'preprocessing/global': 'svd_quarter_components', 'preprocessing/scaling': ['quantile_uni_coarse'], 'softmax_temperature': 0.8, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_r11', 'priority': -6}, 'ag_args_ensemble': {'model_random_seed': 88, 'vary_seed_across_folds': True}, 'balance_probabilities': True, 'inference_config/OUTLIER_REMOVAL_STD': 6, 'inference_config/POLYNOMIAL_FEATURES': 25, 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': [None], 'preprocessing/append_original': True, 'preprocessing/categoricals': 'onehot', 'preprocessing/global': 'svd_quarter_components', 'preprocessing/scaling': ['safepower', 'quantile_uni'], 'softmax_temperature': 0.7, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_c1', 'priority': -8}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_r196', 'priority': -10}, 'ag_args_ensemble': {'model_random_seed': 1568, 'vary_seed_across_folds': True}, 'balance_probabilities': False, 'inference_config/OUTLIER_REMOVAL_STD': 12, 'inference_config/POLYNOMIAL_FEATURES': 'no', 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': ['kdi_alpha_1.0'], 'preprocessing/append_original': False, 'preprocessing/categoricals': 'numeric', 'preprocessing/global': None, 'preprocessing/scaling': ['squashing_scaler_default'], 'softmax_temperature': 1.25, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}]}
2
+ hyperparameter_portfolio_zeroshot_gpu_2025_12_18 = {
3
+ "TABDPT": [
4
+ {
5
+ "ag_args": {"name_suffix": "_c1", "priority": -3},
6
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": False},
7
+ },
8
+ {
9
+ "ag_args": {"name_suffix": "_r20", "priority": -5},
10
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": False},
11
+ "clip_sigma": 8,
12
+ "feature_reduction": "subsample",
13
+ "missing_indicators": False,
14
+ "normalizer": "quantile-uniform",
15
+ "permute_classes": False,
16
+ "temperature": 0.5,
17
+ },
18
+ {
19
+ "ag_args": {"name_suffix": "_r1", "priority": -7},
20
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": False},
21
+ "clip_sigma": 16,
22
+ "feature_reduction": "subsample",
23
+ "missing_indicators": False,
24
+ "normalizer": "log1p",
25
+ "permute_classes": False,
26
+ "temperature": 0.5,
27
+ },
28
+ {
29
+ "ag_args": {"name_suffix": "_r15", "priority": -9},
30
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": False},
31
+ "clip_sigma": 16,
32
+ "feature_reduction": "subsample",
33
+ "missing_indicators": False,
34
+ "normalizer": "standard",
35
+ "permute_classes": True,
36
+ "temperature": 0.7,
37
+ },
38
+ {
39
+ "ag_args": {"name_suffix": "_r22", "priority": -11},
40
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": False},
41
+ "clip_sigma": 8,
42
+ "feature_reduction": "pca",
43
+ "missing_indicators": True,
44
+ "normalizer": "robust",
45
+ "permute_classes": False,
46
+ "temperature": 0.5,
47
+ },
48
+ ],
49
+ "TABICL": [{"ag_args": {"name_suffix": "_c1", "priority": -4}, "ag_args_ensemble": {"refit_folds": True}}],
50
+ "MITRA": [
51
+ {
52
+ "ag_args": {"name_suffix": "_c1", "priority": -12},
53
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
54
+ }
55
+ ],
56
+ "TABM": [
57
+ {
58
+ "ag_args": {"name_suffix": "_r99", "priority": -13},
59
+ "amp": False,
60
+ "arch_type": "tabm-mini",
61
+ "batch_size": "auto",
62
+ "d_block": 880,
63
+ "d_embedding": 24,
64
+ "dropout": 0.10792355695428629,
65
+ "gradient_clipping_norm": 1.0,
66
+ "lr": 0.0013641856391615784,
67
+ "n_blocks": 5,
68
+ "num_emb_n_bins": 16,
69
+ "num_emb_type": "pwl",
70
+ "patience": 16,
71
+ "share_training_batches": False,
72
+ "tabm_k": 32,
73
+ "weight_decay": 0.0,
74
+ },
75
+ {
76
+ "ag_args": {"name_suffix": "_r124", "priority": -17},
77
+ "amp": False,
78
+ "arch_type": "tabm-mini",
79
+ "batch_size": "auto",
80
+ "d_block": 208,
81
+ "d_embedding": 16,
82
+ "dropout": 0.0,
83
+ "gradient_clipping_norm": 1.0,
84
+ "lr": 0.00042152744054701374,
85
+ "n_blocks": 2,
86
+ "num_emb_n_bins": 109,
87
+ "num_emb_type": "pwl",
88
+ "patience": 16,
89
+ "share_training_batches": False,
90
+ "tabm_k": 32,
91
+ "weight_decay": 0.00014007839435474664,
92
+ },
93
+ {
94
+ "ag_args": {"name_suffix": "_r69", "priority": -21},
95
+ "amp": False,
96
+ "arch_type": "tabm-mini",
97
+ "batch_size": "auto",
98
+ "d_block": 848,
99
+ "d_embedding": 28,
100
+ "dropout": 0.40215621636031007,
101
+ "gradient_clipping_norm": 1.0,
102
+ "lr": 0.0010413640454559532,
103
+ "n_blocks": 3,
104
+ "num_emb_n_bins": 18,
105
+ "num_emb_type": "pwl",
106
+ "patience": 16,
107
+ "share_training_batches": False,
108
+ "tabm_k": 32,
109
+ "weight_decay": 0.0,
110
+ },
111
+ {
112
+ "ag_args": {"name_suffix": "_r184", "priority": -24},
113
+ "amp": False,
114
+ "arch_type": "tabm-mini",
115
+ "batch_size": "auto",
116
+ "d_block": 864,
117
+ "d_embedding": 24,
118
+ "dropout": 0.0,
119
+ "gradient_clipping_norm": 1.0,
120
+ "lr": 0.0019256819924656217,
121
+ "n_blocks": 3,
122
+ "num_emb_n_bins": 3,
123
+ "num_emb_type": "pwl",
124
+ "patience": 16,
125
+ "share_training_batches": False,
126
+ "tabm_k": 32,
127
+ "weight_decay": 0.0,
128
+ },
129
+ {
130
+ "ag_args": {"name_suffix": "_r34", "priority": -26},
131
+ "amp": False,
132
+ "arch_type": "tabm-mini",
133
+ "batch_size": "auto",
134
+ "d_block": 896,
135
+ "d_embedding": 8,
136
+ "dropout": 0.0,
137
+ "gradient_clipping_norm": 1.0,
138
+ "lr": 0.002459175026451607,
139
+ "n_blocks": 4,
140
+ "num_emb_n_bins": 104,
141
+ "num_emb_type": "pwl",
142
+ "patience": 16,
143
+ "share_training_batches": False,
144
+ "tabm_k": 32,
145
+ "weight_decay": 0.0006299584388562901,
146
+ },
147
+ ],
148
+ "GBM_PREP": [
149
+ {
150
+ "ag.prep_params": [
151
+ [
152
+ [["ArithmeticFeatureGenerator", {}]],
153
+ [
154
+ ["CategoricalInteractionFeatureGenerator", {"passthrough": True}],
155
+ ["OOFTargetEncodingFeatureGenerator", {}],
156
+ ],
157
+ ]
158
+ ],
159
+ "ag.prep_params.passthrough_types": {"invalid_raw_types": ["category", "object"]},
160
+ "ag_args": {"name_suffix": "_r13", "priority": -14},
161
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
162
+ "bagging_fraction": 0.9923026236907,
163
+ "bagging_freq": 1,
164
+ "cat_l2": 0.014290368488,
165
+ "cat_smooth": 1.8662939903973,
166
+ "extra_trees": True,
167
+ "feature_fraction": 0.5533919718605,
168
+ "lambda_l1": 0.914411672958,
169
+ "lambda_l2": 1.90439560009,
170
+ "learning_rate": 0.0193225778401,
171
+ "max_cat_to_onehot": 18,
172
+ "min_data_in_leaf": 28,
173
+ "min_data_per_group": 54,
174
+ "num_leaves": 64,
175
+ },
176
+ {
177
+ "ag.prep_params": [
178
+ [
179
+ [["ArithmeticFeatureGenerator", {}]],
180
+ [
181
+ ["CategoricalInteractionFeatureGenerator", {"passthrough": True}],
182
+ ["OOFTargetEncodingFeatureGenerator", {}],
183
+ ],
184
+ ]
185
+ ],
186
+ "ag.prep_params.passthrough_types": {"invalid_raw_types": ["category", "object"]},
187
+ "ag_args": {"name_suffix": "_r41", "priority": -16},
188
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
189
+ "bagging_fraction": 0.7215411996558,
190
+ "bagging_freq": 1,
191
+ "cat_l2": 1.887369154362,
192
+ "cat_smooth": 0.0278693980873,
193
+ "extra_trees": True,
194
+ "feature_fraction": 0.4247583287144,
195
+ "lambda_l1": 0.1129800247772,
196
+ "lambda_l2": 0.2623265718536,
197
+ "learning_rate": 0.0074201920651,
198
+ "max_cat_to_onehot": 9,
199
+ "min_data_in_leaf": 15,
200
+ "min_data_per_group": 10,
201
+ "num_leaves": 8,
202
+ },
203
+ {
204
+ "ag.prep_params": [
205
+ [
206
+ [["ArithmeticFeatureGenerator", {}]],
207
+ [
208
+ ["CategoricalInteractionFeatureGenerator", {"passthrough": True}],
209
+ ["OOFTargetEncodingFeatureGenerator", {}],
210
+ ],
211
+ ]
212
+ ],
213
+ "ag.prep_params.passthrough_types": {"invalid_raw_types": ["category", "object"]},
214
+ "ag_args": {"name_suffix": "_r31", "priority": -18},
215
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
216
+ "bagging_fraction": 0.9591526242875,
217
+ "bagging_freq": 1,
218
+ "cat_l2": 1.8962346412823,
219
+ "cat_smooth": 0.0215219089995,
220
+ "extra_trees": False,
221
+ "feature_fraction": 0.5791844062459,
222
+ "lambda_l1": 0.938461750637,
223
+ "lambda_l2": 0.9899852075056,
224
+ "learning_rate": 0.0397613094741,
225
+ "max_cat_to_onehot": 27,
226
+ "min_data_in_leaf": 1,
227
+ "min_data_per_group": 39,
228
+ "num_leaves": 16,
229
+ },
230
+ {
231
+ "ag.prep_params": [],
232
+ "ag_args": {"name_suffix": "_r21", "priority": -20},
233
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
234
+ "bagging_fraction": 0.7111549514262,
235
+ "bagging_freq": 1,
236
+ "cat_l2": 0.8679131150136,
237
+ "cat_smooth": 48.7244965504817,
238
+ "extra_trees": False,
239
+ "feature_fraction": 0.425140839263,
240
+ "lambda_l1": 0.5140528525242,
241
+ "lambda_l2": 0.5134051978198,
242
+ "learning_rate": 0.0134375321277,
243
+ "max_cat_to_onehot": 16,
244
+ "min_data_in_leaf": 2,
245
+ "min_data_per_group": 32,
246
+ "num_leaves": 20,
247
+ },
248
+ {
249
+ "ag.prep_params": [
250
+ [
251
+ [["ArithmeticFeatureGenerator", {}]],
252
+ [
253
+ ["CategoricalInteractionFeatureGenerator", {"passthrough": True}],
254
+ ["OOFTargetEncodingFeatureGenerator", {}],
255
+ ],
256
+ ]
257
+ ],
258
+ "ag.prep_params.passthrough_types": {"invalid_raw_types": ["category", "object"]},
259
+ "ag_args": {"name_suffix": "_r17", "priority": -23},
260
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
261
+ "bagging_fraction": 0.9277474245702,
262
+ "bagging_freq": 1,
263
+ "cat_l2": 0.0731876168104,
264
+ "cat_smooth": 0.1369210915339,
265
+ "extra_trees": False,
266
+ "feature_fraction": 0.6680440910385,
267
+ "lambda_l1": 0.0125057410295,
268
+ "lambda_l2": 0.7157181359874,
269
+ "learning_rate": 0.0351342879995,
270
+ "max_cat_to_onehot": 20,
271
+ "min_data_in_leaf": 1,
272
+ "min_data_per_group": 2,
273
+ "num_leaves": 64,
274
+ },
275
+ ],
276
+ "CAT": [{"ag_args": {"name_suffix": "_c1", "priority": -15}}],
277
+ "GBM": [
278
+ {
279
+ "ag_args": {"name_suffix": "_r73", "priority": -19},
280
+ "bagging_fraction": 0.7295548973583,
281
+ "bagging_freq": 1,
282
+ "cat_l2": 1.8025485263237,
283
+ "cat_smooth": 59.6178463268351,
284
+ "extra_trees": False,
285
+ "feature_fraction": 0.8242607305914,
286
+ "lambda_l1": 0.7265522905459,
287
+ "lambda_l2": 0.3492160682092,
288
+ "learning_rate": 0.0068803786367,
289
+ "max_cat_to_onehot": 16,
290
+ "min_data_in_leaf": 1,
291
+ "min_data_per_group": 10,
292
+ "num_leaves": 24,
293
+ },
294
+ {
295
+ "ag_args": {"name_suffix": "_r37", "priority": -22},
296
+ "bagging_fraction": 0.8096374561947,
297
+ "bagging_freq": 1,
298
+ "cat_l2": 1.6385754694703,
299
+ "cat_smooth": 16.1922506671724,
300
+ "extra_trees": True,
301
+ "feature_fraction": 0.885927003286,
302
+ "lambda_l1": 0.0430386950502,
303
+ "lambda_l2": 0.2507506811761,
304
+ "learning_rate": 0.0079622660542,
305
+ "max_cat_to_onehot": 23,
306
+ "min_data_in_leaf": 7,
307
+ "min_data_per_group": 49,
308
+ "num_leaves": 6,
309
+ },
310
+ {
311
+ "ag_args": {"name_suffix": "_r162", "priority": -25},
312
+ "bagging_fraction": 0.7552878818396,
313
+ "bagging_freq": 1,
314
+ "cat_l2": 0.0081083103544,
315
+ "cat_smooth": 75.7373446363438,
316
+ "extra_trees": False,
317
+ "feature_fraction": 0.6171258454584,
318
+ "lambda_l1": 0.1071522383181,
319
+ "lambda_l2": 1.7882554584069,
320
+ "learning_rate": 0.0229328987255,
321
+ "max_cat_to_onehot": 24,
322
+ "min_data_in_leaf": 23,
323
+ "min_data_per_group": 2,
324
+ "num_leaves": 125,
325
+ },
326
+ {
327
+ "ag_args": {"name_suffix": "_r57", "priority": -27},
328
+ "bagging_fraction": 0.8515739264605,
329
+ "bagging_freq": 1,
330
+ "cat_l2": 0.2263901847144,
331
+ "cat_smooth": 1.7397457971767,
332
+ "extra_trees": True,
333
+ "feature_fraction": 0.6284015946887,
334
+ "lambda_l1": 0.6935431676756,
335
+ "lambda_l2": 1.7605230133162,
336
+ "learning_rate": 0.0294830579218,
337
+ "max_cat_to_onehot": 52,
338
+ "min_data_in_leaf": 8,
339
+ "min_data_per_group": 3,
340
+ "num_leaves": 43,
341
+ },
342
+ {
343
+ "ag_args": {"name_suffix": "_r33", "priority": -28},
344
+ "bagging_fraction": 0.9625293420216,
345
+ "bagging_freq": 1,
346
+ "cat_l2": 0.1236875455555,
347
+ "cat_smooth": 68.8584757332856,
348
+ "extra_trees": False,
349
+ "feature_fraction": 0.6189215809382,
350
+ "lambda_l1": 0.1641757352921,
351
+ "lambda_l2": 0.6937755557881,
352
+ "learning_rate": 0.0154031028561,
353
+ "max_cat_to_onehot": 17,
354
+ "min_data_in_leaf": 1,
355
+ "min_data_per_group": 30,
356
+ "num_leaves": 68,
357
+ },
358
+ ],
359
+ "REALTABPFN-V2": [
360
+ {
361
+ "ag_args": {"name_suffix": "_r13", "priority": -1},
362
+ "ag_args_ensemble": {"model_random_seed": 104, "vary_seed_across_folds": True},
363
+ "balance_probabilities": False,
364
+ "inference_config/OUTLIER_REMOVAL_STD": 6,
365
+ "inference_config/POLYNOMIAL_FEATURES": "no",
366
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, "safepower"],
367
+ "preprocessing/append_original": False,
368
+ "preprocessing/categoricals": "numeric",
369
+ "preprocessing/global": None,
370
+ "preprocessing/scaling": ["squashing_scaler_default", "quantile_uni_coarse"],
371
+ "softmax_temperature": 1.0,
372
+ "zip_model_path": ["tabpfn-v2-classifier-finetuned-zk73skhh.ckpt", "tabpfn-v2-regressor-v2_default.ckpt"],
373
+ },
374
+ {
375
+ "ag_args": {"name_suffix": "_r106", "priority": -2},
376
+ "ag_args_ensemble": {"model_random_seed": 848, "vary_seed_across_folds": True},
377
+ "balance_probabilities": False,
378
+ "inference_config/OUTLIER_REMOVAL_STD": 6,
379
+ "inference_config/POLYNOMIAL_FEATURES": "no",
380
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
381
+ "preprocessing/append_original": True,
382
+ "preprocessing/categoricals": "numeric",
383
+ "preprocessing/global": "svd_quarter_components",
384
+ "preprocessing/scaling": ["quantile_uni_coarse"],
385
+ "softmax_temperature": 0.8,
386
+ "zip_model_path": ["tabpfn-v2-classifier-finetuned-zk73skhh.ckpt", "tabpfn-v2-regressor-v2_default.ckpt"],
387
+ },
388
+ {
389
+ "ag_args": {"name_suffix": "_r11", "priority": -6},
390
+ "ag_args_ensemble": {"model_random_seed": 88, "vary_seed_across_folds": True},
391
+ "balance_probabilities": True,
392
+ "inference_config/OUTLIER_REMOVAL_STD": 6,
393
+ "inference_config/POLYNOMIAL_FEATURES": 25,
394
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
395
+ "preprocessing/append_original": True,
396
+ "preprocessing/categoricals": "onehot",
397
+ "preprocessing/global": "svd_quarter_components",
398
+ "preprocessing/scaling": ["safepower", "quantile_uni"],
399
+ "softmax_temperature": 0.7,
400
+ "zip_model_path": ["tabpfn-v2-classifier-finetuned-zk73skhh.ckpt", "tabpfn-v2-regressor-v2_default.ckpt"],
401
+ },
402
+ {
403
+ "ag_args": {"name_suffix": "_c1", "priority": -8},
404
+ "ag_args_ensemble": {"model_random_seed": 0, "vary_seed_across_folds": True},
405
+ "zip_model_path": ["tabpfn-v2-classifier-finetuned-zk73skhh.ckpt", "tabpfn-v2-regressor-v2_default.ckpt"],
406
+ },
407
+ {
408
+ "ag_args": {"name_suffix": "_r196", "priority": -10},
409
+ "ag_args_ensemble": {"model_random_seed": 1568, "vary_seed_across_folds": True},
410
+ "balance_probabilities": False,
411
+ "inference_config/OUTLIER_REMOVAL_STD": 12,
412
+ "inference_config/POLYNOMIAL_FEATURES": "no",
413
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ["kdi_alpha_1.0"],
414
+ "preprocessing/append_original": False,
415
+ "preprocessing/categoricals": "numeric",
416
+ "preprocessing/global": None,
417
+ "preprocessing/scaling": ["squashing_scaler_default"],
418
+ "softmax_temperature": 1.25,
419
+ "zip_model_path": ["tabpfn-v2-classifier-finetuned-zk73skhh.ckpt", "tabpfn-v2-regressor-v2_default.ckpt"],
420
+ },
421
+ ],
422
+ }
@@ -44,13 +44,17 @@ class ScikitMixin:
44
44
  # Input validation
45
45
  X = check_array(X)
46
46
  if X.shape[1] != self.n_features_in_:
47
- raise ValueError(f"Inconsistent number of features between fit and predict calls: ({self.n_features_in_}, {X.shape[1]})")
47
+ raise ValueError(
48
+ f"Inconsistent number of features between fit and predict calls: ({self.n_features_in_}, {X.shape[1]})"
49
+ )
48
50
  return X
49
51
 
50
52
  def _combine_X_y(self, X, y) -> pd.DataFrame:
51
53
  label = self.predictor_.label
52
54
  X = pd.DataFrame(X)
53
- assert label not in list(X.columns), f"Cannot have column named {label}. Please rename the column to a different value."
55
+ assert label not in list(X.columns), (
56
+ f"Cannot have column named {label}. Please rename the column to a different value."
57
+ )
54
58
  X[label] = y
55
59
  return X
56
60
 
@@ -65,7 +65,9 @@ class TabularClassifier(BaseEstimator, ClassifierMixin, ScikitMixin):
65
65
  # Input validation
66
66
  X = check_array(X)
67
67
  if X.shape[1] != self.n_features_in_:
68
- raise ValueError(f"Inconsistent number of features between fit and predict calls: ({self.n_features_in_}, {X.shape[1]})")
68
+ raise ValueError(
69
+ f"Inconsistent number of features between fit and predict calls: ({self.n_features_in_}, {X.shape[1]})"
70
+ )
69
71
 
70
72
  data = pd.DataFrame(X)
71
73
  y_pred = self.predictor_.predict(data=data).to_numpy()
@@ -55,7 +55,9 @@ class TabularRegressor(BaseEstimator, RegressorMixin, ScikitMixin):
55
55
  # Input validation
56
56
  X = check_array(X)
57
57
  if X.shape[1] != self.n_features_in_:
58
- raise ValueError(f"Inconsistent number of features between fit and predict calls: ({self.n_features_in_}, {X.shape[1]})")
58
+ raise ValueError(
59
+ f"Inconsistent number of features between fit and predict calls: ({self.n_features_in_}, {X.shape[1]})"
60
+ )
59
61
 
60
62
  data = pd.DataFrame(X)
61
63
  y_pred = self.predictor_.predict(data=data).to_numpy()
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- import pandas as pd
4
3
  import matplotlib.pyplot as plt
4
+ import pandas as pd
5
5
  from matplotlib.figure import Figure
6
6
 
7
7
  from autogluon.tabular import TabularPredictor
@@ -9,7 +9,7 @@ from autogluon.tabular import TabularPredictor
9
9
 
10
10
  def _cumulative_min_idx(series: pd.Series) -> pd.Series:
11
11
  """
12
-
12
+
13
13
  Parameters
14
14
  ----------
15
15
  series: pd.Series
@@ -20,7 +20,7 @@ def _cumulative_min_idx(series: pd.Series) -> pd.Series:
20
20
  The index of the cumulative min of the series values.
21
21
 
22
22
  """
23
- min_val = float('inf')
23
+ min_val = float("inf")
24
24
  min_index = -1
25
25
  result = []
26
26
  for i, val in enumerate(series):
@@ -54,7 +54,9 @@ def compute_cumulative_leaderboard_stats(leaderboard: pd.DataFrame) -> pd.DataFr
54
54
  leaderboard["time_so_far"] = leaderboard["fit_time_marginal"].cumsum()
55
55
  leaderboard["metric_error_val_so_far"] = leaderboard["best_model_so_far"].map(leaderboard["metric_error_val"])
56
56
  if "metric_error_test" in leaderboard:
57
- leaderboard["metric_error_test_so_far"] = leaderboard["best_model_so_far"].map(leaderboard["metric_error_test"])
57
+ leaderboard["metric_error_test_so_far"] = leaderboard["best_model_so_far"].map(
58
+ leaderboard["metric_error_test"]
59
+ )
58
60
  leaderboard = leaderboard.reset_index(drop=False).set_index("fit_order")
59
61
  return leaderboard
60
62
 
@@ -88,7 +90,7 @@ def compute_cumulative_leaderboard_stats_ensemble(
88
90
  model_fit_order = list(leaderboard_stats["model"])
89
91
  ens_names = []
90
92
  for i in range(len(model_fit_order)):
91
- models_to_ens = model_fit_order[:i + 1]
93
+ models_to_ens = model_fit_order[: i + 1]
92
94
  ens_name = predictor.fit_weighted_ensemble(base_models=models_to_ens, name_suffix=f"_fit_{i + 1}")[0]
93
95
  ens_names.append(ens_name)
94
96
 
@@ -144,10 +146,14 @@ def plot_leaderboard_from_predictor(
144
146
  """
145
147
  leaderboard = predictor.leaderboard(test_data, score_format="error", display=False)
146
148
  if ensemble:
147
- leaderboard_order_sorted = compute_cumulative_leaderboard_stats_ensemble(leaderboard=leaderboard, test_data=test_data, predictor=predictor)
149
+ leaderboard_order_sorted = compute_cumulative_leaderboard_stats_ensemble(
150
+ leaderboard=leaderboard, test_data=test_data, predictor=predictor
151
+ )
148
152
  else:
149
153
  leaderboard_order_sorted = compute_cumulative_leaderboard_stats(leaderboard=leaderboard)
150
- return plot_leaderboard(leaderboard=leaderboard_order_sorted, preprocess=False, ensemble=ensemble, include_val=include_val)
154
+ return plot_leaderboard(
155
+ leaderboard=leaderboard_order_sorted, preprocess=False, ensemble=ensemble, include_val=include_val
156
+ )
151
157
 
152
158
 
153
159
  def plot_leaderboard(
@@ -198,36 +204,84 @@ def plot_leaderboard(
198
204
 
199
205
  # TODO: View on inference time, can take from ensemble model, 3rd dimension, color?
200
206
  fig, axes = plt.subplots(1, 2, sharey=True)
201
- fig.suptitle('AutoGluon Metric Error Over Time')
207
+ fig.suptitle("AutoGluon Metric Error Over Time")
202
208
 
203
209
  ax = axes[0]
204
210
 
205
211
  if include_test:
206
- ax.plot(leaderboard_order_sorted.index, leaderboard_order_sorted["metric_error_test_so_far"].values, '-', color="b", label="test")
212
+ ax.plot(
213
+ leaderboard_order_sorted.index,
214
+ leaderboard_order_sorted["metric_error_test_so_far"].values,
215
+ "-",
216
+ color="b",
217
+ label="test",
218
+ )
207
219
  if include_val:
208
- ax.plot(leaderboard_order_sorted.index, leaderboard_order_sorted["metric_error_val_so_far"].values, '-', color="orange", label="val")
220
+ ax.plot(
221
+ leaderboard_order_sorted.index,
222
+ leaderboard_order_sorted["metric_error_val_so_far"].values,
223
+ "-",
224
+ color="orange",
225
+ label="val",
226
+ )
209
227
  if ensemble:
210
228
  if include_test:
211
- ax.plot(leaderboard_order_sorted.index, leaderboard_order_sorted["metric_error_test_so_far_ens"].values, '--', color="b", label="test (ens)")
229
+ ax.plot(
230
+ leaderboard_order_sorted.index,
231
+ leaderboard_order_sorted["metric_error_test_so_far_ens"].values,
232
+ "--",
233
+ color="b",
234
+ label="test (ens)",
235
+ )
212
236
  if include_val:
213
- ax.plot(leaderboard_order_sorted.index, leaderboard_order_sorted["metric_error_val_so_far_ens"].values, '--', color="orange", label="val (ens)")
237
+ ax.plot(
238
+ leaderboard_order_sorted.index,
239
+ leaderboard_order_sorted["metric_error_val_so_far_ens"].values,
240
+ "--",
241
+ color="orange",
242
+ label="val (ens)",
243
+ )
214
244
  ax.set_xlim(left=1, right=leaderboard_order_sorted.index.max())
215
- ax.set_xlabel('# Models Fit')
216
- ax.set_ylabel(f'Metric Error ({eval_metric})')
245
+ ax.set_xlabel("# Models Fit")
246
+ ax.set_ylabel(f"Metric Error ({eval_metric})")
217
247
  ax.grid()
218
248
 
219
249
  ax = axes[1]
220
250
 
221
251
  if include_test:
222
- ax.plot(leaderboard_order_sorted["time_so_far"].values, leaderboard_order_sorted["metric_error_test_so_far"].values, '-', color="b", label="test")
252
+ ax.plot(
253
+ leaderboard_order_sorted["time_so_far"].values,
254
+ leaderboard_order_sorted["metric_error_test_so_far"].values,
255
+ "-",
256
+ color="b",
257
+ label="test",
258
+ )
223
259
  if include_val:
224
- ax.plot(leaderboard_order_sorted["time_so_far"].values, leaderboard_order_sorted["metric_error_val_so_far"].values, '-', color="orange", label="val")
260
+ ax.plot(
261
+ leaderboard_order_sorted["time_so_far"].values,
262
+ leaderboard_order_sorted["metric_error_val_so_far"].values,
263
+ "-",
264
+ color="orange",
265
+ label="val",
266
+ )
225
267
  if ensemble:
226
268
  if include_test:
227
- ax.plot(leaderboard_order_sorted["time_so_far"].values, leaderboard_order_sorted["metric_error_test_so_far_ens"].values, '--', color="b", label="test (ens)")
269
+ ax.plot(
270
+ leaderboard_order_sorted["time_so_far"].values,
271
+ leaderboard_order_sorted["metric_error_test_so_far_ens"].values,
272
+ "--",
273
+ color="b",
274
+ label="test (ens)",
275
+ )
228
276
  if include_val:
229
- ax.plot(leaderboard_order_sorted["time_so_far"].values, leaderboard_order_sorted["metric_error_val_so_far_ens"].values, '--', color="orange", label="val (ens)")
230
- ax.set_xlabel('Time Elapsed (s)')
277
+ ax.plot(
278
+ leaderboard_order_sorted["time_so_far"].values,
279
+ leaderboard_order_sorted["metric_error_val_so_far_ens"].values,
280
+ "--",
281
+ color="orange",
282
+ label="val (ens)",
283
+ )
284
+ ax.set_xlabel("Time Elapsed (s)")
231
285
  ax.grid()
232
286
  ax.legend()
233
287