autogluon.tabular 1.2.1b20250225__tar.gz → 1.2.1b20250226__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/PKG-INFO +1 -1
  2. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/automm/automm_model.py +2 -0
  3. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/automm/ft_transformer.py +3 -0
  4. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/catboost_model.py +7 -0
  5. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/tabular_nn_fastai.py +10 -1
  6. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/fasttext_model.py +3 -0
  7. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/image_prediction/image_predictor.py +2 -0
  8. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/imodels/imodels_models.py +15 -0
  9. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/knn_model.py +3 -0
  10. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/lgb_model.py +7 -0
  11. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/lr_model.py +3 -0
  12. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/rf_model.py +3 -0
  13. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_transformer_model.py +2 -0
  14. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfn/tabpfn_model.py +3 -0
  15. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +4 -0
  16. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -0
  17. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/text_prediction/text_prediction_v1_model.py +3 -0
  18. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/vowpalwabbit/vowpalwabbit_model.py +3 -0
  19. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/xgboost_model.py +3 -0
  20. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xt/xt_model.py +3 -0
  21. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/predictor/predictor.py +19 -3
  22. autogluon.tabular-1.2.1b20250226/src/autogluon/tabular/register/__init__.py +2 -0
  23. autogluon.tabular-1.2.1b20250226/src/autogluon/tabular/register/_ag_model_register.py +66 -0
  24. autogluon.tabular-1.2.1b20250226/src/autogluon/tabular/register/_model_register.py +146 -0
  25. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/model_presets/presets.py +10 -116
  26. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/version.py +1 -1
  27. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/PKG-INFO +1 -1
  28. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/SOURCES.txt +3 -0
  29. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/requires.txt +11 -11
  30. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/setup.cfg +0 -0
  31. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/setup.py +0 -0
  32. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/__init__.py +0 -0
  33. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/__init__.py +0 -0
  34. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/config_helper.py +0 -0
  35. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/feature_generator_presets.py +0 -0
  36. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/hyperparameter_configs.py +0 -0
  37. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/presets_configs.py +0 -0
  38. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/zeroshot/__init__.py +0 -0
  39. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -0
  40. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/__init__.py +0 -0
  41. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/_scikit_mixin.py +0 -0
  42. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/_tabular_classifier.py +0 -0
  43. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/_tabular_regressor.py +0 -0
  44. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/experimental/plot_leaderboard.py +0 -0
  45. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/learner/__init__.py +0 -0
  46. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/learner/abstract_learner.py +0 -0
  47. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/learner/default_learner.py +0 -0
  48. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/__init__.py +0 -0
  49. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/_utils/__init__.py +0 -0
  50. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/_utils/rapids_utils.py +0 -0
  51. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/_utils/torch_utils.py +0 -0
  52. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/automm/__init__.py +0 -0
  53. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/__init__.py +0 -0
  54. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/callbacks.py +0 -0
  55. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/catboost_softclass_utils.py +0 -0
  56. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/catboost_utils.py +0 -0
  57. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/hyperparameters/__init__.py +0 -0
  58. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/hyperparameters/parameters.py +0 -0
  59. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/catboost/hyperparameters/searchspaces.py +0 -0
  60. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/__init__.py +0 -0
  61. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/callbacks.py +0 -0
  62. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/fastai_helpers.py +0 -0
  63. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/hyperparameters/__init__.py +0 -0
  64. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/hyperparameters/parameters.py +0 -0
  65. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +0 -0
  66. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/imports_helper.py +0 -0
  67. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fastainn/quantile_helpers.py +0 -0
  68. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/__init__.py +0 -0
  69. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/hyperparameters/__init__.py +0 -0
  70. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/fasttext/hyperparameters/parameters.py +0 -0
  71. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/image_prediction/__init__.py +0 -0
  72. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/imodels/__init__.py +0 -0
  73. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/__init__.py +0 -0
  74. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/_knn_loo_variants.py +0 -0
  75. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/knn_rapids_model.py +0 -0
  76. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/knn/knn_utils.py +0 -0
  77. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/__init__.py +0 -0
  78. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/callbacks.py +0 -0
  79. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/hyperparameters/__init__.py +0 -0
  80. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/hyperparameters/parameters.py +0 -0
  81. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +0 -0
  82. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lgb/lgb_utils.py +0 -0
  83. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/__init__.py +0 -0
  84. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/hyperparameters/__init__.py +0 -0
  85. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/hyperparameters/parameters.py +0 -0
  86. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/hyperparameters/searchspaces.py +0 -0
  87. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/lr_preprocessing_utils.py +0 -0
  88. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/lr/lr_rapids_model.py +0 -0
  89. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/__init__.py +0 -0
  90. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/compilers/__init__.py +0 -0
  91. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/compilers/native.py +0 -0
  92. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/compilers/onnx.py +0 -0
  93. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/rf_quantile.py +0 -0
  94. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/rf/rf_rapids_model.py +0 -0
  95. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/__init__.py +0 -0
  96. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/hyperparameters/__init__.py +0 -0
  97. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/hyperparameters/parameters.py +0 -0
  98. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/hyperparameters/searchspaces.py +0 -0
  99. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/modified_transformer.py +0 -0
  100. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/pretexts.py +0 -0
  101. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_model_base.py +0 -0
  102. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_transformer.py +0 -0
  103. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/tab_transformer_encoder.py +0 -0
  104. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tab_transformer/utils.py +0 -0
  105. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfn/__init__.py +0 -0
  106. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/__init__.py +0 -0
  107. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/__init__.py +0 -0
  108. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/config/__init__.py +0 -0
  109. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/config/config_run.py +0 -0
  110. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/__init__.py +0 -0
  111. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +0 -0
  112. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +0 -0
  113. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +0 -0
  114. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -0
  115. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -0
  116. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +0 -0
  117. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +0 -0
  118. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +0 -0
  119. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +0 -0
  120. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/data/__init__.py +0 -0
  121. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +0 -0
  122. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +0 -0
  123. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/__init__.py +0 -0
  124. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/__init__.py +0 -0
  125. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +0 -0
  126. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +0 -0
  127. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/results/__init__.py +0 -0
  128. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +0 -0
  129. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +0 -0
  130. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +0 -0
  131. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/__init__.py +0 -0
  132. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/compilers/__init__.py +0 -0
  133. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/compilers/native.py +0 -0
  134. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/compilers/onnx.py +0 -0
  135. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/hyperparameters/__init__.py +0 -0
  136. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +0 -0
  137. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/hyperparameters/searchspaces.py +0 -0
  138. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/__init__.py +0 -0
  139. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +0 -0
  140. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +0 -0
  141. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/__init__.py +0 -0
  142. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +0 -0
  143. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +0 -0
  144. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +0 -0
  145. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/text_prediction/__init__.py +0 -0
  146. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/vowpalwabbit/__init__.py +0 -0
  147. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/vowpalwabbit/vowpalwabbit_utils.py +0 -0
  148. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/__init__.py +0 -0
  149. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/callbacks.py +0 -0
  150. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/hyperparameters/__init__.py +0 -0
  151. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/hyperparameters/parameters.py +0 -0
  152. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/hyperparameters/searchspaces.py +0 -0
  153. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xgboost/xgboost_utils.py +0 -0
  154. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/models/xt/__init__.py +0 -0
  155. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/predictor/__init__.py +0 -0
  156. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/predictor/interpretable_predictor.py +0 -0
  157. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/__init__.py +0 -0
  158. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/abstract_trainer.py +0 -0
  159. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/auto_trainer.py +0 -0
  160. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/model_presets/__init__.py +0 -0
  161. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/trainer/model_presets/presets_distill.py +0 -0
  162. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/tuning/__init__.py +0 -0
  163. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon/tabular/tuning/feature_pruner.py +0 -0
  164. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/dependency_links.txt +0 -0
  165. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/namespace_packages.txt +0 -0
  166. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/top_level.txt +0 -0
  167. {autogluon.tabular-1.2.1b20250225 → autogluon.tabular-1.2.1b20250226}/src/autogluon.tabular.egg-info/zip-safe +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.tabular
