autogluon.tabular 1.2.1b20250225__tar.gz → 1.2.1b20250226__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/PKG-INFO +1 -1
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/automm/automm_model.py +2 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/automm/ft_transformer.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/catboost_model.py +7 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/tabular_nn_fastai.py +10 -1
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/fasttext_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/image_prediction/image_predictor.py +2 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/imodels/imodels_models.py +15 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/knn_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/lgb_model.py +7 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/lr_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/rf_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_transformer_model.py +2 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfn/tabpfn_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +4 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/text_prediction/text_prediction_v1_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/vowpalwabbit/vowpalwabbit_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/xgboost_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xt/xt_model.py +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/predictor/predictor.py +19 -3
- autogluon.tabular-1.2.1b20250226/src/autogluon/tabular/register/__init__.py +2 -0
- autogluon.tabular-1.2.1b20250226/src/autogluon/tabular/register/_ag_model_register.py +66 -0
- autogluon.tabular-1.2.1b20250226/src/autogluon/tabular/register/_model_register.py +146 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/model_presets/presets.py +10 -116
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/PKG-INFO +1 -1
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/SOURCES.txt +3 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/requires.txt +11 -11
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/setup.cfg +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/setup.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/config_helper.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/feature_generator_presets.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/hyperparameter_configs.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/presets_configs.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/zeroshot/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/_scikit_mixin.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/_tabular_classifier.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/_tabular_regressor.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/plot_leaderboard.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/learner/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/learner/abstract_learner.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/learner/default_learner.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/_utils/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/_utils/rapids_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/_utils/torch_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/automm/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/callbacks.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/catboost_softclass_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/catboost_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/callbacks.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/fastai_helpers.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/imports_helper.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/quantile_helpers.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/image_prediction/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/imodels/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/_knn_loo_variants.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/knn_rapids_model.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/knn_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/callbacks.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/lgb_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/lr_preprocessing_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/lr_rapids_model.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/compilers/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/compilers/native.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/compilers/onnx.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/rf_quantile.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/rf_rapids_model.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/modified_transformer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/pretexts.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_model_base.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_transformer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_transformer_encoder.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfn/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/config/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/config/config_run.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/data/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/results/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/compilers/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/compilers/native.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/compilers/onnx.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/text_prediction/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/vowpalwabbit/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/vowpalwabbit/vowpalwabbit_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/callbacks.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/hyperparameters/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/hyperparameters/parameters.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/hyperparameters/searchspaces.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/xgboost_utils.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xt/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/predictor/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/predictor/interpretable_predictor.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/abstract_trainer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/auto_trainer.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/model_presets/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/model_presets/presets_distill.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/tuning/__init__.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/tuning/feature_pruner.py +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/dependency_links.txt +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/namespace_packages.txt +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/top_level.txt +0 -0
- {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/zip-safe +0 -0
@@ -12,6 +12,9 @@ logger = logging.getLogger(__name__)
|
|
12
12
|
|
13
13
|
# TODO: Add unit tests
|
14
14
|
class FTTransformerModel(MultiModalPredictorModel):
|
15
|
+
ag_key = "FT_TRANSFORMER"
|
16
|
+
ag_name = "FTTransformer"
|
17
|
+
|
15
18
|
def __init__(self, **kwargs):
|
16
19
|
"""Wrapper of autogluon.multimodal.MultiModalPredictor.
|
17
20
|
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
import math
|
3
3
|
import os
|
4
4
|
import time
|
5
|
+
from types import MappingProxyType
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
@@ -30,6 +31,12 @@ class CatBoostModel(AbstractModel):
|
|
30
31
|
|
31
32
|
Hyperparameter options: https://catboost.ai/en/docs/references/training-parameters
|
32
33
|
"""
|
34
|
+
ag_key = "CAT"
|
35
|
+
ag_name = "CatBoost"
|
36
|
+
ag_priority = 70
|
37
|
+
ag_priority_by_problem_type = MappingProxyType({
|
38
|
+
SOFTCLASS: 60
|
39
|
+
})
|
33
40
|
|
34
41
|
def __init__(self, **kwargs):
|
35
42
|
super().__init__(**kwargs)
|
@@ -8,6 +8,7 @@ import warnings
|
|
8
8
|
from builtins import classmethod
|
9
9
|
from functools import partial
|
10
10
|
from pathlib import Path
|
11
|
+
from types import MappingProxyType
|
11
12
|
from typing import Union
|
12
13
|
|
13
14
|
import numpy as np
|
@@ -28,7 +29,7 @@ from autogluon.common.features.types import (
|
|
28
29
|
from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
|
29
30
|
from autogluon.common.utils.resource_utils import ResourceManager
|
30
31
|
from autogluon.common.utils.try_import import try_import_fastai
|
31
|
-
from autogluon.core.constants import BINARY, QUANTILE, REGRESSION
|
32
|
+
from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REGRESSION
|
32
33
|
from autogluon.core.hpo.constants import RAY_BACKEND
|
33
34
|
from autogluon.core.models import AbstractModel
|
34
35
|
from autogluon.core.utils.exceptions import TimeLimitExceeded
|
@@ -92,6 +93,14 @@ class NNFastAiTabularModel(AbstractModel):
|
|
92
93
|
'early.stopping.min_delta': 0.0001,
|
93
94
|
'early.stopping.patience': 10,
|
94
95
|
"""
|
96
|
+
ag_key = "FASTAI"
|
97
|
+
ag_name = "NeuralNetFastAI"
|
98
|
+
ag_priority = 50
|
99
|
+
# Increase priority for multiclass since neural networks
|
100
|
+
# scale better than trees as a function of n_classes.
|
101
|
+
ag_priority_by_problem_type = MappingProxyType({
|
102
|
+
MULTICLASS: 95,
|
103
|
+
})
|
95
104
|
|
96
105
|
model_internals_file_name = "model-internals.pkl"
|
97
106
|
|
@@ -22,6 +22,8 @@ class ImagePredictorModel(MultiModalPredictorModel):
|
|
22
22
|
Additionally has special null image handling to improve performance in the presence of null images (aka image path of '')
|
23
23
|
Note: null handling has not been compared to the built-in null handling of MultimodalPredictor yet.
|
24
24
|
"""
|
25
|
+
ag_key = "AG_IMAGE_NN"
|
26
|
+
ag_name = "ImagePredictor"
|
25
27
|
|
26
28
|
def __init__(self, **kwargs):
|
27
29
|
super().__init__(**kwargs)
|
@@ -75,6 +75,9 @@ class _IModelsModel(AbstractModel):
|
|
75
75
|
|
76
76
|
|
77
77
|
class RuleFitModel(_IModelsModel):
|
78
|
+
ag_key = "IM_RULEFIT"
|
79
|
+
ag_name = "RuleFit"
|
80
|
+
|
78
81
|
def get_model(self):
|
79
82
|
try_import_imodels()
|
80
83
|
from imodels import RuleFitClassifier, RuleFitRegressor
|
@@ -86,6 +89,9 @@ class RuleFitModel(_IModelsModel):
|
|
86
89
|
|
87
90
|
|
88
91
|
class GreedyTreeModel(_IModelsModel):
|
92
|
+
ag_key = "IM_GREEDYTREE"
|
93
|
+
ag_name = "GreedyTree"
|
94
|
+
|
89
95
|
def get_model(self):
|
90
96
|
try_import_imodels()
|
91
97
|
from imodels import GreedyTreeClassifier
|
@@ -98,6 +104,9 @@ class GreedyTreeModel(_IModelsModel):
|
|
98
104
|
|
99
105
|
|
100
106
|
class BoostedRulesModel(_IModelsModel):
|
107
|
+
ag_key = "IM_BOOSTEDRULES"
|
108
|
+
ag_name = "BoostedRules"
|
109
|
+
|
101
110
|
def get_model(self):
|
102
111
|
try_import_imodels()
|
103
112
|
from imodels import BoostedRulesClassifier
|
@@ -109,6 +118,9 @@ class BoostedRulesModel(_IModelsModel):
|
|
109
118
|
|
110
119
|
|
111
120
|
class HSTreeModel(_IModelsModel):
|
121
|
+
ag_key = "IM_HSTREE"
|
122
|
+
ag_name = "HierarchicalShrinkageTree"
|
123
|
+
|
112
124
|
def get_model(self):
|
113
125
|
try_import_imodels()
|
114
126
|
from imodels import HSTreeClassifierCV, HSTreeRegressorCV
|
@@ -120,6 +132,9 @@ class HSTreeModel(_IModelsModel):
|
|
120
132
|
|
121
133
|
|
122
134
|
class FigsModel(_IModelsModel):
|
135
|
+
ag_key = "IM_FIGS"
|
136
|
+
ag_name = "Figs"
|
137
|
+
|
123
138
|
def get_model(self):
|
124
139
|
try_import_imodels()
|
125
140
|
from imodels import FIGSClassifier, FIGSRegressor
|
@@ -22,6 +22,9 @@ class KNNModel(AbstractModel):
|
|
22
22
|
"""
|
23
23
|
KNearestNeighbors model (scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
|
24
24
|
"""
|
25
|
+
ag_key = "KNN"
|
26
|
+
ag_name = "KNeighbors"
|
27
|
+
ag_priority = 100
|
25
28
|
|
26
29
|
def __init__(self, **kwargs):
|
27
30
|
super().__init__(**kwargs)
|
@@ -7,6 +7,7 @@ import random
|
|
7
7
|
import re
|
8
8
|
import time
|
9
9
|
import warnings
|
10
|
+
from types import MappingProxyType
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import pandas as pd
|
@@ -40,6 +41,12 @@ class LGBModel(AbstractModel):
|
|
40
41
|
Extra hyperparameter options:
|
41
42
|
ag.early_stop : int, specifies the early stopping rounds. Defaults to an adaptive strategy. Recommended to keep default.
|
42
43
|
"""
|
44
|
+
ag_key = "GBM"
|
45
|
+
ag_name = "LightGBM"
|
46
|
+
ag_priority = 90
|
47
|
+
ag_priority_by_problem_type = MappingProxyType({
|
48
|
+
SOFTCLASS: 100
|
49
|
+
})
|
43
50
|
|
44
51
|
def __init__(self, **kwargs):
|
45
52
|
super().__init__(**kwargs)
|
@@ -38,6 +38,9 @@ class LinearModel(AbstractModel):
|
|
38
38
|
|
39
39
|
'regression': https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html#sklearn.linear_model.Ridge
|
40
40
|
"""
|
41
|
+
ag_key = "LR"
|
42
|
+
ag_name = "LinearModel"
|
43
|
+
ag_priority = 30
|
41
44
|
|
42
45
|
def __init__(self, **kwargs):
|
43
46
|
super().__init__(**kwargs)
|
@@ -27,6 +27,9 @@ class RFModel(AbstractModel):
|
|
27
27
|
"""
|
28
28
|
Random Forest model (scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
|
29
29
|
"""
|
30
|
+
ag_key = "RF"
|
31
|
+
ag_name = "RandomForest"
|
32
|
+
ag_priority = 80
|
30
33
|
|
31
34
|
def __init__(self, **kwargs):
|
32
35
|
super().__init__(**kwargs)
|
@@ -40,6 +40,8 @@ class TabTransformerModel(AbstractNeuralNetworkModel):
|
|
40
40
|
and applies them to the use case of tabular data. Specifically, this makes TabTransformer suitable for unsupervised
|
41
41
|
training of Tabular data with a subsequent fine-tuning step on labeled data.
|
42
42
|
"""
|
43
|
+
ag_key = "TRANSF"
|
44
|
+
ag_name = "Transformer"
|
43
45
|
|
44
46
|
params_file_name = "tab_trans_params.pth"
|
45
47
|
|
@@ -22,6 +22,9 @@ class TabPFNModel(AbstractModel):
|
|
22
22
|
To use this model, `tabpfn` must be installed.
|
23
23
|
To install TabPFN, you can run `pip install autogluon.tabular[tabpfn]` or `pip install tabpfn`.
|
24
24
|
"""
|
25
|
+
ag_key = "TABPFN"
|
26
|
+
ag_name = "TabPFN"
|
27
|
+
ag_priority = 110
|
25
28
|
|
26
29
|
def __init__(self, **kwargs):
|
27
30
|
super().__init__(**kwargs)
|
@@ -35,6 +35,10 @@ class TabPFNMixModel(AbstractModel):
|
|
35
35
|
|
36
36
|
For more information, refer to the `./_internals/README.md` file.
|
37
37
|
"""
|
38
|
+
ag_key = "TABPFNMIX"
|
39
|
+
ag_name = "TabPFNMix"
|
40
|
+
ag_priority = 45
|
41
|
+
|
38
42
|
weights_file_name = "model.pt"
|
39
43
|
|
40
44
|
def __init__(self, **kwargs):
|
@@ -47,6 +47,9 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
|
|
47
47
|
ag.early_stop : int | str, default = "default"
|
48
48
|
Specifies the early stopping rounds. Defaults to an adaptive strategy. Recommended to keep default.
|
49
49
|
"""
|
50
|
+
ag_key = "NN_TORCH"
|
51
|
+
ag_name = "NeuralNetTorch"
|
52
|
+
ag_priority = 25
|
50
53
|
|
51
54
|
# Constants used throughout this class:
|
52
55
|
unique_category_str = np.nan # string used to represent missing values and unknown categories for categorical features.
|
@@ -19,6 +19,9 @@ logger = logging.getLogger(__name__)
|
|
19
19
|
class TextPredictorModel(MultiModalPredictorModel):
|
20
20
|
"""MultimodalPredictor that doesn't use image features"""
|
21
21
|
|
22
|
+
ag_key = "AG_TEXT_NN"
|
23
|
+
ag_name = "TextPredictor"
|
24
|
+
|
22
25
|
def _get_default_auxiliary_params(self) -> dict:
|
23
26
|
default_auxiliary_params = super()._get_default_auxiliary_params()
|
24
27
|
extra_auxiliary_params = dict(
|
@@ -38,6 +38,9 @@ class VowpalWabbitModel(AbstractModel):
|
|
38
38
|
VowpalWabbit Command Line args: https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Command-line-arguments
|
39
39
|
|
40
40
|
"""
|
41
|
+
ag_key = "VW"
|
42
|
+
ag_name = "VowpalWabbit"
|
43
|
+
ag_priority = 10
|
41
44
|
|
42
45
|
model_internals_file_name = "model-internals.pkl"
|
43
46
|
|
@@ -7,6 +7,9 @@ class XTModel(RFModel):
|
|
7
7
|
"""
|
8
8
|
Extra Trees model (scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html#sklearn.ensemble.ExtraTreesClassifier
|
9
9
|
"""
|
10
|
+
ag_key = "XT"
|
11
|
+
ag_name = "ExtraTrees"
|
12
|
+
ag_priority = 60
|
10
13
|
|
11
14
|
def _get_model_type(self):
|
12
15
|
if self.problem_type == REGRESSION:
|
@@ -4330,7 +4330,14 @@ class TabularPredictor:
|
|
4330
4330
|
reduce_children=reduce_children,
|
4331
4331
|
)
|
4332
4332
|
|
4333
|
-
def delete_models(
|
4333
|
+
def delete_models(
|
4334
|
+
self,
|
4335
|
+
models_to_keep: str | list[str] | None = None,
|
4336
|
+
models_to_delete: str | list[str] | None = None,
|
4337
|
+
allow_delete_cascade: bool = False,
|
4338
|
+
delete_from_disk: bool = True,
|
4339
|
+
dry_run: bool | None = None,
|
4340
|
+
):
|
4334
4341
|
"""
|
4335
4342
|
Deletes models from `predictor`.
|
4336
4343
|
This can be helpful to minimize memory usage and disk usage, particularly for model deployment.
|
@@ -4341,13 +4348,13 @@ class TabularPredictor:
|
|
4341
4348
|
|
4342
4349
|
Parameters
|
4343
4350
|
----------
|
4344
|
-
models_to_keep : str or list, default = None
|
4351
|
+
models_to_keep : str or list[str], default = None
|
4345
4352
|
Name of model or models to not delete.
|
4346
4353
|
All models that are not specified and are also not required as a dependency of any model in `models_to_keep` will be deleted.
|
4347
4354
|
Specify `models_to_keep='best'` to keep only the best model and its model dependencies.
|
4348
4355
|
`models_to_delete` must be None if `models_to_keep` is set.
|
4349
4356
|
To see the list of possible model names, use: `predictor.model_names()` or `predictor.leaderboard()`.
|
4350
|
-
models_to_delete : str or list, default = None
|
4357
|
+
models_to_delete : str or list[str], default = None
|
4351
4358
|
Name of model or models to delete.
|
4352
4359
|
All models that are not specified but depend on a model in `models_to_delete` will also be deleted.
|
4353
4360
|
`models_to_keep` must be None if `models_to_delete` is set.
|
@@ -4361,10 +4368,19 @@ class TabularPredictor:
|
|
4361
4368
|
WARNING: This deletes the entire directory for the deleted models, and ALL FILES located there.
|
4362
4369
|
It is highly recommended to first run with `dry_run=True` to understand which directories will be deleted.
|
4363
4370
|
dry_run : bool, default = True
|
4371
|
+
WARNING: Starting in v1.4.0 dry_run will default to False.
|
4364
4372
|
If `True`, then deletions don't occur, and logging statements are printed describing what would have occurred.
|
4365
4373
|
Set `dry_run=False` to perform the deletions.
|
4366
4374
|
|
4367
4375
|
"""
|
4376
|
+
if dry_run is None:
|
4377
|
+
warnings.warn(
|
4378
|
+
f"dry_run was not specified for `TabularPredictor.delete_models`. dry_run prior to version 1.4.0 defaults to True. "
|
4379
|
+
f"Starting in version 1.4, AutoGluon will default dry_run to False. "
|
4380
|
+
f"If you want to maintain the current logic in future versions, explicitly specify `dry_run=True`.",
|
4381
|
+
category=FutureWarning,
|
4382
|
+
)
|
4383
|
+
dry_run = True
|
4368
4384
|
self._assert_is_fit("delete_models")
|
4369
4385
|
if models_to_keep == "best":
|
4370
4386
|
models_to_keep = self.model_best
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from autogluon.core.models import (
|
2
|
+
DummyModel,
|
3
|
+
GreedyWeightedEnsembleModel,
|
4
|
+
SimpleWeightedEnsembleModel,
|
5
|
+
)
|
6
|
+
|
7
|
+
from . import ModelRegister
|
8
|
+
from ..models import (
|
9
|
+
BoostedRulesModel,
|
10
|
+
CatBoostModel,
|
11
|
+
FastTextModel,
|
12
|
+
FigsModel,
|
13
|
+
FTTransformerModel,
|
14
|
+
GreedyTreeModel,
|
15
|
+
HSTreeModel,
|
16
|
+
ImagePredictorModel,
|
17
|
+
KNNModel,
|
18
|
+
LGBModel,
|
19
|
+
LinearModel,
|
20
|
+
MultiModalPredictorModel,
|
21
|
+
NNFastAiTabularModel,
|
22
|
+
RFModel,
|
23
|
+
RuleFitModel,
|
24
|
+
TabPFNMixModel,
|
25
|
+
TabPFNModel,
|
26
|
+
TabularNeuralNetTorchModel,
|
27
|
+
TextPredictorModel,
|
28
|
+
VowpalWabbitModel,
|
29
|
+
XGBoostModel,
|
30
|
+
XTModel,
|
31
|
+
)
|
32
|
+
from ..models.tab_transformer.tab_transformer_model import TabTransformerModel
|
33
|
+
|
34
|
+
|
35
|
+
# When adding a new model officially to AutoGluon, the model class should be added to the bottom of this list.
|
36
|
+
REGISTERED_MODEL_CLS_LST = [
|
37
|
+
RFModel,
|
38
|
+
XTModel,
|
39
|
+
KNNModel,
|
40
|
+
LGBModel,
|
41
|
+
CatBoostModel,
|
42
|
+
XGBoostModel,
|
43
|
+
TabularNeuralNetTorchModel,
|
44
|
+
LinearModel,
|
45
|
+
NNFastAiTabularModel,
|
46
|
+
TabTransformerModel,
|
47
|
+
TextPredictorModel,
|
48
|
+
ImagePredictorModel,
|
49
|
+
MultiModalPredictorModel,
|
50
|
+
FTTransformerModel,
|
51
|
+
TabPFNModel,
|
52
|
+
TabPFNMixModel,
|
53
|
+
FastTextModel,
|
54
|
+
VowpalWabbitModel,
|
55
|
+
GreedyWeightedEnsembleModel,
|
56
|
+
SimpleWeightedEnsembleModel,
|
57
|
+
RuleFitModel,
|
58
|
+
GreedyTreeModel,
|
59
|
+
FigsModel,
|
60
|
+
HSTreeModel,
|
61
|
+
BoostedRulesModel,
|
62
|
+
DummyModel,
|
63
|
+
]
|
64
|
+
|
65
|
+
# TODO: Replace logic in `autogluon.tabular.trainer.model_presets.presets` with `ag_model_register`
|
66
|
+
ag_model_register = ModelRegister(model_cls_list=REGISTERED_MODEL_CLS_LST)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Type
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from autogluon.core.models import AbstractModel
|
8
|
+
|
9
|
+
|
10
|
+
# TODO: Move to core? Maybe TimeSeries can reuse?
|
11
|
+
# TODO: Use this / refer to this in the custom model tutorial
|
12
|
+
# TODO: Add to documentation website
|
13
|
+
# TODO: Test register logic in AG
|
14
|
+
class ModelRegister:
|
15
|
+
"""
|
16
|
+
ModelRegister keeps track of all known model classes to AutoGluon.
|
17
|
+
It can provide information such as:
|
18
|
+
What model classes and keys are valid to specify in an AutoGluon predictor fit call.
|
19
|
+
What a model's name is.
|
20
|
+
What a model's key is (such as the key specified by the user in `hyperparameters` to refer to a specific model type).
|
21
|
+
What a model's priority is (aka which order to fit a list of models).
|
22
|
+
|
23
|
+
Additionally, users can register custom models to AutoGluon so the key is recognized in `hyperparameters` and is treated with the proper priority and name.
|
24
|
+
They can register new models via `ModelRegister.add(model_cls)`.
|
25
|
+
|
26
|
+
Therefore, if a user creates a custom model `MyCustomModel` that inherits from `AbstractModel`, they can set the class attributes in `MyCustomModel`:
|
27
|
+
ag_key: The string key that can be specified in `hyperparameters`. Example: "GBM" for LGBModel
|
28
|
+
ag_name: The string name that is used in logging and accessing the model. Example: "LightGBM" for LGBModel
|
29
|
+
ag_priority: The int priority that is used to order the fitting of models. Higher values will be fit before lower values. Default 0. Example: 90 for LGBModel
|
30
|
+
ag_priority_to_problem_type: A dictionary of problem_type to priority that overrides `ag_priority` if specified for a given problem_type. Optional.
|
31
|
+
|
32
|
+
Then they can say `ag_model_register.add(MyCustomModel)`.
|
33
|
+
Assuming MyCustomModel.ag_key = "MY_MODEL", they can now do:
|
34
|
+
```
|
35
|
+
predictor.fit(..., hyperparameters={"MY_MODEL": ...})
|
36
|
+
```
|
37
|
+
"""
|
38
|
+
def __init__(self, model_cls_list: list[Type[AbstractModel]] | None = None):
|
39
|
+
if model_cls_list is None:
|
40
|
+
model_cls_list = []
|
41
|
+
assert isinstance(model_cls_list, list)
|
42
|
+
self._model_cls_list = []
|
43
|
+
self._key_to_cls_map = dict()
|
44
|
+
for model_cls in model_cls_list:
|
45
|
+
self.add(model_cls)
|
46
|
+
|
47
|
+
def exists(self, model_cls: Type[AbstractModel]) -> bool:
|
48
|
+
return model_cls in self._model_cls_list
|
49
|
+
|
50
|
+
def add(self, model_cls: Type[AbstractModel]):
|
51
|
+
"""
|
52
|
+
Adds `model_cls` to the model register
|
53
|
+
"""
|
54
|
+
assert not self.exists(model_cls), f"Cannot add model_cls that is already registered: {model_cls}"
|
55
|
+
if model_cls.ag_key is None:
|
56
|
+
raise AssertionError(
|
57
|
+
f"Cannot add model_cls with `ag_key=None`. "
|
58
|
+
f"Ensure you set class attribute `ag_key` to a string for your model_cls: {model_cls}"
|
59
|
+
f'\n\tFor example, LightGBModel sets `ag_key = "GBM"`'
|
60
|
+
)
|
61
|
+
if model_cls.ag_name is None:
|
62
|
+
raise AssertionError(
|
63
|
+
f"Cannot add model_cls with `ag_name=None`. "
|
64
|
+
f"Ensure you set class attribute `ag_name` to a string for your model_cls: {model_cls}"
|
65
|
+
f'\n\tFor example, LightGBModel sets `ag_name = "LightGBM"`'
|
66
|
+
)
|
67
|
+
assert isinstance(model_cls.ag_key, str)
|
68
|
+
assert isinstance(model_cls.ag_name, str)
|
69
|
+
assert isinstance(model_cls.ag_priority, int)
|
70
|
+
if model_cls.ag_key in self._key_to_cls_map:
|
71
|
+
raise AssertionError(
|
72
|
+
f"Cannot register a model class that shares a model key with an already registered model class."
|
73
|
+
f"\n`model_cls.ag_key` must be unique among registered models:"
|
74
|
+
f"\n\t New Class: {model_cls}"
|
75
|
+
f"\n\tConflicting Class: {self._key_to_cls_map[model_cls.ag_key]}"
|
76
|
+
f"\n\tConflicting ag_key: {model_cls.ag_key}"
|
77
|
+
)
|
78
|
+
self._model_cls_list.append(model_cls)
|
79
|
+
self._key_to_cls_map[model_cls.ag_key] = model_cls
|
80
|
+
|
81
|
+
def remove(self, model_cls: Type[AbstractModel]):
|
82
|
+
"""
|
83
|
+
Removes `model_cls` from the model register
|
84
|
+
"""
|
85
|
+
assert self.exists(model_cls), f"Cannot remove model_cls that isn't registered: {model_cls}"
|
86
|
+
self._model_cls_list = [m for m in self._model_cls_list if m != model_cls]
|
87
|
+
self._key_to_cls_map.pop(model_cls.ag_key)
|
88
|
+
|
89
|
+
@property
|
90
|
+
def model_cls_list(self) -> list[Type[AbstractModel]]:
|
91
|
+
return self._model_cls_list
|
92
|
+
|
93
|
+
@property
|
94
|
+
def keys(self) -> list[str]:
|
95
|
+
return [self.key(model_cls) for model_cls in self.model_cls_list]
|
96
|
+
|
97
|
+
def key_to_cls_map(self) -> dict[str, Type[AbstractModel]]:
|
98
|
+
return self._key_to_cls_map
|
99
|
+
|
100
|
+
def key_to_cls(self, key: str) -> Type[AbstractModel]:
|
101
|
+
if key not in self._key_to_cls_map:
|
102
|
+
raise ValueError(
|
103
|
+
f"No registered model exists with provided key: {key}"
|
104
|
+
f"\n\tValid keys: {list(self.key_to_cls_map().keys())}"
|
105
|
+
)
|
106
|
+
return self.key_to_cls_map()[key]
|
107
|
+
|
108
|
+
def priority_map(self, problem_type: str | None = None) -> dict[Type[AbstractModel], int]:
|
109
|
+
return {model_cls: self.priority(model_cls, problem_type=problem_type) for model_cls in self._model_cls_list}
|
110
|
+
|
111
|
+
def key(self, model_cls: Type[AbstractModel]) -> str:
|
112
|
+
assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
|
113
|
+
return model_cls.ag_key
|
114
|
+
|
115
|
+
def name_map(self) -> dict[Type[AbstractModel], str]:
|
116
|
+
return {model_cls: model_cls.ag_name for model_cls in self._model_cls_list}
|
117
|
+
|
118
|
+
def name(self, model_cls: Type[AbstractModel]) -> str:
|
119
|
+
assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
|
120
|
+
return model_cls.ag_name
|
121
|
+
|
122
|
+
def priority(self, model_cls: Type[AbstractModel], problem_type: str | None = None) -> int:
|
123
|
+
assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
|
124
|
+
return model_cls.get_ag_priority(problem_type=problem_type)
|
125
|
+
|
126
|
+
def docstring(self, model_cls: Type[AbstractModel]) -> str:
|
127
|
+
assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
|
128
|
+
return model_cls.__doc__
|
129
|
+
|
130
|
+
# TODO: Could add a lot of information here to track which features are supported for each model:
|
131
|
+
# ag.early_stop support
|
132
|
+
# refit_full support
|
133
|
+
# GPU support
|
134
|
+
# etc.
|
135
|
+
def to_frame(self) -> pd.DataFrame:
|
136
|
+
model_classes = self.model_cls_list
|
137
|
+
cls_dict = {}
|
138
|
+
for model_cls in model_classes:
|
139
|
+
cls_dict[self.key(model_cls)] = {
|
140
|
+
"model_cls": model_cls.__name__,
|
141
|
+
"ag_name": self.name(model_cls),
|
142
|
+
"ag_priority": self.priority(model_cls),
|
143
|
+
}
|
144
|
+
df = pd.DataFrame(cls_dict).T
|
145
|
+
df.index.name = "ag_key"
|
146
|
+
return df
|