autogluon.tabular 1.4.1b20251214__py3-none-any.whl → 1.5.0b20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/hyperparameter_configs.py +4 -0
- autogluon/tabular/configs/presets_configs.py +39 -2
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +2 -44
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +2 -0
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +2 -0
- autogluon/tabular/learner/default_learner.py +1 -0
- autogluon/tabular/models/__init__.py +3 -1
- autogluon/tabular/models/abstract/__init__.py +0 -0
- autogluon/tabular/models/abstract/abstract_torch_model.py +148 -0
- autogluon/tabular/models/catboost/catboost_model.py +1 -1
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +5 -1
- autogluon/tabular/models/lgb/lgb_model.py +58 -8
- autogluon/tabular/models/lgb/lgb_utils.py +2 -2
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +14 -1
- autogluon/tabular/models/mitra/mitra_model.py +53 -22
- autogluon/tabular/models/realmlp/realmlp_model.py +8 -2
- autogluon/tabular/models/tabdpt/__init__.py +0 -0
- autogluon/tabular/models/tabdpt/tabdpt_model.py +253 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +15 -2
- autogluon/tabular/models/tabm/tabm_model.py +23 -79
- autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +451 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +86 -8
- autogluon/tabular/models/tabprep/__init__.py +0 -0
- autogluon/tabular/models/tabprep/prep_lgb_model.py +21 -0
- autogluon/tabular/models/tabprep/prep_mixin.py +220 -0
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +1 -1
- autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +12 -4
- autogluon/tabular/models/xgboost/xgboost_model.py +2 -0
- autogluon/tabular/predictor/predictor.py +47 -18
- autogluon/tabular/registry/_ag_model_registry.py +8 -2
- autogluon/tabular/testing/fit_helper.py +33 -0
- autogluon/tabular/trainer/abstract_trainer.py +45 -9
- autogluon/tabular/trainer/auto_trainer.py +5 -0
- autogluon/tabular/version.py +1 -1
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/METADATA +36 -35
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/RECORD +43 -33
- /autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth → /autogluon.tabular-1.5.0b20251222-py3.11-nspkg.pth +0 -0
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/WHEEL +0 -0
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/licenses/LICENSE +0 -0
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/licenses/NOTICE +0 -0
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/namespace_packages.txt +0 -0
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/top_level.txt +0 -0
- {autogluon_tabular-1.4.1b20251214.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/zip-safe +0 -0
|
@@ -2,6 +2,8 @@ import copy
|
|
|
2
2
|
|
|
3
3
|
from .zeroshot.zeroshot_portfolio_2023 import hyperparameter_portfolio_zeroshot_2023
|
|
4
4
|
from .zeroshot.zeroshot_portfolio_2025 import hyperparameter_portfolio_zeroshot_2025_small
|
|
5
|
+
from .zeroshot.zeroshot_portfolio_cpu_2025_12_18 import hyperparameter_portfolio_zeroshot_cpu_2025_12_18
|
|
6
|
+
from .zeroshot.zeroshot_portfolio_gpu_2025_12_18 import hyperparameter_portfolio_zeroshot_gpu_2025_12_18
|
|
5
7
|
|
|
6
8
|
# Dictionary of preset hyperparameter configurations.
|
|
7
9
|
hyperparameter_config_dict = dict(
|
|
@@ -117,6 +119,8 @@ hyperparameter_config_dict = dict(
|
|
|
117
119
|
zeroshot=hyperparameter_portfolio_zeroshot_2023,
|
|
118
120
|
zeroshot_2023=hyperparameter_portfolio_zeroshot_2023,
|
|
119
121
|
zeroshot_2025_tabfm=hyperparameter_portfolio_zeroshot_2025_small,
|
|
122
|
+
zeroshot_2025_12_18_gpu=hyperparameter_portfolio_zeroshot_gpu_2025_12_18,
|
|
123
|
+
zeroshot_2025_12_18_cpu=hyperparameter_portfolio_zeroshot_cpu_2025_12_18,
|
|
120
124
|
)
|
|
121
125
|
|
|
122
126
|
tabpfnmix_default = {
|
|
@@ -9,6 +9,15 @@ tabular_presets_dict = dict(
|
|
|
9
9
|
"hyperparameters": "zeroshot",
|
|
10
10
|
"time_limit": 3600,
|
|
11
11
|
},
|
|
12
|
+
|
|
13
|
+
best_quality_v150={
|
|
14
|
+
"auto_stack": True,
|
|
15
|
+
"dynamic_stacking": "auto",
|
|
16
|
+
"num_stack_levels": 0,
|
|
17
|
+
"hyperparameters": "zeroshot_2025_12_18_cpu",
|
|
18
|
+
"time_limit": 3600,
|
|
19
|
+
"callbacks": [["EarlyStoppingCountCallback", {"patience": [[100, 4], [500, 8], [2500, 15], [10000, 40], [100000, 100], None]}]],
|
|
20
|
+
},
|
|
12
21
|
# High predictive accuracy with fast inference. ~8x faster inference and ~8x lower disk usage than `best_quality`.
|
|
13
22
|
# Recommended for applications that require fast inference speed and/or small model size.
|
|
14
23
|
# Aliases: high
|
|
@@ -21,6 +30,19 @@ tabular_presets_dict = dict(
|
|
|
21
30
|
"set_best_to_refit_full": True,
|
|
22
31
|
"save_bag_folds": False,
|
|
23
32
|
},
|
|
33
|
+
|
|
34
|
+
high_quality_v150={
|
|
35
|
+
"auto_stack": True,
|
|
36
|
+
"dynamic_stacking": "auto",
|
|
37
|
+
"num_stack_levels": 0,
|
|
38
|
+
"hyperparameters": "zeroshot_2025_12_18_cpu",
|
|
39
|
+
"time_limit": 3600,
|
|
40
|
+
"callbacks": [["EarlyStoppingCountCallback", {"patience": [[100, 4], [500, 8], [2500, 15], [10000, 40], [100000, 100], None]}]],
|
|
41
|
+
"refit_full": True,
|
|
42
|
+
"set_best_to_refit_full": True,
|
|
43
|
+
"save_bag_folds": False,
|
|
44
|
+
},
|
|
45
|
+
|
|
24
46
|
# Good predictive accuracy with very fast inference. ~4x faster training, ~8x faster inference and ~8x lower disk usage than `high_quality`.
|
|
25
47
|
# Recommended for applications that require very fast inference speed.
|
|
26
48
|
# Aliases: good
|
|
@@ -78,11 +100,20 @@ tabular_presets_dict = dict(
|
|
|
78
100
|
# Absolute best predictive accuracy with **zero** consideration to inference time or disk usage.
|
|
79
101
|
# Recommended for applications that benefit from the best possible model accuracy and **do not** care about inference speed.
|
|
80
102
|
# Significantly stronger than `best_quality`, but can be over 10x slower in inference.
|
|
81
|
-
# Uses pre-trained tabular foundation models, which add a minimum of
|
|
103
|
+
# Uses pre-trained tabular foundation models, which add a minimum of 100 MB to the predictor artifact's size.
|
|
82
104
|
# For best results, use as large of an instance as possible with a GPU and as many CPU cores as possible (ideally 64+ cores)
|
|
83
105
|
# Aliases: extreme, experimental, experimental_quality
|
|
84
106
|
# GPU STRONGLY RECOMMENDED
|
|
85
107
|
extreme_quality={
|
|
108
|
+
"auto_stack": True,
|
|
109
|
+
"dynamic_stacking": "auto",
|
|
110
|
+
"num_stack_levels": 0,
|
|
111
|
+
"hyperparameters": "zeroshot_2025_12_18_gpu",
|
|
112
|
+
"time_limit": 3600,
|
|
113
|
+
"callbacks": [["EarlyStoppingCountCallback", {"patience": [[100, 4], [500, 8], [2500, 15], [10000, 40], [100000, 100], None]}]],
|
|
114
|
+
},
|
|
115
|
+
|
|
116
|
+
extreme_quality_v140={
|
|
86
117
|
"auto_stack": True,
|
|
87
118
|
"dynamic_stacking": "auto",
|
|
88
119
|
"num_bag_sets": 1,
|
|
@@ -140,5 +171,11 @@ tabular_presets_alias = dict(
|
|
|
140
171
|
mq="medium_quality",
|
|
141
172
|
experimental="extreme_quality",
|
|
142
173
|
experimental_quality="extreme_quality",
|
|
143
|
-
experimental_quality_v140="
|
|
174
|
+
experimental_quality_v140="extreme_quality_v140",
|
|
175
|
+
best_v140="best_quality",
|
|
176
|
+
best_v150="best_quality_v150",
|
|
177
|
+
best_quality_v140="best_quality",
|
|
178
|
+
high_v150="high_quality_v150",
|
|
179
|
+
extreme_v140="extreme_quality_v140",
|
|
180
|
+
extreme_v150="extreme_quality",
|
|
144
181
|
)
|
|
@@ -1,50 +1,8 @@
|
|
|
1
1
|
# optimized for <=10000 samples and <=500 features, with a GPU present
|
|
2
2
|
hyperparameter_portfolio_zeroshot_2025_small = {
|
|
3
|
-
"
|
|
3
|
+
"REALTABPFN-V2": [
|
|
4
4
|
{
|
|
5
|
-
"ag_args": {'
|
|
6
|
-
"average_before_softmax": False,
|
|
7
|
-
"classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
|
|
8
|
-
"inference_config/FINGERPRINT_FEATURE": False,
|
|
9
|
-
"inference_config/OUTLIER_REMOVAL_STD": None,
|
|
10
|
-
"inference_config/POLYNOMIAL_FEATURES": 'no',
|
|
11
|
-
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
|
|
12
|
-
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
|
|
13
|
-
"inference_config/SUBSAMPLE_SAMPLES": 0.99,
|
|
14
|
-
"model_type": 'single',
|
|
15
|
-
"n_ensemble_repeats": 4,
|
|
16
|
-
"regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
|
|
17
|
-
"softmax_temperature": 0.75,
|
|
18
|
-
},
|
|
19
|
-
{
|
|
20
|
-
"ag_args": {'name_suffix': '_r94', 'priority': -3},
|
|
21
|
-
"average_before_softmax": True,
|
|
22
|
-
"classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
|
|
23
|
-
"inference_config/FINGERPRINT_FEATURE": True,
|
|
24
|
-
"inference_config/OUTLIER_REMOVAL_STD": None,
|
|
25
|
-
"inference_config/POLYNOMIAL_FEATURES": 'no',
|
|
26
|
-
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
|
|
27
|
-
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
|
|
28
|
-
"inference_config/SUBSAMPLE_SAMPLES": None,
|
|
29
|
-
"model_type": 'single',
|
|
30
|
-
"n_ensemble_repeats": 4,
|
|
31
|
-
"regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
|
|
32
|
-
"softmax_temperature": 0.9,
|
|
33
|
-
},
|
|
34
|
-
{
|
|
35
|
-
"ag_args": {'name_suffix': '_r181', 'priority': -4},
|
|
36
|
-
"average_before_softmax": False,
|
|
37
|
-
"classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
|
|
38
|
-
"inference_config/FINGERPRINT_FEATURE": False,
|
|
39
|
-
"inference_config/OUTLIER_REMOVAL_STD": 9.0,
|
|
40
|
-
"inference_config/POLYNOMIAL_FEATURES": 50,
|
|
41
|
-
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
|
|
42
|
-
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
|
|
43
|
-
"inference_config/SUBSAMPLE_SAMPLES": None,
|
|
44
|
-
"model_type": 'single',
|
|
45
|
-
"n_ensemble_repeats": 4,
|
|
46
|
-
"regression_model_path": 'tabpfn-v2-regressor.ckpt',
|
|
47
|
-
"softmax_temperature": 0.95,
|
|
5
|
+
"ag_args": {'priority': -1},
|
|
48
6
|
},
|
|
49
7
|
],
|
|
50
8
|
"GBM": [
|
|
@@ -0,0 +1,2 @@
|
|
|
1
|
+
# On par with `best_quality` while being much faster for smaller datasets. Runs on CPU.
|
|
2
|
+
hyperparameter_portfolio_zeroshot_cpu_2025_12_18 = {'CAT': [{'ag_args': {'name_suffix': '_c1', 'priority': -1}}], 'GBM_PREP': [{'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r13', 'priority': -2}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9923026236907, 'bagging_freq': 1, 'cat_l2': 0.014290368488, 'cat_smooth': 1.8662939903973, 'extra_trees': True, 'feature_fraction': 0.5533919718605, 'lambda_l1': 0.914411672958, 'lambda_l2': 1.90439560009, 'learning_rate': 0.0193225778401, 'max_cat_to_onehot': 18, 'min_data_in_leaf': 28, 'min_data_per_group': 54, 'num_leaves': 64}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r41', 'priority': -7}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7215411996558, 'bagging_freq': 1, 'cat_l2': 1.887369154362, 'cat_smooth': 0.0278693980873, 'extra_trees': True, 'feature_fraction': 0.4247583287144, 'lambda_l1': 0.1129800247772, 'lambda_l2': 0.2623265718536, 'learning_rate': 0.0074201920651, 'max_cat_to_onehot': 9, 'min_data_in_leaf': 15, 'min_data_per_group': 10, 'num_leaves': 8}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r31', 'priority': -10}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9591526242875, 'bagging_freq': 1, 'cat_l2': 1.8962346412823, 'cat_smooth': 0.0215219089995, 'extra_trees': False, 'feature_fraction': 0.5791844062459, 'lambda_l1': 0.938461750637, 'lambda_l2': 0.9899852075056, 'learning_rate': 0.0397613094741, 'max_cat_to_onehot': 27, 'min_data_in_leaf': 1, 'min_data_per_group': 39, 'num_leaves': 16}, {'ag.prep_params': [], 'ag_args': {'name_suffix': '_r21', 'priority': -12}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7111549514262, 'bagging_freq': 1, 'cat_l2': 0.8679131150136, 'cat_smooth': 48.7244965504817, 'extra_trees': False, 'feature_fraction': 0.425140839263, 'lambda_l1': 0.5140528525242, 'lambda_l2': 0.5134051978198, 'learning_rate': 0.0134375321277, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 2, 'min_data_per_group': 32, 'num_leaves': 20}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r17', 'priority': -17}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9277474245702, 'bagging_freq': 1, 'cat_l2': 0.0731876168104, 'cat_smooth': 0.1369210915339, 'extra_trees': False, 'feature_fraction': 0.6680440910385, 'lambda_l1': 0.0125057410295, 'lambda_l2': 0.7157181359874, 'learning_rate': 0.0351342879995, 'max_cat_to_onehot': 20, 'min_data_in_leaf': 1, 'min_data_per_group': 2, 'num_leaves': 64}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]]]], 'ag_args': {'name_suffix': '_r47', 'priority': -18}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9918048278435, 'bagging_freq': 1, 'cat_l2': 0.984162386723, 'cat_smooth': 0.0049687445294, 'extra_trees': True, 'feature_fraction': 0.4974006116018, 'lambda_l1': 0.7970644065518, 'lambda_l2': 1.2179933810825, 'learning_rate': 0.0537072755122, 'max_cat_to_onehot': 13, 'min_data_in_leaf': 1, 'min_data_per_group': 4, 'num_leaves': 32}, {'ag.prep_params': [[[['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r1', 'priority': -19}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.8836335684032, 'bagging_freq': 1, 'cat_l2': 0.6608043016307, 'cat_smooth': 0.0451936212097, 'extra_trees': True, 'feature_fraction': 0.6189315903408, 'lambda_l1': 0.6514130054123, 'lambda_l2': 1.7382678663835, 'learning_rate': 0.0412716109215, 'max_cat_to_onehot': 9, 'min_data_in_leaf': 9, 'min_data_per_group': 3, 'num_leaves': 128}, {'ag.prep_params': [[[['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r19', 'priority': -26}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7106002663401, 'bagging_freq': 1, 'cat_l2': 0.1559746777257, 'cat_smooth': 0.0036366126697, 'extra_trees': False, 'feature_fraction': 0.688233104808, 'lambda_l1': 0.8732887427372, 'lambda_l2': 0.446716114323, 'learning_rate': 0.0815946452855, 'max_cat_to_onehot': 78, 'min_data_in_leaf': 12, 'min_data_per_group': 2, 'num_leaves': 16}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r34', 'priority': -32}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.8453534561545, 'bagging_freq': 1, 'cat_l2': 0.0321580936847, 'cat_smooth': 0.0011470238114, 'extra_trees': True, 'feature_fraction': 0.8611499511087, 'lambda_l1': 0.910743969343, 'lambda_l2': 1.2750027607225, 'learning_rate': 0.0151455176168, 'max_cat_to_onehot': 8, 'min_data_in_leaf': 60, 'min_data_per_group': 4, 'num_leaves': 32}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r32', 'priority': -37}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.927947070297, 'bagging_freq': 1, 'cat_l2': 0.0082294539727, 'cat_smooth': 0.0671878797989, 'extra_trees': True, 'feature_fraction': 0.9169657691675, 'lambda_l1': 0.9386485912678, 'lambda_l2': 1.619775689786, 'learning_rate': 0.0056864355547, 'max_cat_to_onehot': 11, 'min_data_in_leaf': 1, 'min_data_per_group': 10, 'num_leaves': 32}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r7', 'priority': -38}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.8984634022103, 'bagging_freq': 1, 'cat_l2': 0.0053608956358, 'cat_smooth': 89.7168790664636, 'extra_trees': False, 'feature_fraction': 0.847638045482, 'lambda_l1': 0.5684527742857, 'lambda_l2': 1.0738026980295, 'learning_rate': 0.0417108779005, 'max_cat_to_onehot': 8, 'min_data_in_leaf': 2, 'min_data_per_group': 7, 'num_leaves': 128}, {'ag.prep_params': [[[['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r14', 'priority': -40}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9318953983366, 'bagging_freq': 1, 'cat_l2': 0.065532200068, 'cat_smooth': 0.0696287198368, 'extra_trees': True, 'feature_fraction': 0.4649868965096, 'lambda_l1': 0.6586569196642, 'lambda_l2': 1.7799375779553, 'learning_rate': 0.072046289471, 'max_cat_to_onehot': 72, 'min_data_in_leaf': 26, 'min_data_per_group': 32, 'num_leaves': 32}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]]]], 'ag_args': {'name_suffix': '_r27', 'priority': -42}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.811983527375, 'bagging_freq': 1, 'cat_l2': 0.0255048028385, 'cat_smooth': 1.5339379274002, 'extra_trees': True, 'feature_fraction': 0.5246746068724, 'lambda_l1': 0.9737915306165, 'lambda_l2': 1.929596568261, 'learning_rate': 0.0172284745143, 'max_cat_to_onehot': 9, 'min_data_in_leaf': 8, 'min_data_per_group': 51, 'num_leaves': 20}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]]]], 'ag_args': {'name_suffix': '_r37', 'priority': -46}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7853761603489, 'bagging_freq': 1, 'cat_l2': 0.2934796127084, 'cat_smooth': 10.1721684646257, 'extra_trees': False, 'feature_fraction': 0.4813265290277, 'lambda_l1': 0.9744837697365, 'lambda_l2': 0.6058665958153, 'learning_rate': 0.0371000014124, 'max_cat_to_onehot': 85, 'min_data_in_leaf': 22, 'min_data_per_group': 3, 'num_leaves': 32}], 'GBM': [{'ag_args': {'name_suffix': '_r177', 'priority': -3}, 'bagging_fraction': 0.8769107816033, 'bagging_freq': 1, 'cat_l2': 0.3418014393813, 'cat_smooth': 15.4304556649114, 'extra_trees': True, 'feature_fraction': 0.4622189821941, 'lambda_l1': 0.2375070586896, 'lambda_l2': 0.3551561351804, 'learning_rate': 0.0178593900218, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 3, 'min_data_per_group': 9, 'num_leaves': 39}, {'ag_args': {'name_suffix': '_r163', 'priority': -5}, 'bagging_fraction': 0.9783898288461, 'bagging_freq': 1, 'cat_l2': 0.1553395260142, 'cat_smooth': 0.0093122749318, 'extra_trees': False, 'feature_fraction': 0.5279825611461, 'lambda_l1': 0.0269274915833, 'lambda_l2': 0.8375250972309, 'learning_rate': 0.0113913650333, 'max_cat_to_onehot': 42, 'min_data_in_leaf': 3, 'min_data_per_group': 75, 'num_leaves': 84}, {'ag_args': {'name_suffix': '_r72', 'priority': -8}, 'bagging_fraction': 0.950146543918, 'bagging_freq': 1, 'cat_l2': 0.2159137242663, 'cat_smooth': 0.0638204395719, 'extra_trees': True, 'feature_fraction': 0.4044759649281, 'lambda_l1': 0.7661581500422, 'lambda_l2': 1.6041759693902, 'learning_rate': 0.0179845918984, 'max_cat_to_onehot': 11, 'min_data_in_leaf': 12, 'min_data_per_group': 3, 'num_leaves': 180}, {'ag_args': {'name_suffix': '_r120', 'priority': -13}, 'bagging_fraction': 0.8541333332514, 'bagging_freq': 1, 'cat_l2': 0.0110343197541, 'cat_smooth': 5.0905236124522, 'extra_trees': True, 'feature_fraction': 0.7334718346252, 'lambda_l1': 0.241338427726, 'lambda_l2': 0.298107723769, 'learning_rate': 0.0126654490778, 'max_cat_to_onehot': 67, 'min_data_in_leaf': 12, 'min_data_per_group': 93, 'num_leaves': 5}, {'ag_args': {'name_suffix': '_r6', 'priority': -16}, 'bagging_fraction': 0.8148132107231, 'bagging_freq': 1, 'cat_l2': 0.0058363329714, 'cat_smooth': 0.0289414318324, 'extra_trees': False, 'feature_fraction': 0.939979116902, 'lambda_l1': 0.4369494828584, 'lambda_l2': 0.2997524486083, 'learning_rate': 0.0078971749764, 'max_cat_to_onehot': 28, 'min_data_in_leaf': 24, 'min_data_per_group': 3, 'num_leaves': 8}, {'ag_args': {'name_suffix': '_r184', 'priority': -21}, 'bagging_fraction': 0.8406256713136, 'bagging_freq': 1, 'cat_l2': 0.9284921901786, 'cat_smooth': 0.0898191451684, 'extra_trees': False, 'feature_fraction': 0.5876132298377, 'lambda_l1': 0.078943697912, 'lambda_l2': 0.7713118402478, 'learning_rate': 0.0090676429159, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 17, 'min_data_per_group': 11, 'num_leaves': 2}, {'ag_args': {'name_suffix': '_r46', 'priority': -23}, 'bagging_fraction': 0.999426150416, 'bagging_freq': 1, 'cat_l2': 0.0076879104679, 'cat_smooth': 89.4599055435924, 'extra_trees': False, 'feature_fraction': 0.8588138897928, 'lambda_l1': 0.0413597548025, 'lambda_l2': 0.2258713386858, 'learning_rate': 0.0074056102479, 'max_cat_to_onehot': 11, 'min_data_in_leaf': 1, 'min_data_per_group': 26, 'num_leaves': 14}, {'ag_args': {'name_suffix': '_r68', 'priority': -24}, 'bagging_fraction': 0.7199080522958, 'bagging_freq': 1, 'cat_l2': 0.9369509319667, 'cat_smooth': 11.0984745216942, 'extra_trees': False, 'feature_fraction': 0.9550596478029, 'lambda_l1': 0.1109843723892, 'lambda_l2': 0.5969094177111, 'learning_rate': 0.0079480499426, 'max_cat_to_onehot': 8, 'min_data_in_leaf': 3, 'min_data_per_group': 8, 'num_leaves': 111}, {'ag_args': {'name_suffix': '_r47', 'priority': -29}, 'bagging_fraction': 0.8831228358892, 'bagging_freq': 1, 'cat_l2': 0.1402622388062, 'cat_smooth': 3.3545774392409, 'extra_trees': True, 'feature_fraction': 0.6155890374887, 'lambda_l1': 0.1749502746898, 'lambda_l2': 0.8761391715812, 'learning_rate': 0.00891978331, 'max_cat_to_onehot': 84, 'min_data_in_leaf': 1, 'min_data_per_group': 21, 'num_leaves': 55}, {'ag_args': {'name_suffix': '_r63', 'priority': -31}, 'bagging_fraction': 0.7801003412553, 'bagging_freq': 1, 'cat_l2': 0.0071438335269, 'cat_smooth': 0.1338043459574, 'extra_trees': False, 'feature_fraction': 0.490455360592, 'lambda_l1': 0.6420805635778, 'lambda_l2': 0.5813319300456, 'learning_rate': 0.0308746408751, 'max_cat_to_onehot': 38, 'min_data_in_leaf': 1, 'min_data_per_group': 83, 'num_leaves': 24}, {'ag_args': {'name_suffix': '_r39', 'priority': -36}, 'bagging_fraction': 0.7035743460186, 'bagging_freq': 1, 'cat_l2': 0.0134845084619, 'cat_smooth': 56.4934757686511, 'extra_trees': True, 'feature_fraction': 0.7824899527144, 'lambda_l1': 0.3700115211248, 'lambda_l2': 0.0341499593689, 'learning_rate': 0.094652390088, 'max_cat_to_onehot': 13, 'min_data_in_leaf': 13, 'min_data_per_group': 4, 'num_leaves': 23}, {'ag_args': {'name_suffix': '_r18', 'priority': -43}, 'bagging_fraction': 0.7041134150362, 'bagging_freq': 1, 'cat_l2': 0.1139031650222, 'cat_smooth': 41.8937939300815, 'extra_trees': True, 'feature_fraction': 0.5028791565785, 'lambda_l1': 0.1031941284118, 'lambda_l2': 1.2554010747358, 'learning_rate': 0.0186530122901, 'max_cat_to_onehot': 29, 'min_data_in_leaf': 5, 'min_data_per_group': 74, 'num_leaves': 5}, {'ag_args': {'name_suffix': '_r50', 'priority': -45}, 'bagging_fraction': 0.9673434664048, 'bagging_freq': 1, 'cat_l2': 1.7662226703416, 'cat_smooth': 0.0097667848046, 'extra_trees': True, 'feature_fraction': 0.9286299570284, 'lambda_l1': 0.0448644389135, 'lambda_l2': 1.7322446850205, 'learning_rate': 0.0507909494543, 'max_cat_to_onehot': 11, 'min_data_in_leaf': 4, 'min_data_per_group': 2, 'num_leaves': 106}, {'ag_args': {'name_suffix': '_r104', 'priority': -48}, 'bagging_fraction': 0.9327643671568, 'bagging_freq': 1, 'cat_l2': 0.0067636494662, 'cat_smooth': 29.2351010915576, 'extra_trees': False, 'feature_fraction': 0.660864035482, 'lambda_l1': 0.556745328417, 'lambda_l2': 1.2717605868201, 'learning_rate': 0.0433336000175, 'max_cat_to_onehot': 42, 'min_data_in_leaf': 18, 'min_data_per_group': 6, 'num_leaves': 19}], 'NN_TORCH': [{'activation': 'elu', 'ag_args': {'name_suffix': '_r37', 'priority': -4}, 'dropout_prob': 0.0889772897547275, 'hidden_size': 109, 'learning_rate': 0.02184363543226557, 'num_layers': 3, 'use_batchnorm': True, 'weight_decay': 3.1736637236578543e-10}, {'activation': 'elu', 'ag_args': {'name_suffix': '_r31', 'priority': -9}, 'dropout_prob': 0.013288954106470907, 'hidden_size': 81, 'learning_rate': 0.005340914647396153, 'num_layers': 4, 'use_batchnorm': False, 'weight_decay': 8.76216837077536e-05}, {'activation': 'elu', 'ag_args': {'name_suffix': '_r193', 'priority': -14}, 'dropout_prob': 0.2976404923811552, 'hidden_size': 131, 'learning_rate': 0.0038408014156739775, 'num_layers': 3, 'use_batchnorm': False, 'weight_decay': 0.01745189206113213}, {'activation': 'elu', 'ag_args': {'name_suffix': '_r144', 'priority': -15}, 'dropout_prob': 0.2670859555485912, 'hidden_size': 52, 'learning_rate': 0.015189605588375421, 'num_layers': 4, 'use_batchnorm': True, 'weight_decay': 2.8013784883244263e-08}, {'activation': 'relu', 'ag_args': {'name_suffix': '_r82', 'priority': -22}, 'dropout_prob': 0.27342918414623907, 'hidden_size': 207, 'learning_rate': 0.0004069380929899853, 'num_layers': 4, 'use_batchnorm': False, 'weight_decay': 0.002473667327700422}, {'activation': 'elu', 'ag_args': {'name_suffix': '_r39', 'priority': -27}, 'dropout_prob': 0.21699951000415899, 'hidden_size': 182, 'learning_rate': 0.00014675249427915203, 'num_layers': 2, 'use_batchnorm': False, 'weight_decay': 9.787353852692089e-08}, {'activation': 'relu', 'ag_args': {'name_suffix': '_r1', 'priority': -30}, 'dropout_prob': 0.23713784729000734, 'hidden_size': 200, 'learning_rate': 0.0031125617090901805, 'num_layers': 4, 'use_batchnorm': True, 'weight_decay': 4.57301675647447e-08}, {'activation': 'relu', 'ag_args': {'name_suffix': '_r48', 'priority': -34}, 'dropout_prob': 0.14224509513998226, 'hidden_size': 26, 'learning_rate': 0.007085904739869829, 'num_layers': 2, 'use_batchnorm': False, 'weight_decay': 2.465786211798467e-10}, {'activation': 'elu', 'ag_args': {'name_suffix': '_r135', 'priority': -39}, 'dropout_prob': 0.06134755114373829, 'hidden_size': 144, 'learning_rate': 0.005834535148903802, 'num_layers': 5, 'use_batchnorm': True, 'weight_decay': 2.0826540090463376e-09}, {'activation': 'elu', 'ag_args': {'name_suffix': '_r24', 'priority': -49}, 'dropout_prob': 0.257596079691855, 'hidden_size': 168, 'learning_rate': 0.0034108596383714608, 'num_layers': 4, 'use_batchnorm': True, 'weight_decay': 1.4840689603685264e-07}, {'activation': 'relu', 'ag_args': {'name_suffix': '_r159', 'priority': -50}, 'dropout_prob': 0.16724368469920037, 'hidden_size': 44, 'learning_rate': 0.011043937174833164, 'num_layers': 4, 'use_batchnorm': False, 'weight_decay': 0.007265742373924609}], 'FASTAI': [{'ag_args': {'name_suffix': '_r25', 'priority': -6}, 'bs': 1024, 'emb_drop': 0.6167722379778131, 'epochs': 44, 'layers': [200, 100, 50], 'lr': 0.05344037785562929, 'ps': 0.48477211305443607}, {'ag_args': {'name_suffix': '_r162', 'priority': -11}, 'bs': 2048, 'emb_drop': 0.5474625640581479, 'epochs': 45, 'layers': [400, 200], 'lr': 0.0047438648957706655, 'ps': 0.07533239360470734}, {'ag_args': {'name_suffix': '_r147', 'priority': -20}, 'bs': 128, 'emb_drop': 0.6378380130337095, 'epochs': 48, 'layers': [200], 'lr': 0.058027179860229344, 'ps': 0.23253362133888375}, {'ag_args': {'name_suffix': '_r192', 'priority': -25}, 'bs': 1024, 'emb_drop': 0.0698130630643278, 'epochs': 37, 'layers': [400, 200], 'lr': 0.0018949411343821322, 'ps': 0.6526067160491229}, {'ag_args': {'name_suffix': '_r109', 'priority': -28}, 'bs': 128, 'emb_drop': 0.1978897556618756, 'epochs': 49, 'layers': [400, 200, 100], 'lr': 0.02155144303508465, 'ps': 0.005518872455908264}, {'ag_args': {'name_suffix': '_r78', 'priority': -33}, 'bs': 512, 'emb_drop': 0.4897354379753617, 'epochs': 26, 'layers': [400, 200, 100], 'lr': 0.027563880686468895, 'ps': 0.44524273881299886}, {'ag_args': {'name_suffix': '_r150', 'priority': -35}, 'bs': 2048, 'emb_drop': 0.6148607467659958, 'epochs': 27, 'layers': [400, 200], 'lr': 0.09351668652547614, 'ps': 0.5314977162016676}, {'ag_args': {'name_suffix': '_r133', 'priority': -41}, 'bs': 256, 'emb_drop': 0.6242606757570891, 'epochs': 43, 'layers': [200, 100, 50], 'lr': 0.001533613235987637, 'ps': 0.5354961132962562}, {'ag_args': {'name_suffix': '_r99', 'priority': -44}, 'bs': 512, 'emb_drop': 0.6071025838237253, 'epochs': 49, 'layers': [400, 200], 'lr': 0.02669945959641021, 'ps': 0.4897025421573259}, {'ag_args': {'name_suffix': '_r197', 'priority': -47}, 'bs': 256, 'emb_drop': 0.5277230463737563, 'epochs': 45, 'layers': [400, 200], 'lr': 0.006908743712130657, 'ps': 0.08262909528632323}]}
|
|
@@ -0,0 +1,2 @@
|
|
|
1
|
+
# State-of-the-art for datasets < 100k samples. Requires a GPU with at least 20 GB VRAM.
|
|
2
|
+
hyperparameter_portfolio_zeroshot_gpu_2025_12_18 = {'TABDPT': [{'ag_args': {'name_suffix': '_c1', 'priority': -3}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}}, {'ag_args': {'name_suffix': '_r20', 'priority': -5}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 8, 'feature_reduction': 'subsample', 'missing_indicators': False, 'normalizer': 'quantile-uniform', 'permute_classes': False, 'temperature': 0.5}, {'ag_args': {'name_suffix': '_r1', 'priority': -7}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 16, 'feature_reduction': 'subsample', 'missing_indicators': False, 'normalizer': 'log1p', 'permute_classes': False, 'temperature': 0.5}, {'ag_args': {'name_suffix': '_r15', 'priority': -9}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 16, 'feature_reduction': 'subsample', 'missing_indicators': False, 'normalizer': 'standard', 'permute_classes': True, 'temperature': 0.7}, {'ag_args': {'name_suffix': '_r22', 'priority': -11}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': False}, 'clip_sigma': 8, 'feature_reduction': 'pca', 'missing_indicators': True, 'normalizer': 'robust', 'permute_classes': False, 'temperature': 0.5}], 'TABICL': [{'ag_args': {'name_suffix': '_c1', 'priority': -4}, 'ag_args_ensemble': {'refit_folds': True}}], 'MITRA': [{'ag_args': {'name_suffix': '_c1', 'priority': -12}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}}], 'TABM': [{'ag_args': {'name_suffix': '_r99', 'priority': -13}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 880, 'd_embedding': 24, 'dropout': 0.10792355695428629, 'gradient_clipping_norm': 1.0, 'lr': 0.0013641856391615784, 'n_blocks': 5, 'num_emb_n_bins': 16, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0}, {'ag_args': {'name_suffix': '_r124', 'priority': -17}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 208, 'd_embedding': 16, 'dropout': 0.0, 'gradient_clipping_norm': 1.0, 'lr': 0.00042152744054701374, 'n_blocks': 2, 'num_emb_n_bins': 109, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.00014007839435474664}, {'ag_args': {'name_suffix': '_r69', 'priority': -21}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 848, 'd_embedding': 28, 'dropout': 0.40215621636031007, 'gradient_clipping_norm': 1.0, 'lr': 0.0010413640454559532, 'n_blocks': 3, 'num_emb_n_bins': 18, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0}, {'ag_args': {'name_suffix': '_r184', 'priority': -24}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 864, 'd_embedding': 24, 'dropout': 0.0, 'gradient_clipping_norm': 1.0, 'lr': 0.0019256819924656217, 'n_blocks': 3, 'num_emb_n_bins': 3, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0}, {'ag_args': {'name_suffix': '_r34', 'priority': -26}, 'amp': False, 'arch_type': 'tabm-mini', 'batch_size': 'auto', 'd_block': 896, 'd_embedding': 8, 'dropout': 0.0, 'gradient_clipping_norm': 1.0, 'lr': 0.002459175026451607, 'n_blocks': 4, 'num_emb_n_bins': 104, 'num_emb_type': 'pwl', 'patience': 16, 'share_training_batches': False, 'tabm_k': 32, 'weight_decay': 0.0006299584388562901}], 'GBM_PREP': [{'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r13', 'priority': -14}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9923026236907, 'bagging_freq': 1, 'cat_l2': 0.014290368488, 'cat_smooth': 1.8662939903973, 'extra_trees': True, 'feature_fraction': 0.5533919718605, 'lambda_l1': 0.914411672958, 'lambda_l2': 1.90439560009, 'learning_rate': 0.0193225778401, 'max_cat_to_onehot': 18, 'min_data_in_leaf': 28, 'min_data_per_group': 54, 'num_leaves': 64}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r41', 'priority': -16}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7215411996558, 'bagging_freq': 1, 'cat_l2': 1.887369154362, 'cat_smooth': 0.0278693980873, 'extra_trees': True, 'feature_fraction': 0.4247583287144, 'lambda_l1': 0.1129800247772, 'lambda_l2': 0.2623265718536, 'learning_rate': 0.0074201920651, 'max_cat_to_onehot': 9, 'min_data_in_leaf': 15, 'min_data_per_group': 10, 'num_leaves': 8}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r31', 'priority': -18}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9591526242875, 'bagging_freq': 1, 'cat_l2': 1.8962346412823, 'cat_smooth': 0.0215219089995, 'extra_trees': False, 'feature_fraction': 0.5791844062459, 'lambda_l1': 0.938461750637, 'lambda_l2': 0.9899852075056, 'learning_rate': 0.0397613094741, 'max_cat_to_onehot': 27, 'min_data_in_leaf': 1, 'min_data_per_group': 39, 'num_leaves': 16}, {'ag.prep_params': [], 'ag_args': {'name_suffix': '_r21', 'priority': -20}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.7111549514262, 'bagging_freq': 1, 'cat_l2': 0.8679131150136, 'cat_smooth': 48.7244965504817, 'extra_trees': False, 'feature_fraction': 0.425140839263, 'lambda_l1': 0.5140528525242, 'lambda_l2': 0.5134051978198, 'learning_rate': 0.0134375321277, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 2, 'min_data_per_group': 32, 'num_leaves': 20}, {'ag.prep_params': [[[['ArithmeticFeatureGenerator', {}]], [['CategoricalInteractionFeatureGenerator', {'passthrough': True}], ['OOFTargetEncodingFeatureGenerator', {}]]]], 'ag.prep_params.passthrough_types': {'invalid_raw_types': ['category', 'object']}, 'ag_args': {'name_suffix': '_r17', 'priority': -23}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'bagging_fraction': 0.9277474245702, 'bagging_freq': 1, 'cat_l2': 0.0731876168104, 'cat_smooth': 0.1369210915339, 'extra_trees': False, 'feature_fraction': 0.6680440910385, 'lambda_l1': 0.0125057410295, 'lambda_l2': 0.7157181359874, 'learning_rate': 0.0351342879995, 'max_cat_to_onehot': 20, 'min_data_in_leaf': 1, 'min_data_per_group': 2, 'num_leaves': 64}], 'CAT': [{'ag_args': {'name_suffix': '_c1', 'priority': -15}}], 'GBM': [{'ag_args': {'name_suffix': '_r73', 'priority': -19}, 'bagging_fraction': 0.7295548973583, 'bagging_freq': 1, 'cat_l2': 1.8025485263237, 'cat_smooth': 59.6178463268351, 'extra_trees': False, 'feature_fraction': 0.8242607305914, 'lambda_l1': 0.7265522905459, 'lambda_l2': 0.3492160682092, 'learning_rate': 0.0068803786367, 'max_cat_to_onehot': 16, 'min_data_in_leaf': 1, 'min_data_per_group': 10, 'num_leaves': 24}, {'ag_args': {'name_suffix': '_r37', 'priority': -22}, 'bagging_fraction': 0.8096374561947, 'bagging_freq': 1, 'cat_l2': 1.6385754694703, 'cat_smooth': 16.1922506671724, 'extra_trees': True, 'feature_fraction': 0.885927003286, 'lambda_l1': 0.0430386950502, 'lambda_l2': 0.2507506811761, 'learning_rate': 0.0079622660542, 'max_cat_to_onehot': 23, 'min_data_in_leaf': 7, 'min_data_per_group': 49, 'num_leaves': 6}, {'ag_args': {'name_suffix': '_r162', 'priority': -25}, 'bagging_fraction': 0.7552878818396, 'bagging_freq': 1, 'cat_l2': 0.0081083103544, 'cat_smooth': 75.7373446363438, 'extra_trees': False, 'feature_fraction': 0.6171258454584, 'lambda_l1': 0.1071522383181, 'lambda_l2': 1.7882554584069, 'learning_rate': 0.0229328987255, 'max_cat_to_onehot': 24, 'min_data_in_leaf': 23, 'min_data_per_group': 2, 'num_leaves': 125}, {'ag_args': {'name_suffix': '_r57', 'priority': -27}, 'bagging_fraction': 0.8515739264605, 'bagging_freq': 1, 'cat_l2': 0.2263901847144, 'cat_smooth': 1.7397457971767, 'extra_trees': True, 'feature_fraction': 0.6284015946887, 'lambda_l1': 0.6935431676756, 'lambda_l2': 1.7605230133162, 'learning_rate': 0.0294830579218, 'max_cat_to_onehot': 52, 'min_data_in_leaf': 8, 'min_data_per_group': 3, 'num_leaves': 43}, {'ag_args': {'name_suffix': '_r33', 'priority': -28}, 'bagging_fraction': 0.9625293420216, 'bagging_freq': 1, 'cat_l2': 0.1236875455555, 'cat_smooth': 68.8584757332856, 'extra_trees': False, 'feature_fraction': 0.6189215809382, 'lambda_l1': 0.1641757352921, 'lambda_l2': 0.6937755557881, 'learning_rate': 0.0154031028561, 'max_cat_to_onehot': 17, 'min_data_in_leaf': 1, 'min_data_per_group': 30, 'num_leaves': 68}], 'REALTABPFN-V2': [{'ag_args': {'name_suffix': '_r13', 'priority': -1}, 'ag_args_ensemble': {'model_random_seed': 104, 'vary_seed_across_folds': True}, 'balance_probabilities': False, 'inference_config/OUTLIER_REMOVAL_STD': 6, 'inference_config/POLYNOMIAL_FEATURES': 'no', 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': [None, 'safepower'], 'preprocessing/append_original': False, 'preprocessing/categoricals': 'numeric', 'preprocessing/global': None, 'preprocessing/scaling': ['squashing_scaler_default', 'quantile_uni_coarse'], 'softmax_temperature': 1.0, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_r106', 'priority': -2}, 'ag_args_ensemble': {'model_random_seed': 848, 'vary_seed_across_folds': True}, 'balance_probabilities': False, 'inference_config/OUTLIER_REMOVAL_STD': 6, 'inference_config/POLYNOMIAL_FEATURES': 'no', 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': [None], 'preprocessing/append_original': True, 'preprocessing/categoricals': 'numeric', 'preprocessing/global': 'svd_quarter_components', 'preprocessing/scaling': ['quantile_uni_coarse'], 'softmax_temperature': 0.8, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_r11', 'priority': -6}, 'ag_args_ensemble': {'model_random_seed': 88, 'vary_seed_across_folds': True}, 'balance_probabilities': True, 'inference_config/OUTLIER_REMOVAL_STD': 6, 'inference_config/POLYNOMIAL_FEATURES': 25, 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': [None], 'preprocessing/append_original': True, 'preprocessing/categoricals': 'onehot', 'preprocessing/global': 'svd_quarter_components', 'preprocessing/scaling': ['safepower', 'quantile_uni'], 'softmax_temperature': 0.7, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_c1', 'priority': -8}, 'ag_args_ensemble': {'model_random_seed': 0, 'vary_seed_across_folds': True}, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}, {'ag_args': {'name_suffix': '_r196', 'priority': -10}, 'ag_args_ensemble': {'model_random_seed': 1568, 'vary_seed_across_folds': True}, 'balance_probabilities': False, 'inference_config/OUTLIER_REMOVAL_STD': 12, 'inference_config/POLYNOMIAL_FEATURES': 'no', 'inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS': ['kdi_alpha_1.0'], 'preprocessing/append_original': False, 'preprocessing/categoricals': 'numeric', 'preprocessing/global': None, 'preprocessing/scaling': ['squashing_scaler_default'], 'softmax_temperature': 1.25, 'zip_model_path': ['tabpfn-v2-classifier-finetuned-zk73skhh.ckpt', 'tabpfn-v2-regressor-v2_default.ckpt']}]}
|
|
@@ -143,6 +143,7 @@ class DefaultLearner(AbstractTabularLearner):
|
|
|
143
143
|
infer_limit=infer_limit,
|
|
144
144
|
infer_limit_batch_size=infer_limit_batch_size,
|
|
145
145
|
groups=groups,
|
|
146
|
+
label_cleaner=copy.deepcopy(self.label_cleaner),
|
|
146
147
|
**trainer_fit_kwargs,
|
|
147
148
|
)
|
|
148
149
|
self.save_trainer(trainer=trainer)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from autogluon.core.models.abstract.abstract_model import AbstractModel
|
|
2
2
|
|
|
3
|
+
from .tabprep.prep_lgb_model import PrepLGBModel
|
|
3
4
|
from .automm.automm_model import MultiModalPredictorModel
|
|
4
5
|
from .automm.ft_transformer import FTTransformerModel
|
|
5
6
|
from .catboost.catboost_model import CatBoostModel
|
|
@@ -20,10 +21,11 @@ from .lgb.lgb_model import LGBModel
|
|
|
20
21
|
from .lr.lr_model import LinearModel
|
|
21
22
|
from .realmlp.realmlp_model import RealMLPModel
|
|
22
23
|
from .rf.rf_model import RFModel
|
|
24
|
+
from .tabdpt.tabdpt_model import TabDPTModel
|
|
23
25
|
from .tabicl.tabicl_model import TabICLModel
|
|
24
26
|
from .tabm.tabm_model import TabMModel
|
|
25
|
-
from .tabpfnv2.tabpfnv2_model import TabPFNV2Model
|
|
26
27
|
from .tabpfnmix.tabpfnmix_model import TabPFNMixModel
|
|
28
|
+
from .tabpfnv2.tabpfnv2_5_model import RealTabPFNv2Model, RealTabPFNv25Model
|
|
27
29
|
from .mitra.mitra_model import MitraModel
|
|
28
30
|
from .tabular_nn.torch.tabular_nn_torch import TabularNeuralNetTorchModel
|
|
29
31
|
from .text_prediction.text_prediction_v1_model import TextPredictorModel
|
|
File without changes
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from autogluon.core.models import AbstractModel
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO: Add type hints once torch is a required dependency
|
|
11
|
+
class AbstractTorchModel(AbstractModel):
|
|
12
|
+
"""
|
|
13
|
+
.. versionadded:: 1.5.0
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self, **kwargs):
|
|
16
|
+
super().__init__(**kwargs)
|
|
17
|
+
self.device = None
|
|
18
|
+
self.device_train = None
|
|
19
|
+
|
|
20
|
+
def suggest_device_infer(self, verbose: bool = False) -> str:
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
# Put the model on the same device it was trained on (GPU/MPS) if it is available; otherwise use CPU
|
|
24
|
+
if self.device_train is None:
|
|
25
|
+
original_device_type = None # skip update because no device is recorded
|
|
26
|
+
elif isinstance(self.device_train, str):
|
|
27
|
+
original_device_type = self.device_train
|
|
28
|
+
else:
|
|
29
|
+
original_device_type = self.device_train.type
|
|
30
|
+
if original_device_type is None:
|
|
31
|
+
# fallback to CPU
|
|
32
|
+
device = torch.device("cpu")
|
|
33
|
+
elif "cuda" in original_device_type:
|
|
34
|
+
# cuda: nvidia GPU
|
|
35
|
+
device = torch.device(original_device_type if torch.cuda.is_available() else "cpu")
|
|
36
|
+
elif "mps" in original_device_type:
|
|
37
|
+
# mps: Apple Silicon
|
|
38
|
+
device = torch.device(original_device_type if torch.backends.mps.is_available() else "cpu")
|
|
39
|
+
else:
|
|
40
|
+
device = torch.device(original_device_type)
|
|
41
|
+
|
|
42
|
+
if verbose and (original_device_type != device.type):
|
|
43
|
+
logger.log(
|
|
44
|
+
15,
|
|
45
|
+
f"Model is trained on {original_device_type}, but the device is not available - "
|
|
46
|
+
f"loading on {device.type}...",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return device.type
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def to_torch_device(cls, device: str):
|
|
53
|
+
import torch
|
|
54
|
+
return torch.device(device)
|
|
55
|
+
|
|
56
|
+
def get_device(self) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Returns torch.device(...) of the fitted model
|
|
59
|
+
|
|
60
|
+
Requires implementation by the inheriting model class.
|
|
61
|
+
Refer to overriding methods in existing models for reference implementations.
|
|
62
|
+
"""
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
def set_device(self, device: str):
|
|
66
|
+
if not isinstance(device, str):
|
|
67
|
+
device = device.type
|
|
68
|
+
self.device = device
|
|
69
|
+
self._set_device(device=device)
|
|
70
|
+
|
|
71
|
+
def _set_device(self, device: str):
|
|
72
|
+
"""
|
|
73
|
+
Sets the device for the inner model object.
|
|
74
|
+
|
|
75
|
+
Requires implementation by the inheriting model class.
|
|
76
|
+
Refer to overriding methods in existing models for reference implementations.
|
|
77
|
+
|
|
78
|
+
If your model does not need to edit inner model object details, you can simply make the logic `pass`.
|
|
79
|
+
"""
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
def _post_fit(self, **kwargs):
|
|
83
|
+
super()._post_fit(**kwargs)
|
|
84
|
+
if self._get_class_tags().get("can_set_device", False):
|
|
85
|
+
self.device_train = self.get_device()
|
|
86
|
+
self.device = self.device_train
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
def save(self, path: str = None, verbose=True) -> str:
|
|
90
|
+
"""
|
|
91
|
+
Need to set device to CPU to be able to load on a non-GPU environment
|
|
92
|
+
"""
|
|
93
|
+
reset_device = False
|
|
94
|
+
og_device = self.device
|
|
95
|
+
|
|
96
|
+
# Save on CPU to ensure the model can be loaded without GPU
|
|
97
|
+
if self.is_fit():
|
|
98
|
+
device_save = self._get_class_tags().get("set_device_on_save_to", None)
|
|
99
|
+
if device_save is not None:
|
|
100
|
+
self.set_device(device=device_save)
|
|
101
|
+
reset_device = True
|
|
102
|
+
path = super().save(path=path, verbose=verbose)
|
|
103
|
+
# Put the model back to the device after the save
|
|
104
|
+
if reset_device:
|
|
105
|
+
self.set_device(device=og_device)
|
|
106
|
+
return path
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def load(cls, path: str, reset_paths=True, verbose=True):
|
|
110
|
+
"""
|
|
111
|
+
Loads the model from disk to memory.
|
|
112
|
+
The loaded model will be on the same device it was trained on (cuda/mps);
|
|
113
|
+
if the device is not available (trained on GPU, deployed on CPU), then `cpu` will be used.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
path : str
|
|
118
|
+
Path to the saved model, minus the file name.
|
|
119
|
+
This should generally be a directory path ending with a '/' character (or appropriate path separator value depending on OS).
|
|
120
|
+
The model file is typically located in os.path.join(path, cls.model_file_name).
|
|
121
|
+
reset_paths : bool, default True
|
|
122
|
+
Whether to reset the self.path value of the loaded model to be equal to path.
|
|
123
|
+
It is highly recommended to keep this value as True unless accessing the original self.path value is important.
|
|
124
|
+
If False, the actual valid path and self.path may differ, leading to strange behaviour and potential exceptions if the model needs to load any other files at a later time.
|
|
125
|
+
verbose : bool, default True
|
|
126
|
+
Whether to log the location of the loaded file.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
model : cls
|
|
131
|
+
Loaded model object.
|
|
132
|
+
"""
|
|
133
|
+
model = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
|
|
134
|
+
|
|
135
|
+
# Put the model on the same device it was trained on (GPU/MPS) if it is available; otherwise use CPU
|
|
136
|
+
if model.is_fit() and model._get_class_tags().get("set_device_on_load", False):
|
|
137
|
+
device = model.suggest_device_infer(verbose=verbose)
|
|
138
|
+
model.set_device(device=device)
|
|
139
|
+
|
|
140
|
+
return model
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def _class_tags(cls):
|
|
144
|
+
return {
|
|
145
|
+
"can_set_device": True,
|
|
146
|
+
"set_device_on_save_to": "cpu",
|
|
147
|
+
"set_device_on_load": True,
|
|
148
|
+
}
|
|
@@ -146,7 +146,7 @@ class CatBoostModel(AbstractModel):
|
|
|
146
146
|
num_cols_train = len(X.columns)
|
|
147
147
|
num_classes = self.num_classes if self.num_classes else 1 # self.num_classes could be None after initialization if it's a regression problem
|
|
148
148
|
|
|
149
|
-
X = self.preprocess(X)
|
|
149
|
+
X = self.preprocess(X, y=y, is_train=True)
|
|
150
150
|
cat_features = list(X.select_dtypes(include="category").columns)
|
|
151
151
|
X = Pool(data=X, label=y, cat_features=cat_features, weight=sample_weight)
|
|
152
152
|
|
|
@@ -660,7 +660,11 @@ class NNFastAiTabularModel(AbstractModel):
|
|
|
660
660
|
|
|
661
661
|
@classmethod
|
|
662
662
|
def _class_tags(cls):
|
|
663
|
-
return {
|
|
663
|
+
return {
|
|
664
|
+
"can_estimate_memory_usage_static": True,
|
|
665
|
+
"reset_torch_threads": True,
|
|
666
|
+
"reset_torch_cudnn_deterministic": True,
|
|
667
|
+
}
|
|
664
668
|
|
|
665
669
|
def _more_tags(self):
|
|
666
670
|
return {"can_refit_full": True}
|
|
@@ -103,10 +103,46 @@ class LGBModel(AbstractModel):
|
|
|
103
103
|
Scales linearly with the number of estimators, number of classes, and number of leaves.
|
|
104
104
|
Memory usage peaks during model saving, with the peak consuming approximately 2-4x the size of the model in memory.
|
|
105
105
|
"""
|
|
106
|
+
data_mem_usage = get_approximate_df_mem_usage(X).sum()
|
|
107
|
+
return cls._estimate_memory_usage_common(
|
|
108
|
+
num_features=X.shape[1],
|
|
109
|
+
data_mem_usage=data_mem_usage,
|
|
110
|
+
hyperparameters=hyperparameters,
|
|
111
|
+
num_classes=num_classes,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def _estimate_memory_usage_static_lite(
|
|
116
|
+
cls,
|
|
117
|
+
num_samples: int,
|
|
118
|
+
num_features: int,
|
|
119
|
+
num_bytes_per_cell: float = 4,
|
|
120
|
+
hyperparameters: dict = None,
|
|
121
|
+
num_classes: int = 1,
|
|
122
|
+
**kwargs,
|
|
123
|
+
) -> int:
|
|
124
|
+
data_mem_usage = num_samples * num_features * num_bytes_per_cell
|
|
125
|
+
return cls._estimate_memory_usage_common(
|
|
126
|
+
num_features=num_features,
|
|
127
|
+
data_mem_usage=data_mem_usage,
|
|
128
|
+
hyperparameters=hyperparameters,
|
|
129
|
+
num_classes=num_classes,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def _estimate_memory_usage_common(
|
|
134
|
+
cls,
|
|
135
|
+
num_features: int,
|
|
136
|
+
data_mem_usage: int | float,
|
|
137
|
+
hyperparameters: dict | None = None,
|
|
138
|
+
num_classes: int = 1,
|
|
139
|
+
) -> int:
|
|
140
|
+
"""
|
|
141
|
+
Utility method to avoid code duplication
|
|
142
|
+
"""
|
|
106
143
|
if hyperparameters is None:
|
|
107
144
|
hyperparameters = {}
|
|
108
145
|
num_classes = num_classes if num_classes else 1 # num_classes could be None after initialization if it's a regression problem
|
|
109
|
-
data_mem_usage = get_approximate_df_mem_usage(X).sum()
|
|
110
146
|
data_mem_usage_bytes = data_mem_usage * 5 + data_mem_usage / 4 * num_classes # TODO: Extremely crude approximation, can be vastly improved
|
|
111
147
|
|
|
112
148
|
n_trees_per_estimator = num_classes if num_classes > 2 else 1
|
|
@@ -114,7 +150,7 @@ class LGBModel(AbstractModel):
|
|
|
114
150
|
max_bins = hyperparameters.get("max_bins", 255)
|
|
115
151
|
num_leaves = hyperparameters.get("num_leaves", 31)
|
|
116
152
|
# Memory usage of histogram based on https://github.com/microsoft/LightGBM/issues/562#issuecomment-304524592
|
|
117
|
-
histogram_mem_usage_bytes = 20 * max_bins *
|
|
153
|
+
histogram_mem_usage_bytes = 20 * max_bins * num_features * num_leaves
|
|
118
154
|
histogram_mem_usage_bytes_max = hyperparameters.get("histogram_pool_size", None)
|
|
119
155
|
if histogram_mem_usage_bytes_max is not None:
|
|
120
156
|
histogram_mem_usage_bytes_max *= 1e6 # Convert megabytes to bytes, `histogram_pool_size` is in MB.
|
|
@@ -124,11 +160,11 @@ class LGBModel(AbstractModel):
|
|
|
124
160
|
|
|
125
161
|
mem_size_per_estimator = n_trees_per_estimator * num_leaves * 100 # very rough estimate
|
|
126
162
|
n_estimators = hyperparameters.get("num_boost_round", DEFAULT_NUM_BOOST_ROUND)
|
|
127
|
-
n_estimators_min = min(n_estimators,
|
|
128
|
-
mem_size_estimators = n_estimators_min * mem_size_per_estimator # memory estimate after fitting up to
|
|
163
|
+
n_estimators_min = min(n_estimators, 5000)
|
|
164
|
+
mem_size_estimators = n_estimators_min * mem_size_per_estimator # memory estimate after fitting up to 5000 estimators
|
|
129
165
|
|
|
130
166
|
approx_mem_size_req = data_mem_usage_bytes + histogram_mem_usage_bytes + mem_size_estimators
|
|
131
|
-
return approx_mem_size_req
|
|
167
|
+
return int(approx_mem_size_req)
|
|
132
168
|
|
|
133
169
|
def _fit(self, X, y, X_val=None, y_val=None, time_limit=None, num_gpus=0, num_cpus=0, sample_weight=None, sample_weight_val=None, verbosity=2, **kwargs):
|
|
134
170
|
try_import_lightgbm() # raise helpful error message if LightGBM isn't installed
|
|
@@ -371,6 +407,9 @@ class LGBModel(AbstractModel):
|
|
|
371
407
|
X = self.preprocess(X, **kwargs)
|
|
372
408
|
|
|
373
409
|
y_pred_proba = self.model.predict(X, num_threads=num_cpus)
|
|
410
|
+
return self._post_process_predictions(y_pred_proba=y_pred_proba)
|
|
411
|
+
|
|
412
|
+
def _post_process_predictions(self, y_pred_proba) -> np.ndarray:
|
|
374
413
|
if self.problem_type == QUANTILE:
|
|
375
414
|
# y_pred_proba is a pd.DataFrame, need to convert
|
|
376
415
|
y_pred_proba = y_pred_proba.to_numpy()
|
|
@@ -423,7 +462,7 @@ class LGBModel(AbstractModel):
|
|
|
423
462
|
self,
|
|
424
463
|
X: DataFrame,
|
|
425
464
|
y: Series,
|
|
426
|
-
params,
|
|
465
|
+
params: dict,
|
|
427
466
|
X_val=None,
|
|
428
467
|
y_val=None,
|
|
429
468
|
X_test=None,
|
|
@@ -432,11 +471,14 @@ class LGBModel(AbstractModel):
|
|
|
432
471
|
sample_weight_val=None,
|
|
433
472
|
sample_weight_test=None,
|
|
434
473
|
save=False,
|
|
474
|
+
init_train=None,
|
|
475
|
+
init_val=None,
|
|
476
|
+
init_test=None,
|
|
435
477
|
):
|
|
436
478
|
lgb_dataset_params_keys = ["two_round"] # Keys that are specific to lightGBM Dataset object construction.
|
|
437
479
|
data_params = {key: params[key] for key in lgb_dataset_params_keys if key in params}.copy()
|
|
438
480
|
|
|
439
|
-
X = self.preprocess(X, is_train=True)
|
|
481
|
+
X = self.preprocess(X, y=y, is_train=True)
|
|
440
482
|
if X_val is not None:
|
|
441
483
|
X_val = self.preprocess(X_val)
|
|
442
484
|
if X_test is not None:
|
|
@@ -458,7 +500,13 @@ class LGBModel(AbstractModel):
|
|
|
458
500
|
|
|
459
501
|
# X, W_train = self.convert_to_weight(X=X)
|
|
460
502
|
dataset_train = construct_dataset(
|
|
461
|
-
x=X,
|
|
503
|
+
x=X,
|
|
504
|
+
y=y,
|
|
505
|
+
location=os.path.join("self.path", "datasets", "train"),
|
|
506
|
+
params=data_params,
|
|
507
|
+
save=save,
|
|
508
|
+
weight=sample_weight,
|
|
509
|
+
init_score=init_train,
|
|
462
510
|
)
|
|
463
511
|
# dataset_train = construct_dataset_lowest_memory(X=X, y=y, location=self.path + 'datasets/train', params=data_params)
|
|
464
512
|
if X_val is not None:
|
|
@@ -471,6 +519,7 @@ class LGBModel(AbstractModel):
|
|
|
471
519
|
params=data_params,
|
|
472
520
|
save=save,
|
|
473
521
|
weight=sample_weight_val,
|
|
522
|
+
init_score=init_val,
|
|
474
523
|
)
|
|
475
524
|
# dataset_val = construct_dataset_lowest_memory(X=X_val, y=y_val, location=self.path + 'datasets/val', reference=dataset_train, params=data_params)
|
|
476
525
|
else:
|
|
@@ -485,6 +534,7 @@ class LGBModel(AbstractModel):
|
|
|
485
534
|
params=data_params,
|
|
486
535
|
save=save,
|
|
487
536
|
weight=sample_weight_test,
|
|
537
|
+
init_score=init_test,
|
|
488
538
|
)
|
|
489
539
|
else:
|
|
490
540
|
dataset_test = None
|
|
@@ -104,11 +104,11 @@ def softclass_lgbobj(preds, train_data):
|
|
|
104
104
|
return grad.flatten("F"), hess.flatten("F")
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
def construct_dataset(x: DataFrame, y: Series, location=None, reference=None, params=None, save=False, weight=None):
|
|
107
|
+
def construct_dataset(x: DataFrame, y: Series, location=None, reference=None, params=None, save=False, weight=None, init_score=None):
|
|
108
108
|
try_import_lightgbm()
|
|
109
109
|
import lightgbm as lgb
|
|
110
110
|
|
|
111
|
-
dataset = lgb.Dataset(data=x, label=y, reference=reference, free_raw_data=True, params=params, weight=weight)
|
|
111
|
+
dataset = lgb.Dataset(data=x, label=y, reference=reference, free_raw_data=True, params=params, weight=weight, init_score=init_score)
|
|
112
112
|
|
|
113
113
|
if save:
|
|
114
114
|
assert location is not None
|
|
@@ -73,6 +73,20 @@ class TrainerFinetune(BaseEstimator):
|
|
|
73
73
|
|
|
74
74
|
self.metric = self.cfg.hyperparams['metric']
|
|
75
75
|
|
|
76
|
+
def set_device(self, device: str):
|
|
77
|
+
self.device = device
|
|
78
|
+
self.model = self.model.to(device=device, non_blocking=True)
|
|
79
|
+
|
|
80
|
+
def post_fit_optimize(self):
|
|
81
|
+
# Minimize memory usage post-fit
|
|
82
|
+
self.checkpoint = None
|
|
83
|
+
self.optimizer = None
|
|
84
|
+
self.scaler = None
|
|
85
|
+
self.scheduler_warmup = None
|
|
86
|
+
self.scheduler_reduce_on_plateau = None
|
|
87
|
+
self.loss = None
|
|
88
|
+
self.early_stopping = None
|
|
89
|
+
self.metric = None
|
|
76
90
|
|
|
77
91
|
def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
|
|
78
92
|
|
|
@@ -184,7 +198,6 @@ class TrainerFinetune(BaseEstimator):
|
|
|
184
198
|
|
|
185
199
|
self.checkpoint.set_to_best(self.model)
|
|
186
200
|
|
|
187
|
-
|
|
188
201
|
def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
|
|
189
202
|
|
|
190
203
|
self.model.eval()
|