autogluon.tabular 1.4.1b20250822__tar.gz → 1.4.1b20250823__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/PKG-INFO +2 -1
  2. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/setup.py +5 -0
  3. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/__init__.py +1 -0
  4. autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/ebm/ebm_model.py +263 -0
  5. autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  6. autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  7. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/mitra_model.py +16 -0
  8. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/sklearn_interface.py +8 -21
  9. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/predictor/predictor.py +1 -0
  10. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/registry/_ag_model_registry.py +2 -0
  11. autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/trainer/model_presets/__init__.py +0 -0
  12. autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/tuning/__init__.py +0 -0
  13. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/version.py +1 -1
  14. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/PKG-INFO +2 -1
  15. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/SOURCES.txt +5 -0
  16. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/requires.txt +31 -26
  17. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/setup.cfg +0 -0
  18. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/__init__.py +0 -0
  19. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/__init__.py +0 -0
  20. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/config_helper.py +0 -0
  21. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/feature_generator_presets.py +0 -0
  22. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/hyperparameter_configs.py +0 -0
  23. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/pipeline_presets.py +0 -0
  24. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/presets_configs.py +0 -0
  25. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/zeroshot/__init__.py +0 -0
  26. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -0
  27. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +0 -0
  28. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/experimental/__init__.py +0 -0
  29. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/experimental/_scikit_mixin.py +0 -0
  30. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/experimental/_tabular_classifier.py +0 -0
  31. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/experimental/_tabular_regressor.py +0 -0
  32. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/experimental/plot_leaderboard.py +0 -0
  33. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/learner/__init__.py +0 -0
  34. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/learner/abstract_learner.py +0 -0
  35. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/learner/default_learner.py +0 -0
  36. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/_utils/__init__.py +0 -0
  37. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/_utils/rapids_utils.py +0 -0
  38. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/_utils/torch_utils.py +0 -0
  39. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/automm/__init__.py +0 -0
  40. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/automm/automm_model.py +0 -0
  41. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/automm/ft_transformer.py +0 -0
  42. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/__init__.py +0 -0
  43. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/callbacks.py +0 -0
  44. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/catboost_model.py +0 -0
  45. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/catboost_softclass_utils.py +0 -0
  46. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/catboost_utils.py +0 -0
  47. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/hyperparameters/__init__.py +0 -0
  48. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/hyperparameters/parameters.py +0 -0
  49. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/catboost/hyperparameters/searchspaces.py +0 -0
  50. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/fastainn → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/ebm}/__init__.py +0 -0
  51. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/fastainn → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/ebm}/hyperparameters/__init__.py +0 -0
  52. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/fasttext → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/fastainn}/__init__.py +0 -0
  53. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/callbacks.py +0 -0
  54. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/fastai_helpers.py +0 -0
  55. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/fasttext → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/fastainn}/hyperparameters/__init__.py +0 -0
  56. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/hyperparameters/parameters.py +0 -0
  57. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +0 -0
  58. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/imports_helper.py +0 -0
  59. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/quantile_helpers.py +0 -0
  60. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fastainn/tabular_nn_fastai.py +0 -0
  61. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/image_prediction → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/fasttext}/__init__.py +0 -0
  62. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fasttext/fasttext_model.py +0 -0
  63. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/imodels → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/fasttext/hyperparameters}/__init__.py +0 -0
  64. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/fasttext/hyperparameters/parameters.py +0 -0
  65. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/knn → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/image_prediction}/__init__.py +0 -0
  66. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/image_prediction/image_predictor.py +0 -0
  67. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/lgb → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/imodels}/__init__.py +0 -0
  68. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/imodels/imodels_models.py +0 -0
  69. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/lgb/hyperparameters → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/knn}/__init__.py +0 -0
  70. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/knn/_knn_loo_variants.py +0 -0
  71. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/knn/knn_model.py +0 -0
  72. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/knn/knn_rapids_model.py +0 -0
  73. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/knn/knn_utils.py +0 -0
  74. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/lr → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/lgb}/__init__.py +0 -0
  75. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lgb/callbacks.py +0 -0
  76. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/lr → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/lgb}/hyperparameters/__init__.py +0 -0
  77. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lgb/hyperparameters/parameters.py +0 -0
  78. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +0 -0
  79. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lgb/lgb_model.py +0 -0
  80. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lgb/lgb_utils.py +0 -0
  81. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/mitra → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/lr}/__init__.py +0 -0
  82. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/realmlp → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/lr/hyperparameters}/__init__.py +0 -0
  83. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lr/hyperparameters/parameters.py +0 -0
  84. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lr/hyperparameters/searchspaces.py +0 -0
  85. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lr/lr_model.py +0 -0
  86. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lr/lr_preprocessing_utils.py +0 -0
  87. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/lr/lr_rapids_model.py +0 -0
  88. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/rf → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/mitra}/__init__.py +0 -0
  89. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/__init__.py +0 -0
  90. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/config/__init__.py +0 -0
  91. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +0 -0
  92. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/config/config_run.py +0 -0
  93. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/config/enums.py +0 -0
  94. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/__init__.py +0 -0
  95. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/callbacks.py +0 -0
  96. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/get_loss.py +0 -0
  97. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +0 -0
  98. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +0 -0
  99. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +0 -0
  100. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +0 -0
  101. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/data/__init__.py +0 -0
  102. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/data/collator.py +0 -0
  103. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +0 -0
  104. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/data/dataset_split.py +0 -0
  105. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/data/preprocessor.py +0 -0
  106. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/models/__init__.py +0 -0
  107. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/models/base.py +0 -0
  108. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/models/embedding.py +0 -0
  109. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/models/tab2d.py +0 -0
  110. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/utils/__init__.py +0 -0
  111. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/mitra/_internal/utils/set_seed.py +0 -0
  112. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/rf/compilers → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/realmlp}/__init__.py +0 -0
  113. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/realmlp/realmlp_model.py +0 -0
  114. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabicl → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/rf}/__init__.py +0 -0
  115. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabm → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/rf/compilers}/__init__.py +0 -0
  116. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/rf/compilers/native.py +0 -0
  117. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/rf/compilers/onnx.py +0 -0
  118. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/rf/rf_model.py +0 -0
  119. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/rf/rf_quantile.py +0 -0
  120. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/rf/rf_rapids_model.py +0 -0
  121. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabicl}/__init__.py +0 -0
  122. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabicl/tabicl_model.py +0 -0
  123. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabm}/__init__.py +0 -0
  124. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabm/_tabm_internal.py +0 -0
  125. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabm/rtdl_num_embeddings.py +0 -0
  126. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabm/tabm_model.py +0 -0
  127. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabm/tabm_reference.py +0 -0
  128. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal/config → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix}/__init__.py +0 -0
  129. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal/core → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal}/__init__.py +0 -0
  130. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal/data → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal/config}/__init__.py +0 -0
  131. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/config/config_run.py +0 -0
  132. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal/models → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal/core}/__init__.py +0 -0
  133. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +0 -0
  134. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +0 -0
  135. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +0 -0
  136. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -0
  137. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -0
  138. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +0 -0
  139. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +0 -0
  140. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +0 -0
  141. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +0 -0
  142. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal/data}/__init__.py +0 -0
  143. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +0 -0
  144. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +0 -0
  145. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnmix/_internal/results → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal/models}/__init__.py +0 -0
  146. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabpfnv2 → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation}/__init__.py +0 -0
  147. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +0 -0
  148. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +0 -0
  149. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabular_nn → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnmix/_internal/results}/__init__.py +0 -0
  150. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +0 -0
  151. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +0 -0
  152. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +0 -0
  153. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +0 -0
  154. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabular_nn/compilers → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabpfnv2}/__init__.py +0 -0
  155. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -0
  156. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -0
  157. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -0
  158. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -0
  159. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -0
  160. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -0
  161. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -0
  162. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -0
  163. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabular_nn/hyperparameters → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabular_nn}/__init__.py +0 -0
  164. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabular_nn/torch → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabular_nn/compilers}/__init__.py +0 -0
  165. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/compilers/native.py +0 -0
  166. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/compilers/onnx.py +0 -0
  167. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/tabular_nn/utils → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabular_nn/hyperparameters}/__init__.py +0 -0
  168. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +0 -0
  169. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/hyperparameters/searchspaces.py +0 -0
  170. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/text_prediction → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabular_nn/torch}/__init__.py +0 -0
  171. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +0 -0
  172. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +0 -0
  173. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +0 -0
  174. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/xgboost → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/tabular_nn/utils}/__init__.py +0 -0
  175. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +0 -0
  176. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +0 -0
  177. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +0 -0
  178. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/xgboost/hyperparameters → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/text_prediction}/__init__.py +0 -0
  179. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/text_prediction/text_prediction_v1_model.py +0 -0
  180. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/models/xt → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/xgboost}/__init__.py +0 -0
  181. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/xgboost/callbacks.py +0 -0
  182. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/trainer/model_presets → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/xgboost/hyperparameters}/__init__.py +0 -0
  183. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/xgboost/hyperparameters/parameters.py +0 -0
  184. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/xgboost/hyperparameters/searchspaces.py +0 -0
  185. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/xgboost/xgboost_model.py +0 -0
  186. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/xgboost/xgboost_utils.py +0 -0
  187. {autogluon.tabular-1.4.1b20250822/src/autogluon/tabular/tuning → autogluon.tabular-1.4.1b20250823/src/autogluon/tabular/models/xt}/__init__.py +0 -0
  188. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/models/xt/xt_model.py +0 -0
  189. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/predictor/__init__.py +0 -0
  190. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/predictor/interpretable_predictor.py +0 -0
  191. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/registry/__init__.py +0 -0
  192. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/registry/_model_registry.py +0 -0
  193. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/testing/__init__.py +0 -0
  194. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/testing/fit_helper.py +0 -0
  195. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/testing/generate_datasets.py +0 -0
  196. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/testing/model_fit_helper.py +0 -0
  197. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/trainer/__init__.py +0 -0
  198. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/trainer/abstract_trainer.py +0 -0
  199. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/trainer/auto_trainer.py +0 -0
  200. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/trainer/model_presets/presets.py +0 -0
  201. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/trainer/model_presets/presets_distill.py +0 -0
  202. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon/tabular/tuning/feature_pruner.py +0 -0
  203. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/dependency_links.txt +0 -0
  204. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/namespace_packages.txt +0 -0
  205. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/top_level.txt +0 -0
  206. {autogluon.tabular-1.4.1b20250822 → autogluon.tabular-1.4.1b20250823}/src/autogluon.tabular.egg-info/zip-safe +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.tabular