3
- Version: 1.2.1b20250225
3
+ Version: 1.2.1b20250226
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -26,6 +26,8 @@ logger = logging.getLogger(__name__)
26
26
 
27
27
 
28
28
  class MultiModalPredictorModel(AbstractModel):
29
+ ag_key = "AG_AUTOMM"
30
+ ag_name = "MultiModalPredictor"
29
31
  _NN_MODEL_NAME = "automm_model"
30
32
 
31
33
  def __init__(self, **kwargs):
@@ -12,6 +12,9 @@ logger = logging.getLogger(__name__)
12
12
 
13
13
  # TODO: Add unit tests
14
14
  class FTTransformerModel(MultiModalPredictorModel):
15
+ ag_key = "FT_TRANSFORMER"
16
+ ag_name = "FTTransformer"
17
+
15
18
  def __init__(self, **kwargs):
16
19
  """Wrapper of autogluon.multimodal.MultiModalPredictor.
17
20
 
@@ -2,6 +2,7 @@ import logging
2
2
  import math
3
3
  import os
4
4
  import time
5
+ from types import MappingProxyType
5
6
 
6
7
  import numpy as np
7
8
  import pandas as pd
@@ -30,6 +31,12 @@ class CatBoostModel(AbstractModel):
30
31
 
31
32
  Hyperparameter options: https://catboost.ai/en/docs/references/training-parameters
32
33
  """