3
- Version: 1.4.1b20250822
3
+ Version: 1.4.1b20250823
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -38,6 +38,7 @@ Provides-Extra: lightgbm
38
38
  Provides-Extra: catboost
39
39
  Provides-Extra: xgboost
40
40
  Provides-Extra: realmlp
41
+ Provides-Extra: interpret
41
42
  Provides-Extra: fastai
42
43
  Provides-Extra: tabm
43
44
  Provides-Extra: tabpfn
@@ -49,6 +49,9 @@ extras_require = {
49
49
  "realmlp": [
50
50
  "pytabkit>=1.6,<1.7",
51
51
  ],
52
+ "interpret": [
53
+ "interpret-core>=0.7.2,<0.8",
54
+ ],
52
55
  "fastai": [
53
56
  "spacy<3.9",
54
57
  "torch", # version range defined in `core/_setup_utils.py`
@@ -129,6 +132,7 @@ extras_require["all"] = all_requires
129
132
 
130
133
  tabarena_requires = copy.deepcopy(all_requires)
131
134
  for extra_package in [
135
+ "interpret",
132
136
  "tabicl",
133
137
  "tabpfn",
134
138
  "realmlp",
@@ -139,6 +143,7 @@ extras_require["tabarena"] = tabarena_requires
139
143
 
140
144
  test_requires = []
141
145
  for test_package in [
146
+ "interpret",
142
147
  "tabicl", # Currently has unnecessary extra dependencies such as xgboost and wandb
143
148
  "tabpfn",
144
149
  "realmlp", # Will consider to put as part of `all_requires` once part of a portfolio
@@ -3,6 +3,7 @@ from autogluon.core.models.abstract.abstract_model import AbstractModel
3
3
  from .automm.automm_model import MultiModalPredictorModel
4
4
  from .automm.ft_transformer import FTTransformerModel
5
5
  from .catboost.catboost_model import CatBoostModel
6
+ from .ebm.ebm_model import EBMModel
6
7
  from .fastainn.tabular_nn_fastai import NNFastAiTabularModel
7
8
  from .fasttext.fasttext_model import FastTextModel
8
9
  from .image_prediction.image_predictor import ImagePredictorModel
@@ -0,0 +1,263 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import warnings
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION
10
+ from autogluon.core.models import AbstractModel
11
+
12
+ from .hyperparameters.parameters import get_param_baseline
13
+ from .hyperparameters.searchspaces import get_default_searchspace
14
+
15
+ if TYPE_CHECKING:
16
+ from autogluon.core.metrics import Scorer
17
+
18
+
19
+ class EbmCallback:
20
+ """Time limit callback for EBM."""
21
+
22
+ def __init__(self, seconds: float):
23
+ self.seconds = seconds
24
+ self.end_time: float | None = None
25
+
26
+ def __call__(self, *args, **kwargs):
27
+ if self.end_time is None:
28
+ self.end_time = time.monotonic() + self.seconds
29
+ return False
30
+ return time.monotonic() > self.end_time
31
+
32
+
33
+ class EBMModel(AbstractModel):
34
+ """
35
+ The Explainable Boosting Machine (EBM) is a glass-box generalized additive model
36
+ with automatic interaction detection (https://interpret.ml/docs). EBMs are
37
+ designed to be highly interpretable while achieving accuracy comparable to
38
+ black-box models on a wide range of tabular datasets.
39
+
40
+ Requires the 'interpret' or 'interpret-core' package. Install via:
41
+
42
+ pip install interpret
43
+
44
+
45
+ Paper: InterpretML: A Unified Framework for Machine Learning Interpretability
46
+
47
+ Authors: H. Nori, S. Jenkins, P. Koch, and R. Caruana 2019
48
+
49
+ Codebase: https://github.com/interpretml/interpret
50
+
51
+ License: MIT
52
+
53
+ .. versionadded:: 1.5.0
54
+ """
55
+
56
+ ag_key = "EBM"
57
+ ag_name = "EBM"
58
+ ag_priority = 35
59
+
60
+ def _fit(
61
+ self,
62
+ X: pd.DataFrame,
63
+ y: pd.Series,
64
+ X_val: pd.DataFrame | None = None,
65
+ y_val: pd.Series | None = None,
66
+ time_limit: float | None = None,
67
+ sample_weight: np.ndarray | None = None,
68
+ sample_weight_val: np.ndarray | None = None,
69
+ num_cpus: int | str = "auto",
70
+ **kwargs,
71
+ ):
72
+ # Preprocess data.
73
+ X = self.preprocess(X)
74
+ if X_val is not None:
75
+ X_val = self.preprocess(X_val)
76
+
77
+ features = self._features
78
+ if features is None:
79
+ features = X.columns
80
+
81
+ params = construct_ebm_params(
82
+ self.problem_type,
83
+ self._get_model_params(),
84
+ features,
85
+ self.stopping_metric,
86
+ num_cpus,
87
+ time_limit,
88
+ )
89
+
90
+ # Init Class
91
+ model_cls = get_class_from_problem_type(self.problem_type)
92
+ self.model = model_cls(random_state=self.random_seed, **params)
93
+
94
+ # Handle validation data format for EBM
95
+ fit_X = X
96
+ fit_y = y
97
+ fit_sample_weight = sample_weight
98
+ bags = None
99
+ if X_val is not None:
100
+ fit_X = pd.concat([X, X_val], ignore_index=True)
101
+ fit_y = pd.concat([y, y_val], ignore_index=True)
102
+ if sample_weight is not None:
103
+ fit_sample_weight = np.hstack([sample_weight, sample_weight_val])
104
+ bags = np.full((len(fit_X), 1), 1, np.int8)
105
+ bags[len(X) :, 0] = -1
106
+
107
+ with warnings.catch_warnings(): # try to filter joblib warnings
108
+ warnings.filterwarnings(
109
+ "ignore",
110
+ category=UserWarning,
111
+ message=".*resource_tracker: process died.*",
112
+ )
113
+ self.model.fit(fit_X, fit_y, sample_weight=fit_sample_weight, bags=bags)
114
+
115
+ def _get_random_seed_from_hyperparameters(
116
+ self, hyperparameters: dict
117
+ ) -> int | None | str:
118
+ return hyperparameters.get("random_state", "N/A")
119
+
120
+ def _set_default_params(self):
121
+ default_params = get_param_baseline(problem_type=self.problem_type, num_classes=self.num_classes)
122
+ for param, val in default_params.items():
123
+ self._set_default_param_value(param, val)
124
+
125
+ def _get_default_searchspace(self):
126
+ return get_default_searchspace(problem_type=self.problem_type, num_classes=self.num_classes)
127
+
128
+ def _get_default_auxiliary_params(self) -> dict:
129
+ default_auxiliary_params = super()._get_default_auxiliary_params()
130
+ extra_auxiliary_params = {
131
+ "valid_raw_types": ["int", "float", "category"],
132
+ }
133
+ default_auxiliary_params.update(extra_auxiliary_params)
134
+ return default_auxiliary_params
135
+
136
+ @classmethod
137
+ def supported_problem_types(cls) -> list[str] | None:
138
+ return ["binary", "multiclass", "regression"]
139
+
140
+ @classmethod
141
+ def _class_tags(cls) -> dict:
142
+ return {"can_estimate_memory_usage_static": True}
143
+
144
+ def _more_tags(self) -> dict:
145
+ """EBMs support refit full."""
146
+ return {"can_refit_full": True}
147
+
148
+ def _estimate_memory_usage(self, X: pd.DataFrame, y: pd.Series | None = None, **kwargs) -> int:
149
+ return self.estimate_memory_usage_static(
150
+ X=X,
151
+ y=y,
152
+ hyperparameters=self._get_model_params(),
153
+ problem_type=self.problem_type,
154
+ num_classes=self.num_classes,
155
+ features=self._features,
156
+ **kwargs,
157
+ )
158
+
159
+ @classmethod
160
+ def _estimate_memory_usage_static(
161
+ cls,
162
+ *,
163
+ X: pd.DataFrame,
164
+ y: pd.Series | None = None,
165
+ hyperparameters: dict | None = None,
166
+ problem_type: str = "infer",
167
+ num_classes: int = 1,
168
+ features=None,
169
+ **kwargs,
170
+ ) -> int:
171
+ """Returns the expected peak memory usage in bytes of the EBM model during fit."""
172
+ # TODO: we can improve the memory estimate slightly by using num_classes if y is None
173
+
174
+ if features is None:
175
+ features = X.columns
176
+
177
+ model_cls = get_class_from_problem_type(problem_type)
178
+ params = construct_ebm_params(problem_type, hyperparameters, features)
179
+ baseline_memory_bytes = 400_000_000 # 400 MB baseline memory
180
+
181
+ # assuming we call pd.concat([X, X_val], ignore_index=True), then X size will be doubled
182
+ return baseline_memory_bytes + model_cls(**params).estimate_mem(
183
+ X, y, data_multiplier=2.0
184
+ )
185
+
186
+ def _validate_fit_memory_usage(self, mem_error_threshold: float = 1, **kwargs):
187
+ # Given the good mem estimates with overhead, we set the threshold to 1.
188
+ return super()._validate_fit_memory_usage(
189
+ mem_error_threshold=mem_error_threshold, **kwargs
190
+ )
191
+
192
+
193
+ def construct_ebm_params(
194
+ problem_type,
195
+ hyperparameters=None,
196
+ features=None,
197
+ stopping_metric=None,
198
+ num_cpus=-1,
199
+ time_limit=None,
200
+ ):
201
+ if hyperparameters is None:
202
+ hyperparameters = {}
203
+
204
+ hyperparameters = hyperparameters.copy() # we pop values below, so copy.
205
+
206
+ # The user can specify nominal and continuous columns.
207
+ continuous_columns = hyperparameters.pop("continuous_columns", [])
208
+ nominal_columns = hyperparameters.pop("nominal_columns", [])
209
+
210
+ feature_types = None
211
+ if features is not None:
212
+ feature_types = []
213
+ for c in features:
214
+ if c in continuous_columns:
215
+ f_type = "continuous"
216
+ elif c in nominal_columns:
217
+ f_type = "nominal"
218
+ else:
219
+ f_type = "auto"
220
+ feature_types.append(f_type)
221
+
222
+ # Default parameters for EBM
223
+ params = {
224
+ "outer_bags": 1, # AutoGluon ensemble creates outer bags, no need for this overhead.
225
+ "n_jobs": 1, # EBM only parallelizes across outer bags currently, so ignore num_cpus
226
+ "feature_names": features,
227
+ "feature_types": feature_types,
228
+ }
229
+ if stopping_metric is not None:
230
+ params["objective"] = get_metric_from_ag_metric(
231
+ metric=stopping_metric, problem_type=problem_type
232
+ )
233
+ if time_limit is not None:
234
+ params["callback"] = EbmCallback(time_limit)
235
+
236
+ params.update(hyperparameters)
237
+ return params
238
+
239
+
240
+ def get_class_from_problem_type(problem_type: str):
241
+ if problem_type in [BINARY, MULTICLASS]:
242
+ from interpret.glassbox import ExplainableBoostingClassifier
243
+
244
+ model_cls = ExplainableBoostingClassifier
245
+ elif problem_type == REGRESSION:
246
+ from interpret.glassbox import ExplainableBoostingRegressor
247
+
248
+ model_cls = ExplainableBoostingRegressor
249
+ else:
250
+ raise ValueError(f"Unsupported problem type: {problem_type}")
251
+ return model_cls
252
+
253
+
254
+ def get_metric_from_ag_metric(*, metric: Scorer, problem_type: str):
255
+ """Map AutoGluon metric to EBM metric for early stopping."""
256
+ if problem_type in [BINARY, MULTICLASS]:
257
+ metric_class = "log_loss"
258
+ elif problem_type == REGRESSION:
259
+ metric_class = "rmse"
260
+ else:
261
+ raise AssertionError(f"EBM does not support {problem_type} problem type.")
262
+
263
+ return metric_class
@@ -0,0 +1,39 @@
1
+ from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION, SOFTCLASS
2
+
3
+ def get_param_baseline(problem_type, num_classes=None):
4
+ if problem_type == BINARY:
5
+ return get_param_binary_baseline()
6
+ elif problem_type == MULTICLASS:
7
+ return get_param_multiclass_baseline(num_classes=num_classes)
8
+ elif problem_type == SOFTCLASS:
9
+ return get_param_multiclass_baseline(num_classes=num_classes)
10
+ elif problem_type == REGRESSION:
11
+ return get_param_regression_baseline()
12
+ else:
13
+ return get_param_binary_baseline()
14
+
15
+
16
+ def get_base_params():
17
+ base_params = {}
18
+ return base_params
19
+
20
+
21
+ def get_param_binary_baseline():
22
+ params = get_base_params()
23
+ baseline_params = {}
24
+ params.update(baseline_params)
25
+ return params
26
+
27
+
28
+ def get_param_multiclass_baseline(num_classes):
29
+ params = get_base_params()
30
+ baseline_params = {}
31
+ params.update(baseline_params)
32
+ return params
33
+
34
+
35
+ def get_param_regression_baseline():
36
+ params = get_base_params()
37
+ baseline_params = {}
38
+ params.update(baseline_params)
39
+ return params
@@ -0,0 +1,72 @@
1
+ """Default hyperparameter search spaces used in EBM model"""
2
+
3
+ from autogluon.common import space
4
+ from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION
5
+
6
+ def get_default_searchspace(problem_type, num_classes=None):
7
+ if problem_type == BINARY:
8
+ return get_searchspace_binary_baseline()
9
+ elif problem_type == MULTICLASS:
10
+ return get_searchspace_multiclass_baseline(num_classes=num_classes)
11
+ elif problem_type == REGRESSION:
12
+ return get_searchspace_regression_baseline()
13
+ else:
14
+ return get_searchspace_binary_baseline()
15
+
16
+
17
+ def get_base_searchspace():
18
+ base_params = {
19
+ "max_leaves": space.Int(2, 3, default=2),
20
+ "smoothing_rounds": space.Int(0, 1000, default=200),
21
+ "learning_rate": space.Real(0.0025, 0.2, default=0.02, log=True),
22
+ "interactions": space.Categorical(
23
+ 0,
24
+ "0.5x",
25
+ "1x",
26
+ "1.5x",
27
+ "2x",
28
+ "2.5x",
29
+ "3x",
30
+ "3.5x",
31
+ "4x",
32
+ "4.5x",
33
+ "5x",
34
+ "6x",
35
+ "7x",
36
+ "8x",
37
+ "9x",
38
+ "10x",
39
+ "15x",
40
+ "20x",
41
+ "25x",
42
+ ),
43
+ "interaction_smoothing_rounds": space.Int(0, 200, default=90),
44
+ "min_hessian": space.Real(1e-10, 1e-2, default=1e-4, log=True),
45
+ "min_samples_leaf": space.Int(2, 20, default=4),
46
+ "gain_scale": space.Real(0.5, 5.0, default=5.0, log=True),
47
+ "min_cat_samples": space.Int(5, 20, default=10),
48
+ "cat_smooth": space.Real(5.0, 100.0, default=10.0, log=True),
49
+ "missing": space.Categorical("separate", "low", "high", "gain"),
50
+ }
51
+ return base_params
52
+
53
+
54
+ def get_searchspace_multiclass_baseline(num_classes):
55
+ params = get_base_searchspace()
56
+ baseline_params = {}
57
+ params.update(baseline_params)
58
+ return params
59
+
60
+
61
+ def get_searchspace_binary_baseline():
62
+ params = get_base_searchspace()
63
+ baseline_params = {}
64
+ params.update(baseline_params)
65
+ return params
66
+
67
+
68
+ def get_searchspace_regression_baseline():
69
+ params = get_base_searchspace()
70
+ baseline_params = {}
71
+ params.update(baseline_params)
72
+ return params
@@ -116,6 +116,22 @@ class MitraModel(AbstractModel):
116
116
 
117
117
  hyp = self._get_model_params()
118
118
 
119
+ hf_cls_model = hyp.pop("hf_cls_model", None)
120
+ hf_reg_model = hyp.pop("hf_reg_model", None)
121
+ if self.problem_type in ["binary", "multiclass"]:
122
+ hf_model = hf_cls_model
123
+ elif self.problem_type == "regression":
124
+ hf_model = hf_reg_model
125
+ else:
126
+ raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
127
+ if hf_model is None:
128
+ hf_model = hyp.pop("hf_general_model", None)
129
+ if hf_model is None:
130
+ hf_model = hyp.pop("hf_model", None)
131
+ if hf_model is not None:
132
+ logger.log(30, f"\tCustom hf_model specified: {hf_model}")
133
+ hyp["hf_model"] = hf_model
134
+
119
135
  if hyp.get("device", None) is None:
120
136
  if num_gpus == 0:
121
137
  hyp["device"] = "cpu"
@@ -30,7 +30,6 @@ RANDOM_MIRROR_X = True # [True, False]
30
30
  LR = 0.0001 # [0.00001, 0.000025, 0.00005, 0.000075, 0.0001, 0.00025, 0.0005, 0.00075, 0.001]
31
31
  PATIENCE = 40 # [30, 35, 40, 45, 50]
32
32
  WARMUP_STEPS = 1000 # [500, 750, 1000, 1250, 1500]
33
- DEFAULT_GENERAL_MODEL = 'autogluon/mitra-classifier'
34
33
  DEFAULT_CLS_MODEL = 'autogluon/mitra-classifier'
35
34
  DEFAULT_REG_MODEL = 'autogluon/mitra-regressor'
36
35
 
@@ -67,9 +66,7 @@ class MitraBase(BaseEstimator):
67
66
  fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
68
67
  metric=DEFAULT_CLS_METRIC,
69
68
  state_dict=None,
70
- hf_general_model=DEFAULT_GENERAL_MODEL,
71
- hf_cls_model=DEFAULT_CLS_MODEL,
72
- hf_reg_model=DEFAULT_REG_MODEL,
69
+ hf_model=None,
73
70
  patience=PATIENCE,
74
71
  lr=LR,
75
72
  warmup_steps=WARMUP_STEPS,
@@ -104,9 +101,7 @@ class MitraBase(BaseEstimator):
104
101
  self.fine_tune_steps = fine_tune_steps
105
102
  self.metric = metric
106
103
  self.state_dict = state_dict
107
- self.hf_general_model = hf_general_model
108
- self.hf_cls_model = hf_cls_model
109
- self.hf_reg_model = hf_reg_model
104
+ self.hf_model = hf_model
110
105
  self.patience = patience
111
106
  self.lr = lr
112
107
  self.warmup_steps = warmup_steps
@@ -200,20 +195,8 @@ class MitraBase(BaseEstimator):
200
195
  self.train_time = 0
201
196
  for _ in range(self.n_estimators):
202
197
  if USE_HF:
203
- if task == 'classification':
204
- if self.hf_cls_model is not None:
205
- model = Tab2D.from_pretrained(self.hf_cls_model, device=self.device)
206
- elif self.hf_general_model is not None:
207
- model = Tab2D.from_pretrained(self.hf_general_model, device=self.device)
208
- else:
209
- model = Tab2D.from_pretrained("autogluon/mitra-classifier", device=self.device)
210
- elif task == 'regression':
211
- if self.hf_reg_model is not None:
212
- model = Tab2D.from_pretrained(self.hf_reg_model, device=self.device)
213
- elif self.hf_general_model is not None:
214
- model = Tab2D.from_pretrained(self.hf_general_model, device=self.device)
215
- else:
216
- model = Tab2D.from_pretrained("autogluon/mitra-regressor", device=self.device)
198
+ assert self.hf_model is not None, f"hf_model must not be None."
199
+ model = Tab2D.from_pretrained(self.hf_model, device=self.device)
217
200
  else:
218
201
  model = Tab2D(
219
202
  dim=cfg.hyperparams['dim'],
@@ -274,6 +257,7 @@ class MitraClassifier(MitraBase, ClassifierMixin):
274
257
  fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
275
258
  metric=DEFAULT_CLS_METRIC,
276
259
  state_dict=None,
260
+ hf_model=DEFAULT_CLS_MODEL,
277
261
  patience=PATIENCE,
278
262
  lr=LR,
279
263
  warmup_steps=WARMUP_STEPS,
@@ -294,6 +278,7 @@ class MitraClassifier(MitraBase, ClassifierMixin):
294
278
  fine_tune_steps,
295
279
  metric,
296
280
  state_dict,
281
+ hf_model=hf_model,
297
282
  patience=patience,
298
283
  lr=lr,
299
284
  warmup_steps=warmup_steps,
@@ -404,6 +389,7 @@ class MitraRegressor(MitraBase, RegressorMixin):
404
389
  fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
405
390
  metric=DEFAULT_REG_METRIC,
406
391
  state_dict=None,
392
+ hf_model=DEFAULT_REG_MODEL,
407
393
  patience=PATIENCE,
408
394
  lr=LR,
409
395
  warmup_steps=WARMUP_STEPS,
@@ -424,6 +410,7 @@ class MitraRegressor(MitraBase, RegressorMixin):
424
410
  fine_tune_steps,
425
411
  metric,
426
412
  state_dict,
413
+ hf_model=hf_model,
427
414
  patience=patience,
428
415
  lr=lr,
429
416
  warmup_steps=warmup_steps,
@@ -525,6 +525,7 @@ class TabularPredictor:
525
525
  'GBM' (LightGBM)
526
526
  'CAT' (CatBoost)
527
527
  'XGB' (XGBoost)
528
+ 'EBM' (Explainable Boosting Machine)
528
529
  'REALMLP' (RealMLP)
529
530
  'TABM' (TabM)
530
531
  'MITRA' (Mitra)
@@ -8,6 +8,7 @@ from . import ModelRegistry
8
8
  from ..models import (
9
9
  BoostedRulesModel,
10
10
  CatBoostModel,
11
+ EBMModel,
11
12
  FastTextModel,
12
13
  FigsModel,
13
14
  FTTransformerModel,
@@ -64,6 +65,7 @@ REGISTERED_MODEL_CLS_LST = [
64
65
  HSTreeModel,
65
66
  BoostedRulesModel,
66
67
  DummyModel,
68
+ EBMModel,
67
69
  ]
68
70
 
69
71
  # TODO: Replace logic in `autogluon.tabular.trainer.model_presets.presets` with `ag_model_registry`
@@ -1,4 +1,4 @@
1
1
  """This is the autogluon version file."""
2
2
 
3
- __version__ = "1.4.1b20250822"
3
+ __version__ = "1.4.1b20250823"
4
4
  __lite__ = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.tabular
3
- Version: 1.4.1b20250822
3
+ Version: 1.4.1b20250823
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -38,6 +38,7 @@ Provides-Extra: lightgbm
38
38
  Provides-Extra: catboost
39
39
  Provides-Extra: xgboost
40
40
  Provides-Extra: realmlp
41
+ Provides-Extra: interpret
41
42
  Provides-Extra: fastai
42
43
  Provides-Extra: tabm
43
44
  Provides-Extra: tabpfn
@@ -42,6 +42,11 @@ src/autogluon/tabular/models/catboost/catboost_utils.py
42
42
  src/autogluon/tabular/models/catboost/hyperparameters/__init__.py
43
43
  src/autogluon/tabular/models/catboost/hyperparameters/parameters.py
44
44
  src/autogluon/tabular/models/catboost/hyperparameters/searchspaces.py
45
+ src/autogluon/tabular/models/ebm/__init__.py
46
+ src/autogluon/tabular/models/ebm/ebm_model.py
47
+ src/autogluon/tabular/models/ebm/hyperparameters/__init__.py
48
+ src/autogluon/tabular/models/ebm/hyperparameters/parameters.py
49
+ src/autogluon/tabular/models/ebm/hyperparameters/searchspaces.py
45
50
  src/autogluon/tabular/models/fastainn/__init__.py
46
51
  src/autogluon/tabular/models/fastainn/callbacks.py
47
52
  src/autogluon/tabular/models/fastainn/fastai_helpers.py