34
+ ag_key = "CAT"
35
+ ag_name = "CatBoost"
36
+ ag_priority = 70
37
+ ag_priority_by_problem_type = MappingProxyType({
38
+ SOFTCLASS: 60
39
+ })
33
40
 
34
41
  def __init__(self, **kwargs):
35
42
  super().__init__(**kwargs)
@@ -8,6 +8,7 @@ import warnings
8
8
  from builtins import classmethod
9
9
  from functools import partial
10
10
  from pathlib import Path
11
+ from types import MappingProxyType
11
12
  from typing import Union
12
13
 
13
14
  import numpy as np
@@ -28,7 +29,7 @@ from autogluon.common.features.types import (
28
29
  from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
29
30
  from autogluon.common.utils.resource_utils import ResourceManager
30
31
  from autogluon.common.utils.try_import import try_import_fastai
31
- from autogluon.core.constants import BINARY, QUANTILE, REGRESSION
32
+ from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REGRESSION
32
33
  from autogluon.core.hpo.constants import RAY_BACKEND
33
34
  from autogluon.core.models import AbstractModel
34
35
  from autogluon.core.utils.exceptions import TimeLimitExceeded
@@ -92,6 +93,14 @@ class NNFastAiTabularModel(AbstractModel):
92
93
  'early.stopping.min_delta': 0.0001,
93
94
  'early.stopping.patience': 10,
94
95
  """
96
+ ag_key = "FASTAI"
97
+ ag_name = "NeuralNetFastAI"
98
+ ag_priority = 50
99
+ # Increase priority for multiclass since neural networks
100
+ # scale better than trees as a function of n_classes.
101
+ ag_priority_by_problem_type = MappingProxyType({
102
+ MULTICLASS: 95,
103
+ })
95
104
 
96
105
  model_internals_file_name = "model-internals.pkl"
97
106
 
@@ -21,6 +21,9 @@ logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
23
  class FastTextModel(AbstractModel):
24
+ ag_key = "FASTTEXT"
25
+ ag_name = "FastText"
26
+
24
27
  model_bin_file_name = "fasttext.ftz"
25
28
 
26
29
  def __init__(self, **kwargs):
@@ -22,6 +22,8 @@ class ImagePredictorModel(MultiModalPredictorModel):
22
22
  Additionally has special null image handling to improve performance in the presence of null images (aka image path of '')
23
23
  Note: null handling has not been compared to the built-in null handling of MultimodalPredictor yet.
24
24
  """
25
+ ag_key = "AG_IMAGE_NN"
26
+ ag_name = "ImagePredictor"
25
27
 
26
28
  def __init__(self, **kwargs):
27
29
  super().__init__(**kwargs)
@@ -75,6 +75,9 @@ class _IModelsModel(AbstractModel):
75
75
 
76
76
 
77
77
  class RuleFitModel(_IModelsModel):
78
+ ag_key = "IM_RULEFIT"
79
+ ag_name = "RuleFit"
80
+
78
81
  def get_model(self):
79
82
  try_import_imodels()
80
83
  from imodels import RuleFitClassifier, RuleFitRegressor
@@ -86,6 +89,9 @@ class RuleFitModel(_IModelsModel):
86
89
 
87
90
 
88
91
  class GreedyTreeModel(_IModelsModel):
92
+ ag_key = "IM_GREEDYTREE"
93
+ ag_name = "GreedyTree"
94
+
89
95
  def get_model(self):
90
96
  try_import_imodels()
91
97
  from imodels import GreedyTreeClassifier
@@ -98,6 +104,9 @@ class GreedyTreeModel(_IModelsModel):
98
104
 
99
105
 
100
106
  class BoostedRulesModel(_IModelsModel):
107
+ ag_key = "IM_BOOSTEDRULES"
108
+ ag_name = "BoostedRules"
109
+
101
110
  def get_model(self):
102
111
  try_import_imodels()
103
112
  from imodels import BoostedRulesClassifier
@@ -109,6 +118,9 @@ class BoostedRulesModel(_IModelsModel):
109
118
 
110
119
 
111
120
  class HSTreeModel(_IModelsModel):
121
+ ag_key = "IM_HSTREE"
122
+ ag_name = "HierarchicalShrinkageTree"
123
+
112
124
  def get_model(self):
113
125
  try_import_imodels()
114
126
  from imodels import HSTreeClassifierCV, HSTreeRegressorCV
@@ -120,6 +132,9 @@ class HSTreeModel(_IModelsModel):
120
132
 
121
133
 
122
134
  class FigsModel(_IModelsModel):
135
+ ag_key = "IM_FIGS"
136
+ ag_name = "Figs"
137
+
123
138
  def get_model(self):
124
139
  try_import_imodels()
125
140
  from imodels import FIGSClassifier, FIGSRegressor
@@ -22,6 +22,9 @@ class KNNModel(AbstractModel):
22
22
  """
23
23
  KNearestNeighbors model (scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
24
24
  """
25
+ ag_key = "KNN"
26
+ ag_name = "KNeighbors"
27
+ ag_priority = 100
25
28
 
26
29
  def __init__(self, **kwargs):
27
30
  super().__init__(**kwargs)
@@ -7,6 +7,7 @@ import random
7
7
  import re
8
8
  import time
9
9
  import warnings
10
+ from types import MappingProxyType
10
11
 
11
12
  import numpy as np
12
13
  import pandas as pd
@@ -40,6 +41,12 @@ class LGBModel(AbstractModel):
40
41
  Extra hyperparameter options:
41
42
  ag.early_stop : int, specifies the early stopping rounds. Defaults to an adaptive strategy. Recommended to keep default.
42
43
  """
44
+ ag_key = "GBM"
45
+ ag_name = "LightGBM"
46
+ ag_priority = 90
47
+ ag_priority_by_problem_type = MappingProxyType({
48
+ SOFTCLASS: 100
49
+ })
43
50
 
44
51
  def __init__(self, **kwargs):
45
52
  super().__init__(**kwargs)
@@ -38,6 +38,9 @@ class LinearModel(AbstractModel):
38
38
 
39
39
  'regression': https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html#sklearn.linear_model.Ridge
40
40
  """
41
+ ag_key = "LR"
42
+ ag_name = "LinearModel"
43
+ ag_priority = 30
41
44
 
42
45
  def __init__(self, **kwargs):
43
46
  super().__init__(**kwargs)
@@ -27,6 +27,9 @@ class RFModel(AbstractModel):
27
27
  """
28
28
  Random Forest model (scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
29
29
  """
30
+ ag_key = "RF"
31
+ ag_name = "RandomForest"
32
+ ag_priority = 80
30
33
 
31
34
  def __init__(self, **kwargs):
32
35
  super().__init__(**kwargs)
@@ -40,6 +40,8 @@ class TabTransformerModel(AbstractNeuralNetworkModel):
40
40
  and applies them to the use case of tabular data. Specifically, this makes TabTransformer suitable for unsupervised
41
41
  training of Tabular data with a subsequent fine-tuning step on labeled data.
42
42
  """
43
+ ag_key = "TRANSF"
44
+ ag_name = "Transformer"
43
45
 
44
46
  params_file_name = "tab_trans_params.pth"
45
47
 
@@ -22,6 +22,9 @@ class TabPFNModel(AbstractModel):
22
22
  To use this model, `tabpfn` must be installed.
23
23
  To install TabPFN, you can run `pip install autogluon.tabular[tabpfn]` or `pip install tabpfn`.
24
24
  """
25
+ ag_key = "TABPFN"
26
+ ag_name = "TabPFN"
27
+ ag_priority = 110
25
28
 
26
29
  def __init__(self, **kwargs):
27
30
  super().__init__(**kwargs)
@@ -35,6 +35,10 @@ class TabPFNMixModel(AbstractModel):
35
35
 
36
36
  For more information, refer to the `./_internals/README.md` file.
37
37
  """
38
+ ag_key = "TABPFNMIX"
39
+ ag_name = "TabPFNMix"
40
+ ag_priority = 45
41
+
38
42
  weights_file_name = "model.pt"
39
43
 
40
44
  def __init__(self, **kwargs):
@@ -47,6 +47,9 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
47
47
  ag.early_stop : int | str, default = "default"
48
48
  Specifies the early stopping rounds. Defaults to an adaptive strategy. Recommended to keep default.
49
49
  """
50
+ ag_key = "NN_TORCH"
51
+ ag_name = "NeuralNetTorch"
52
+ ag_priority = 25
50
53
 
51
54
  # Constants used throughout this class:
52
55
  unique_category_str = np.nan # string used to represent missing values and unknown categories for categorical features.
@@ -19,6 +19,9 @@ logger = logging.getLogger(__name__)
19
19
  class TextPredictorModel(MultiModalPredictorModel):
20
20
  """MultimodalPredictor that doesn't use image features"""
21
21
 
22
+ ag_key = "AG_TEXT_NN"
23
+ ag_name = "TextPredictor"
24
+
22
25
  def _get_default_auxiliary_params(self) -> dict:
23
26
  default_auxiliary_params = super()._get_default_auxiliary_params()
24
27
  extra_auxiliary_params = dict(
@@ -38,6 +38,9 @@ class VowpalWabbitModel(AbstractModel):
38
38
  VowpalWabbit Command Line args: https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Command-line-arguments
39
39
 
40
40
  """
41
+ ag_key = "VW"
42
+ ag_name = "VowpalWabbit"
43
+ ag_priority = 10
41
44
 
42
45
  model_internals_file_name = "model-internals.pkl"
43
46
 
@@ -27,6 +27,9 @@ class XGBoostModel(AbstractModel):
27
27
 
28
28
  Hyperparameter options: https://xgboost.readthedocs.io/en/latest/parameter.html
29
29
  """
30
+ ag_key = "XGB"
31
+ ag_name = "XGBoost"
32
+ ag_priority = 40
30
33
 
31
34
  def __init__(self, **kwargs):
32
35
  super().__init__(**kwargs)
@@ -7,6 +7,9 @@ class XTModel(RFModel):
7
7
  """
8
8
  Extra Trees model (scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html#sklearn.ensemble.ExtraTreesClassifier
9
9
  """
10
+ ag_key = "XT"
11
+ ag_name = "ExtraTrees"
12
+ ag_priority = 60
10
13
 
11
14
  def _get_model_type(self):
12
15
  if self.problem_type == REGRESSION:
@@ -4330,7 +4330,14 @@ class TabularPredictor:
4330
4330
  reduce_children=reduce_children,
4331
4331
  )
4332
4332
 
4333
- def delete_models(self, models_to_keep=None, models_to_delete=None, allow_delete_cascade=False, delete_from_disk=True, dry_run=True):
4333
+ def delete_models(
4334
+ self,
4335
+ models_to_keep: str | list[str] | None = None,
4336
+ models_to_delete: str | list[str] | None = None,
4337
+ allow_delete_cascade: bool = False,
4338
+ delete_from_disk: bool = True,
4339
+ dry_run: bool | None = None,
4340
+ ):
4334
4341
  """
4335
4342
  Deletes models from `predictor`.
4336
4343
  This can be helpful to minimize memory usage and disk usage, particularly for model deployment.
@@ -4341,13 +4348,13 @@ class TabularPredictor:
4341
4348
 
4342
4349
  Parameters
4343
4350
  ----------
4344
- models_to_keep : str or list, default = None
4351
+ models_to_keep : str or list[str], default = None
4345
4352
  Name of model or models to not delete.
4346
4353
  All models that are not specified and are also not required as a dependency of any model in `models_to_keep` will be deleted.
4347
4354
  Specify `models_to_keep='best'` to keep only the best model and its model dependencies.
4348
4355
  `models_to_delete` must be None if `models_to_keep` is set.
4349
4356
  To see the list of possible model names, use: `predictor.model_names()` or `predictor.leaderboard()`.
4350
- models_to_delete : str or list, default = None
4357
+ models_to_delete : str or list[str], default = None
4351
4358
  Name of model or models to delete.
4352
4359
  All models that are not specified but depend on a model in `models_to_delete` will also be deleted.
4353
4360
  `models_to_keep` must be None if `models_to_delete` is set.
@@ -4361,10 +4368,19 @@ class TabularPredictor:
4361
4368
  WARNING: This deletes the entire directory for the deleted models, and ALL FILES located there.
4362
4369
  It is highly recommended to first run with `dry_run=True` to understand which directories will be deleted.
4363
4370
  dry_run : bool, default = True
4371
+ WARNING: Starting in v1.4.0 dry_run will default to False.
4364
4372
  If `True`, then deletions don't occur, and logging statements are printed describing what would have occurred.
4365
4373
  Set `dry_run=False` to perform the deletions.
4366
4374
 
4367
4375
  """
4376
+ if dry_run is None:
4377
+ warnings.warn(
4378
+ f"dry_run was not specified for `TabularPredictor.delete_models`. dry_run prior to version 1.4.0 defaults to True. "
4379
+ f"Starting in version 1.4, AutoGluon will default dry_run to False. "
4380
+ f"If you want to maintain the current logic in future versions, explicitly specify `dry_run=True`.",
4381
+ category=FutureWarning,
4382
+ )
4383
+ dry_run = True
4368
4384
  self._assert_is_fit("delete_models")
4369
4385
  if models_to_keep == "best":
4370
4386
  models_to_keep = self.model_best
@@ -0,0 +1,2 @@
1
+ from ._model_register import ModelRegister
2
+ from ._ag_model_register import ag_model_register
@@ -0,0 +1,66 @@
1
+ from autogluon.core.models import (
2
+ DummyModel,
3
+ GreedyWeightedEnsembleModel,
4
+ SimpleWeightedEnsembleModel,
5
+ )
6
+
7
+ from . import ModelRegister
8
+ from ..models import (
9
+ BoostedRulesModel,
10
+ CatBoostModel,
11
+ FastTextModel,
12
+ FigsModel,
13
+ FTTransformerModel,
14
+ GreedyTreeModel,
15
+ HSTreeModel,
16
+ ImagePredictorModel,
17
+ KNNModel,
18
+ LGBModel,
19
+ LinearModel,
20
+ MultiModalPredictorModel,
21
+ NNFastAiTabularModel,
22
+ RFModel,
23
+ RuleFitModel,
24
+ TabPFNMixModel,
25
+ TabPFNModel,
26
+ TabularNeuralNetTorchModel,
27
+ TextPredictorModel,
28
+ VowpalWabbitModel,
29
+ XGBoostModel,
30
+ XTModel,
31
+ )
32
+ from ..models.tab_transformer.tab_transformer_model import TabTransformerModel
33
+
34
+
35
+ # When adding a new model officially to AutoGluon, the model class should be added to the bottom of this list.
36
+ REGISTERED_MODEL_CLS_LST = [
37
+ RFModel,
38
+ XTModel,
39
+ KNNModel,
40
+ LGBModel,
41
+ CatBoostModel,
42
+ XGBoostModel,
43
+ TabularNeuralNetTorchModel,
44
+ LinearModel,
45
+ NNFastAiTabularModel,
46
+ TabTransformerModel,
47
+ TextPredictorModel,
48
+ ImagePredictorModel,
49
+ MultiModalPredictorModel,
50
+ FTTransformerModel,
51
+ TabPFNModel,
52
+ TabPFNMixModel,
53
+ FastTextModel,
54
+ VowpalWabbitModel,
55
+ GreedyWeightedEnsembleModel,
56
+ SimpleWeightedEnsembleModel,
57
+ RuleFitModel,
58
+ GreedyTreeModel,
59
+ FigsModel,
60
+ HSTreeModel,
61
+ BoostedRulesModel,
62
+ DummyModel,
63
+ ]
64
+
65
+ # TODO: Replace logic in `autogluon.tabular.trainer.model_presets.presets` with `ag_model_register`
66
+ ag_model_register = ModelRegister(model_cls_list=REGISTERED_MODEL_CLS_LST)
@@ -0,0 +1,146 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Type
4
+
5
+ import pandas as pd
6
+
7
+ from autogluon.core.models import AbstractModel
8
+
9
+
10
+ # TODO: Move to core? Maybe TimeSeries can reuse?
11
+ # TODO: Use this / refer to this in the custom model tutorial
12
+ # TODO: Add to documentation website
13
+ # TODO: Test register logic in AG
14
+ class ModelRegister:
15
+ """
16
+ ModelRegister keeps track of all known model classes to AutoGluon.
17
+ It can provide information such as:
18
+ What model classes and keys are valid to specify in an AutoGluon predictor fit call.
19
+ What a model's name is.
20
+ What a model's key is (such as the key specified by the user in `hyperparameters` to refer to a specific model type).
21
+ What a model's priority is (aka which order to fit a list of models).
22
+
23
+ Additionally, users can register custom models to AutoGluon so the key is recognized in `hyperparameters` and is treated with the proper priority and name.
24
+ They can register new models via `ModelRegister.add(model_cls)`.
25
+
26
+ Therefore, if a user creates a custom model `MyCustomModel` that inherits from `AbstractModel`, they can set the class attributes in `MyCustomModel`:
27
+ ag_key: The string key that can be specified in `hyperparameters`. Example: "GBM" for LGBModel
28
+ ag_name: The string name that is used in logging and accessing the model. Example: "LightGBM" for LGBModel
29
+ ag_priority: The int priority that is used to order the fitting of models. Higher values will be fit before lower values. Default 0. Example: 90 for LGBModel
30
+ ag_priority_to_problem_type: A dictionary of problem_type to priority that overrides `ag_priority` if specified for a given problem_type. Optional.
31
+
32
+ Then they can say `ag_model_register.add(MyCustomModel)`.
33
+ Assuming MyCustomModel.ag_key = "MY_MODEL", they can now do:
34
+ ```
35
+ predictor.fit(..., hyperparameters={"MY_MODEL": ...})
36
+ ```
37
+ """
38
+ def __init__(self, model_cls_list: list[Type[AbstractModel]] | None = None):
39
+ if model_cls_list is None:
40
+ model_cls_list = []
41
+ assert isinstance(model_cls_list, list)
42
+ self._model_cls_list = []
43
+ self._key_to_cls_map = dict()
44
+ for model_cls in model_cls_list:
45
+ self.add(model_cls)
46
+
47
+ def exists(self, model_cls: Type[AbstractModel]) -> bool:
48
+ return model_cls in self._model_cls_list
49
+
50
+ def add(self, model_cls: Type[AbstractModel]):
51
+ """
52
+ Adds `model_cls` to the model register
53
+ """
54
+ assert not self.exists(model_cls), f"Cannot add model_cls that is already registered: {model_cls}"
55
+ if model_cls.ag_key is None:
56
+ raise AssertionError(
57
+ f"Cannot add model_cls with `ag_key=None`. "
58
+ f"Ensure you set class attribute `ag_key` to a string for your model_cls: {model_cls}"
59
+ f'\n\tFor example, LightGBModel sets `ag_key = "GBM"`'
60
+ )
61
+ if model_cls.ag_name is None:
62
+ raise AssertionError(
63
+ f"Cannot add model_cls with `ag_name=None`. "
64
+ f"Ensure you set class attribute `ag_name` to a string for your model_cls: {model_cls}"
65
+ f'\n\tFor example, LightGBModel sets `ag_name = "LightGBM"`'
66
+ )
67
+ assert isinstance(model_cls.ag_key, str)
68
+ assert isinstance(model_cls.ag_name, str)
69
+ assert isinstance(model_cls.ag_priority, int)
70
+ if model_cls.ag_key in self._key_to_cls_map:
71
+ raise AssertionError(
72
+ f"Cannot register a model class that shares a model key with an already registered model class."
73
+ f"\n`model_cls.ag_key` must be unique among registered models:"
74
+ f"\n\t New Class: {model_cls}"
75
+ f"\n\tConflicting Class: {self._key_to_cls_map[model_cls.ag_key]}"
76
+ f"\n\tConflicting ag_key: {model_cls.ag_key}"
77
+ )
78
+ self._model_cls_list.append(model_cls)
79
+ self._key_to_cls_map[model_cls.ag_key] = model_cls
80
+
81
+ def remove(self, model_cls: Type[AbstractModel]):
82
+ """
83
+ Removes `model_cls` from the model register
84
+ """
85
+ assert self.exists(model_cls), f"Cannot remove model_cls that isn't registered: {model_cls}"
86
+ self._model_cls_list = [m for m in self._model_cls_list if m != model_cls]
87
+ self._key_to_cls_map.pop(model_cls.ag_key)
88
+
89
+ @property
90
+ def model_cls_list(self) -> list[Type[AbstractModel]]:
91
+ return self._model_cls_list
92
+
93
+ @property
94
+ def keys(self) -> list[str]:
95
+ return [self.key(model_cls) for model_cls in self.model_cls_list]
96
+
97
+ def key_to_cls_map(self) -> dict[str, Type[AbstractModel]]:
98
+ return self._key_to_cls_map
99
+
100
+ def key_to_cls(self, key: str) -> Type[AbstractModel]:
101
+ if key not in self._key_to_cls_map:
102
+ raise ValueError(
103
+ f"No registered model exists with provided key: {key}"
104
+ f"\n\tValid keys: {list(self.key_to_cls_map().keys())}"
105
+ )
106
+ return self.key_to_cls_map()[key]
107
+
108
+ def priority_map(self, problem_type: str | None = None) -> dict[Type[AbstractModel], int]:
109
+ return {model_cls: self.priority(model_cls, problem_type=problem_type) for model_cls in self._model_cls_list}
110
+
111
+ def key(self, model_cls: Type[AbstractModel]) -> str:
112
+ assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
113
+ return model_cls.ag_key
114
+
115
+ def name_map(self) -> dict[Type[AbstractModel], str]:
116
+ return {model_cls: model_cls.ag_name for model_cls in self._model_cls_list}
117
+
118
+ def name(self, model_cls: Type[AbstractModel]) -> str:
119
+ assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
120
+ return model_cls.ag_name
121
+
122
+ def priority(self, model_cls: Type[AbstractModel], problem_type: str | None = None) -> int:
123
+ assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
124
+ return model_cls.get_ag_priority(problem_type=problem_type)
125
+
126
+ def docstring(self, model_cls: Type[AbstractModel]) -> str:
127
+ assert self.exists(model_cls), f"Model class must be registered: {model_cls}"
128
+ return model_cls.__doc__
129
+
130
+ # TODO: Could add a lot of information here to track which features are supported for each model:
131
+ # ag.early_stop support
132
+ # refit_full support
133
+ # GPU support
134
+ # etc.
135
+ def to_frame(self) -> pd.DataFrame:
136
+ model_classes = self.model_cls_list
137
+ cls_dict = {}
138
+ for model_cls in model_classes:
139
+ cls_dict[self.key(model_cls)] = {
140
+ "model_cls": model_cls.__name__,
141
+ "ag_name": self.name(model_cls),
142
+ "ag_priority": self.priority(model_cls),
143
+ }
144
+ df = pd.DataFrame(cls_dict).T
145
+ df.index.name = "ag_key"
146
+ return df