autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/config_helper.py +1 -1
- autogluon/tabular/configs/hyperparameter_configs.py +2 -265
- autogluon/tabular/configs/pipeline_presets.py +130 -0
- autogluon/tabular/configs/presets_configs.py +51 -26
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
- autogluon/tabular/models/__init__.py +6 -1
- autogluon/tabular/models/_utils/rapids_utils.py +1 -1
- autogluon/tabular/models/automm/automm_model.py +2 -0
- autogluon/tabular/models/automm/ft_transformer.py +4 -1
- autogluon/tabular/models/catboost/callbacks.py +3 -2
- autogluon/tabular/models/catboost/catboost_model.py +15 -9
- autogluon/tabular/models/catboost/catboost_utils.py +17 -3
- autogluon/tabular/models/ebm/__init__.py +0 -0
- autogluon/tabular/models/ebm/ebm_model.py +259 -0
- autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
- autogluon/tabular/models/knn/knn_model.py +7 -3
- autogluon/tabular/models/lgb/lgb_model.py +60 -21
- autogluon/tabular/models/lr/lr_model.py +6 -1
- autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
- autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +380 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
- autogluon/tabular/models/realmlp/__init__.py +0 -0
- autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
- autogluon/tabular/models/rf/rf_model.py +11 -6
- autogluon/tabular/models/tabicl/__init__.py +0 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
- autogluon/tabular/models/tabm/__init__.py +0 -0
- autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
- autogluon/tabular/models/tabm/tabm_model.py +356 -0
- autogluon/tabular/models/tabm/tabm_reference.py +631 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
- autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
- autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
- autogluon/tabular/predictor/predictor.py +147 -84
- autogluon/tabular/registry/_ag_model_registry.py +12 -2
- autogluon/tabular/testing/fit_helper.py +57 -27
- autogluon/tabular/testing/generate_datasets.py +7 -0
- autogluon/tabular/trainer/abstract_trainer.py +3 -1
- autogluon/tabular/trainer/model_presets/presets.py +10 -1
- autogluon/tabular/version.py +1 -1
- autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
{autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD
RENAMED
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
autogluon.tabular-1.
|
|
1
|
+
autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth,sha256=kAlKxjI5mE3Pwwqphu2maN5OBQk8W8ew70e_qbI1c6A,482
|
|
2
2
|
autogluon/tabular/__init__.py,sha256=2OXpJCvENRHubBTYNIPpHX93WWuFZzsJBtTZbNVHVas,400
|
|
3
|
-
autogluon/tabular/version.py,sha256
|
|
3
|
+
autogluon/tabular/version.py,sha256=CYNrmn5KfprHt9fJxfHv56jQZ5Q-BPbHMgM1kWYZDD8,91
|
|
4
4
|
autogluon/tabular/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
autogluon/tabular/configs/config_helper.py,sha256=
|
|
5
|
+
autogluon/tabular/configs/config_helper.py,sha256=Rby5gRhuY5IlZWdKbtsmzbSt948B97qxwQ2f1MbH_38,21070
|
|
6
6
|
autogluon/tabular/configs/feature_generator_presets.py,sha256=EV5Ym8VW15q92MwOUpTi7wZFS2QooM51fLg3RdUsn-M,1223
|
|
7
|
-
autogluon/tabular/configs/hyperparameter_configs.py,sha256=
|
|
8
|
-
autogluon/tabular/configs/
|
|
7
|
+
autogluon/tabular/configs/hyperparameter_configs.py,sha256=aQ1rrF8P0MX4Ic5M33O96JtKV-K7YpDrgJmWhYmEyug,6848
|
|
8
|
+
autogluon/tabular/configs/pipeline_presets.py,sha256=ccrT3C56pYHW8x8VB_Q9zAu_eCxlgNQpt7TXpVUzDfE,4761
|
|
9
|
+
autogluon/tabular/configs/presets_configs.py,sha256=_C9wTfKVRyoomtYa04RqNyw1CEOYc_5Q3QKejqDp754,7674
|
|
9
10
|
autogluon/tabular/configs/zeroshot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py,sha256=
|
|
11
|
+
autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py,sha256=6yd84vPqOk-6sLCoM_e_PlphrR2NZUjliS7L1SMKMug,29777
|
|
12
|
+
autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py,sha256=NXwfqZLQLx4kdvRqF6deFDdhZZKxbfgpUurdB0kqOh8,11996
|
|
11
13
|
autogluon/tabular/experimental/__init__.py,sha256=PpkdMSv_pPZted1XRIuzcFWKjM-66VMUukTnCcoiW0s,100
|
|
12
14
|
autogluon/tabular/experimental/_scikit_mixin.py,sha256=cKeCmtURAXZnhQGrkCBw5rmACCQF7biAWTT3qX8bM2Q,2281
|
|
13
15
|
autogluon/tabular/experimental/_tabular_classifier.py,sha256=7lGoFdvkHiZS3VpcXo97q4ENV9qyIVDExlWkm0wzL3s,2527
|
|
@@ -16,27 +18,32 @@ autogluon/tabular/experimental/plot_leaderboard.py,sha256=BN_kB-zmOZNUYWyI7z9pF6
|
|
|
16
18
|
autogluon/tabular/learner/__init__.py,sha256=Hhmk5WpKQHohVmI-veOaKMelKJpIdzeXrmw_DPn3DTU,63
|
|
17
19
|
autogluon/tabular/learner/abstract_learner.py,sha256=0kf0huvg0nphe-lrdKtNTzdIFr14jzJPsfZDRBkKo3g,55253
|
|
18
20
|
autogluon/tabular/learner/default_learner.py,sha256=hjdKbcFtIQxQ3-k1LiGOo-w5sLxIIQAyFLs3-R35aw0,24781
|
|
19
|
-
autogluon/tabular/models/__init__.py,sha256=
|
|
21
|
+
autogluon/tabular/models/__init__.py,sha256=grZ23UfuNZ_LxoNdl-yjIUmq71TeovT5CJPhbatiqvg,1252
|
|
20
22
|
autogluon/tabular/models/_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
|
-
autogluon/tabular/models/_utils/rapids_utils.py,sha256=
|
|
23
|
+
autogluon/tabular/models/_utils/rapids_utils.py,sha256=9A2Y10Owva6zhcLkBVQ_T4tOAMDp1idSMzDWhl_QyBI,1083
|
|
22
24
|
autogluon/tabular/models/_utils/torch_utils.py,sha256=dxs_KMMAOmNkRNjYf_hrzqaHIfkqn1xoKRKqCFbQ1Rk,537
|
|
23
25
|
autogluon/tabular/models/automm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
-
autogluon/tabular/models/automm/automm_model.py,sha256=
|
|
25
|
-
autogluon/tabular/models/automm/ft_transformer.py,sha256=
|
|
26
|
+
autogluon/tabular/models/automm/automm_model.py,sha256=MoydDuPEd5atbUPlVDzWLTKLB7EchcPdSVVncxA9jEM,11355
|
|
27
|
+
autogluon/tabular/models/automm/ft_transformer.py,sha256=X-IEi5uKme7SoRcHnPjGTByzrjCB85I7RpB0hS36TLQ,3897
|
|
26
28
|
autogluon/tabular/models/catboost/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
|
-
autogluon/tabular/models/catboost/callbacks.py,sha256=
|
|
28
|
-
autogluon/tabular/models/catboost/catboost_model.py,sha256=
|
|
29
|
+
autogluon/tabular/models/catboost/callbacks.py,sha256=QvyiynQoxjvfYaYwGNSF5N3gc_wqI9mi1nQiawL0EJ4,7194
|
|
30
|
+
autogluon/tabular/models/catboost/catboost_model.py,sha256=tAT_eklRJDARJsbS72-Nn8PxLmKgIvffzjjrTI1XMXM,18041
|
|
29
31
|
autogluon/tabular/models/catboost/catboost_softclass_utils.py,sha256=UiW0SUb3hFueW5qYtQn6Sbk7Wg7BWN4jqKWeFtbMvgU,3919
|
|
30
|
-
autogluon/tabular/models/catboost/catboost_utils.py,sha256=
|
|
32
|
+
autogluon/tabular/models/catboost/catboost_utils.py,sha256=zJMIsbgyW_JH0eULhUeu_TWR0Qfmf34CnED7c7NvXBw,3899
|
|
31
33
|
autogluon/tabular/models/catboost/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
34
|
autogluon/tabular/models/catboost/hyperparameters/parameters.py,sha256=Hxi4mPTc2ML9GdpW0TalkDgtsYJLwpEcd-LiyLOsmlA,956
|
|
33
35
|
autogluon/tabular/models/catboost/hyperparameters/searchspaces.py,sha256=Oe86ixuvd1xJCdSHs2Oh5Ifx0501YJBsdyL2l9Z4nxM,1458
|
|
36
|
+
autogluon/tabular/models/ebm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
+
autogluon/tabular/models/ebm/ebm_model.py,sha256=PyocCEPxByB-E5gRCZitI5gsP6DVYlxmRx8bbZ31guA,8524
|
|
38
|
+
autogluon/tabular/models/ebm/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
|
+
autogluon/tabular/models/ebm/hyperparameters/parameters.py,sha256=IbDv3Ufx8CGHvejqSbAggZKlMq5X9k0Ggclm_DCoiII,1080
|
|
40
|
+
autogluon/tabular/models/ebm/hyperparameters/searchspaces.py,sha256=G6zgHERKt_KJlVfZ06tFKw2aOUuM7DdDyCm0s5RBXoc,2191
|
|
34
41
|
autogluon/tabular/models/fastainn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
42
|
autogluon/tabular/models/fastainn/callbacks.py,sha256=3WvOEwqd1YAVInooKsFOTzAkCLeIXjEelsglYwfofq0,4788
|
|
36
43
|
autogluon/tabular/models/fastainn/fastai_helpers.py,sha256=gGYzyrAFl8hi8GnsemZNLGZn5xr7cyJXdFl08PIlza4,1393
|
|
37
44
|
autogluon/tabular/models/fastainn/imports_helper.py,sha256=ICxA8ty47-oZu0Q9AjKCQe8uVi340Iu0NFruxvJPrbA,330
|
|
38
45
|
autogluon/tabular/models/fastainn/quantile_helpers.py,sha256=d89GKvSRBgOy9EqcDI83MK5sqPRxP6JJ3BmPLmKnB0o,1808
|
|
39
|
-
autogluon/tabular/models/fastainn/tabular_nn_fastai.py,sha256=
|
|
46
|
+
autogluon/tabular/models/fastainn/tabular_nn_fastai.py,sha256=FqT6xqhU2XoTWJ0yY_ZmT3JI6ranl63vpdPkn6JFbos,29666
|
|
40
47
|
autogluon/tabular/models/fastainn/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
48
|
autogluon/tabular/models/fastainn/hyperparameters/parameters.py,sha256=DkQwAZZ7CuODKoljr-yrkx-uFxBSPRxkKuvPdwO-UhQ,2069
|
|
42
49
|
autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py,sha256=5qdknZDrHtdPdrhSqjamYQrCxvupXvlN3bVGEPgs48E,1660
|
|
@@ -50,34 +57,67 @@ autogluon/tabular/models/imodels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRk
|
|
|
50
57
|
autogluon/tabular/models/imodels/imodels_models.py,sha256=89uQwbRAtqcUvPwYsKnER8SUMIbwkGZUd9spoG_mP10,4878
|
|
51
58
|
autogluon/tabular/models/knn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
52
59
|
autogluon/tabular/models/knn/_knn_loo_variants.py,sha256=-n2znYS7OBA0bZvtei6JZiEMRWp4GX-Qp64uheaHyhQ,4562
|
|
53
|
-
autogluon/tabular/models/knn/knn_model.py,sha256=
|
|
60
|
+
autogluon/tabular/models/knn/knn_model.py,sha256=I7wPRy38oD03f_3KN7Q_CyoJJucDPrPQyJqjgovmx8Q,14061
|
|
54
61
|
autogluon/tabular/models/knn/knn_rapids_model.py,sha256=0FFApNZFH8nyrDqlBSUV7jO-2fLe0-h_UHp1GsyQJ8E,1550
|
|
55
62
|
autogluon/tabular/models/knn/knn_utils.py,sha256=XU1cxVXp1BAoQnja2_KmSIn9_q9gZkjAya7-9b0uStk,7455
|
|
56
63
|
autogluon/tabular/models/lgb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
57
64
|
autogluon/tabular/models/lgb/callbacks.py,sha256=KJB1KmebA88qHT206KSfvm5NamGuv5lRzy7O9dOwW-M,12243
|
|
58
|
-
autogluon/tabular/models/lgb/lgb_model.py,sha256=
|
|
65
|
+
autogluon/tabular/models/lgb/lgb_model.py,sha256=kRIcBBIDMJ2inaZeJXO5uhAG0qUigwYseJoFQ7jzqQE,27415
|
|
59
66
|
autogluon/tabular/models/lgb/lgb_utils.py,sha256=jzTDTzP-z7gcBGZyy1_0YkyTOLbU5DLeRqtil4FCZPI,7382
|
|
60
67
|
autogluon/tabular/models/lgb/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
61
68
|
autogluon/tabular/models/lgb/hyperparameters/parameters.py,sha256=LLEQ-Ns3HElWBsFJx3ogRV7L6qw_nXlcl7EyO0C0fVQ,1336
|
|
62
69
|
autogluon/tabular/models/lgb/hyperparameters/searchspaces.py,sha256=tvNNR7niWz_B-PndYQXb6vVNABxSfBYRHj6ZVQJ1x2E,1930
|
|
63
70
|
autogluon/tabular/models/lr/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
64
|
-
autogluon/tabular/models/lr/lr_model.py,sha256=
|
|
65
|
-
autogluon/tabular/models/lr/lr_preprocessing_utils.py,sha256=
|
|
66
|
-
autogluon/tabular/models/lr/lr_rapids_model.py,sha256=
|
|
71
|
+
autogluon/tabular/models/lr/lr_model.py,sha256=2A6e8Itw-PgjOLjVXeo8bJwFQuVSGYwJNVxhHxFQXlw,15732
|
|
72
|
+
autogluon/tabular/models/lr/lr_preprocessing_utils.py,sha256=tgb75V6zHfMJh8m9GDs5404ItdfwNakqykTk0qjBtFE,1045
|
|
73
|
+
autogluon/tabular/models/lr/lr_rapids_model.py,sha256=XIB1KCPPfBZMxTRC3Wc1Dsl5NTMQSM_m8Uc2igyTLX8,3939
|
|
67
74
|
autogluon/tabular/models/lr/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
68
75
|
autogluon/tabular/models/lr/hyperparameters/parameters.py,sha256=Hr5YC13zjbt3CfCbzGj8iXUIuDn-Q7FvDT2uSuiSVlM,1414
|
|
69
76
|
autogluon/tabular/models/lr/hyperparameters/searchspaces.py,sha256=Igywc-B6qJ9EBLdasrDhW-Ot5FGirIzbXLwv5HRe5Xo,276
|
|
77
|
+
autogluon/tabular/models/mitra/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
|
+
autogluon/tabular/models/mitra/mitra_model.py,sha256=TzjozU19zQLU09S2tM8Sfe7TiTBSDDjld-tVt5L1JGQ,13954
|
|
79
|
+
autogluon/tabular/models/mitra/sklearn_interface.py,sha256=vyg8kkmYKzEJRWiehEqEsgZeOCV20tnZAZaaaJkwDuA,17739
|
|
80
|
+
autogluon/tabular/models/mitra/_internal/__init__.py,sha256=dN2dz1pGMgQTFiSf9oYbyq23iJUxV8QNlOX3qw3KUO4,35
|
|
81
|
+
autogluon/tabular/models/mitra/_internal/config/__init__.py,sha256=Exu_Sx6-K-D5peDQ_TibsjZpqAALs2-9IXfq8hu1mwU,40
|
|
82
|
+
autogluon/tabular/models/mitra/_internal/config/config_pretrain.py,sha256=CeaD96EcDX69LdcLTYGlFmYLdBNINEJXRMWmJ6LbhTg,6038
|
|
83
|
+
autogluon/tabular/models/mitra/_internal/config/config_run.py,sha256=CVna6KOwmF-rIxcyH3mHm63jvM1C6RdFbRLgUGEXDn0,677
|
|
84
|
+
autogluon/tabular/models/mitra/_internal/config/enums.py,sha256=hlyhgXHvHZKgYK1z3DHSHxEsuCHOE7Y2AdokjOG8SWs,3930
|
|
85
|
+
autogluon/tabular/models/mitra/_internal/core/__init__.py,sha256=hgy4uzJfTQFt9hVlbSrOZU9LSUbLM-uZUnG04f1CUcs,31
|
|
86
|
+
autogluon/tabular/models/mitra/_internal/core/callbacks.py,sha256=xYkJUXiGzLvpWcj6a_wRJUK7f_zgjd1BLA8nH6Hc884,2605
|
|
87
|
+
autogluon/tabular/models/mitra/_internal/core/get_loss.py,sha256=hv0t7zvyZ-DgA5PbKpbX_ayq8tEvuW_nJhbudMDqkDk,2243
|
|
88
|
+
autogluon/tabular/models/mitra/_internal/core/get_optimizer.py,sha256=UgGO6lduVZTKZmYAmE207o2Dqs4e3_hyzaoSOQ0iK6A,3412
|
|
89
|
+
autogluon/tabular/models/mitra/_internal/core/get_scheduler.py,sha256=2lzdAxDOYZNq76pmK-FjCOX5MX6cqUSMjqVu8BX9jfY,2238
|
|
90
|
+
autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py,sha256=fai0VnDm0mNjJzx8e1JXdB77PKQsmfbtn8zybD9_qD0,4394
|
|
91
|
+
autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py,sha256=tI8sN9mv3PtEBdmDxcBgzderZ7YQdtn6MxtOWAc8or8,17908
|
|
92
|
+
autogluon/tabular/models/mitra/_internal/data/__init__.py,sha256=u4ZTvTQNIHqqxilkVqTmYShI2jFMCOyMdv1GRExvtj0,42
|
|
93
|
+
autogluon/tabular/models/mitra/_internal/data/collator.py,sha256=o2F7ODs_eUnV947lCQTx9RugrANidCdiwnZWtdVNJnE,2300
|
|
94
|
+
autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py,sha256=AYxyQ1NJZ3pAp6ny-Y_hqw_4VtyW5X1AABchf7pVsSM,4340
|
|
95
|
+
autogluon/tabular/models/mitra/_internal/data/dataset_split.py,sha256=0uvfyiKrzipde4ZcCDwTE1E3zHelE8xbuNvCeL38J5c,2033
|
|
96
|
+
autogluon/tabular/models/mitra/_internal/data/preprocessor.py,sha256=zx2pWrpDaGSSawPaj7ieRjFOtct_Fyh8LYjo_YtlNG0,13821
|
|
97
|
+
autogluon/tabular/models/mitra/_internal/models/__init__.py,sha256=K0vh5pyrntXp-o7gWNgQ0ZvDbxgeQuRgb6u8ecdjFhA,45
|
|
98
|
+
autogluon/tabular/models/mitra/_internal/models/base.py,sha256=PKpMPT5OT9JFnmYPnhzFUeZPwdNM1e-k97_gW8GZq0Y,468
|
|
99
|
+
autogluon/tabular/models/mitra/_internal/models/embedding.py,sha256=74O6cGWhUyHxg4-wiQwy4sPeDYQze2ekI9H5mLUtSLg,6223
|
|
100
|
+
autogluon/tabular/models/mitra/_internal/models/tab2d.py,sha256=o_S572-nKrhwxmEFaDSTvTLE7KztOvQmARRrc7CIeCY,25783
|
|
101
|
+
autogluon/tabular/models/mitra/_internal/utils/__init__.py,sha256=0mhykAqjMmcEc8Y2od_DMPMk8f66LZHWM7qFdUrPddU,34
|
|
102
|
+
autogluon/tabular/models/mitra/_internal/utils/set_seed.py,sha256=UnXzYfhmfT_tNAofKtLkKpwB9b6HVf9cpI4mKvoBuNM,340
|
|
103
|
+
autogluon/tabular/models/realmlp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
|
+
autogluon/tabular/models/realmlp/realmlp_model.py,sha256=3pe_yhOGW8cbX3KgNs25s3FP0P3FzVSAS-hd4jMFjDg,14573
|
|
70
105
|
autogluon/tabular/models/rf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
71
|
-
autogluon/tabular/models/rf/rf_model.py,sha256=
|
|
106
|
+
autogluon/tabular/models/rf/rf_model.py,sha256=auvNHx0qD9Pz8rS6yNIuG9cHzFNquv8fOVS7FWZNIAw,21721
|
|
72
107
|
autogluon/tabular/models/rf/rf_quantile.py,sha256=2S8FE8po9lMnZaeKuVkzOUFOcdil46ZbFqm49OuvNZY,36460
|
|
73
108
|
autogluon/tabular/models/rf/rf_rapids_model.py,sha256=3s-8M11dzCl_2Lu5iB3H8YjHLgyP_SElrm_4w_HfmqY,2028
|
|
74
109
|
autogluon/tabular/models/rf/compilers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
75
110
|
autogluon/tabular/models/rf/compilers/native.py,sha256=HhaqQRkVuf9UEEJPsHcdYCmuWBMYtyqRwwB_N2qxG2M,1313
|
|
76
111
|
autogluon/tabular/models/rf/compilers/onnx.py,sha256=pvaZWdl2JJaE2pFU0mFugzhnybePqe0x1-5oLOvogA0,4318
|
|
77
|
-
autogluon/tabular/models/
|
|
78
|
-
autogluon/tabular/models/
|
|
112
|
+
autogluon/tabular/models/tabicl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
113
|
+
autogluon/tabular/models/tabicl/tabicl_model.py,sha256=_Eq3g9babdC17kyvAA0rIqtZEtiRGwM2XngkbWevXpU,6283
|
|
114
|
+
autogluon/tabular/models/tabm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
115
|
+
autogluon/tabular/models/tabm/_tabm_internal.py,sha256=fRQ-s5PN94kWqf3LRDen7su_fd-d332YKxdms30FoZM,21066
|
|
116
|
+
autogluon/tabular/models/tabm/rtdl_num_embeddings.py,sha256=XssNMaUM0E0G8Grzl_VkVsLt2FcMf3I4cplfvQdVum0,30156
|
|
117
|
+
autogluon/tabular/models/tabm/tabm_model.py,sha256=_SGc7R87ug9m8KGd_BgC9maJ7sjOAlYB9vtg1omwOto,13640
|
|
118
|
+
autogluon/tabular/models/tabm/tabm_reference.py,sha256=byyP6lcJjA4THbP1VDTgJkj62zyz2S3mEvxWB-kFROw,21944
|
|
79
119
|
autogluon/tabular/models/tabpfnmix/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
80
|
-
autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py,sha256=
|
|
120
|
+
autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py,sha256=NAuV3rJia-UNnFwiFU5tkz6vzZ2lokQ_12vUJ3E6wAA,16498
|
|
81
121
|
autogluon/tabular/models/tabpfnmix/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
82
122
|
autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py,sha256=_WIO_YQBUCfprKYLHxUNEICPb5XWZw4zbw00DuiTk_s,3426
|
|
83
123
|
autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py,sha256=J6JvrK6L6y3s-Ah6sHQdjSK0mwAMP-Wy3RRBwzB0AoA,3196
|
|
@@ -102,15 +142,24 @@ autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py,sha2
|
|
|
102
142
|
autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py,sha256=bhNpGIA5BKqIVX-kDW4bZLgsOB_A8iNsnpgoyyBLR98,5383
|
|
103
143
|
autogluon/tabular/models/tabpfnmix/_internal/results/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
144
|
autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py,sha256=1tRPHyViSSLJ7BkQJi6wai-PwXJ56od86Dy1WWKWZq4,1743
|
|
145
|
+
autogluon/tabular/models/tabpfnv2/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
146
|
+
autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py,sha256=nXZcq4SMV54dciOKFM57Suc9eVyXQXy-2iN6moRt2b8,14801
|
|
147
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py,sha256=yE5XAhGxKEFV0JcelZ_JTQZIWGlVEVUQ9a-lxcH_Esc,585
|
|
148
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/configs.py,sha256=lzBY9kKOeBZACVrtRDPHF4ATs9g1rxyNnIs2CMjE20c,1175
|
|
149
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py,sha256=uvHsfvnnMdg4tP3_7zAilktkw7nr65LaqfVKXabXAow,6785
|
|
150
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py,sha256=-KQNm_HYWem6HWUsdbnIX4lKe-eW0PQAXZUny2kqego,55582
|
|
151
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py,sha256=FRJSelTtDaKnpsKKHphjy2rJrFX302miSdHZ0YqHxCQ,28045
|
|
152
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py,sha256=jv2ZHsGwcO4Inhxtol_tig3NoXZQR649dhmW_Kv69QY,29607
|
|
153
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/utils.py,sha256=vjMQsNaZZcW1BBf0hduSCtrNCtSd467xfkhsbHspUog,3489
|
|
105
154
|
autogluon/tabular/models/tabular_nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
106
155
|
autogluon/tabular/models/tabular_nn/compilers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
107
156
|
autogluon/tabular/models/tabular_nn/compilers/native.py,sha256=W8d8cqBj7U-KVhfGK3hdtGj8JJm3lXr_SecU0615Gbs,1330
|
|
108
157
|
autogluon/tabular/models/tabular_nn/compilers/onnx.py,sha256=3mj9_5p6YMOuKbYk7FBQ2Ijhm1kGzfqq6cyyKLUKLOo,14804
|
|
109
158
|
autogluon/tabular/models/tabular_nn/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
110
|
-
autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py,sha256=
|
|
159
|
+
autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py,sha256=kGvfuDZa9wDCCTEeytVLKhOAeR0pCcoVNJcWjketmBI,6375
|
|
111
160
|
autogluon/tabular/models/tabular_nn/hyperparameters/searchspaces.py,sha256=pT9cJ3MaWPnaQwAf47Yz6f0-L9qDBknahERbggAp52U,2810
|
|
112
161
|
autogluon/tabular/models/tabular_nn/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
113
|
-
autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py,sha256=
|
|
162
|
+
autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py,sha256=TGVMv_ClKh0iYVVCqgd19DE-1fXk_VODpsXIMvzI3Sw,42978
|
|
114
163
|
autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py,sha256=RdnQGZSrvY1iuJB4JTANniH3Dorw-DP0Em_JK3_h7RM,13497
|
|
115
164
|
autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py,sha256=Qc3PwXTD8A7PgXi6EGuaBCrN3jsFAXDLCW7i6tE5wYI,11338
|
|
116
165
|
autogluon/tabular/models/tabular_nn/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -121,7 +170,7 @@ autogluon/tabular/models/text_prediction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5
|
|
|
121
170
|
autogluon/tabular/models/text_prediction/text_prediction_v1_model.py,sha256=PBN7F98qgEAO6U76rV_hxZfAmKr_XpVKjElOdBvfX8c,1090
|
|
122
171
|
autogluon/tabular/models/xgboost/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
123
172
|
autogluon/tabular/models/xgboost/callbacks.py,sha256=PuRQUg3AEjgvFa-dpstRFoEVM9jHDe5W4XYSdDPRqoE,7009
|
|
124
|
-
autogluon/tabular/models/xgboost/xgboost_model.py,sha256=
|
|
173
|
+
autogluon/tabular/models/xgboost/xgboost_model.py,sha256=tKVLvBnuTbDaFwBRVDZ5ADo4PjBF2FDR93Ib86WYTMM,15630
|
|
125
174
|
autogluon/tabular/models/xgboost/xgboost_utils.py,sha256=FVqZ8h4JAe_pifSvNx83cLZHwsuzTXylrrcan07AoNo,5757
|
|
126
175
|
autogluon/tabular/models/xgboost/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
127
176
|
autogluon/tabular/models/xgboost/hyperparameters/parameters.py,sha256=ay6bVVpiPzftbtz6TTS76w7j4vjDjzHFpuf2Bjf6Zu4,1673
|
|
@@ -130,27 +179,27 @@ autogluon/tabular/models/xt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMp
|
|
|
130
179
|
autogluon/tabular/models/xt/xt_model.py,sha256=qOHJ5h1lHI7uYJfbl0BWm-29R3MNp2WeZB9ptcq5Xis,1003
|
|
131
180
|
autogluon/tabular/predictor/__init__.py,sha256=zCMgjxQlWpDWnr1l1xjBCiK3rWC3N3RoD8UXBnazT74,107
|
|
132
181
|
autogluon/tabular/predictor/interpretable_predictor.py,sha256=5UeKgnMFsfY65tiO3kxfHBPr03lyswLrgdtjPhI0Y7Q,6934
|
|
133
|
-
autogluon/tabular/predictor/predictor.py,sha256=
|
|
182
|
+
autogluon/tabular/predictor/predictor.py,sha256=fjw7CQALXZ7AR18ryLm4xWwDzRBeUnrmNubPS8U_pmQ,361223
|
|
134
183
|
autogluon/tabular/registry/__init__.py,sha256=vZpzX4Xve7bfA9crt5LxjgQv9PPfxbi1E1U6Im0Y_xU,93
|
|
135
|
-
autogluon/tabular/registry/_ag_model_registry.py,sha256=
|
|
184
|
+
autogluon/tabular/registry/_ag_model_registry.py,sha256=2Zx5qxXvOdXIbL1FKslNh2M_JM2YG_7GvsCMFF11wDY,1578
|
|
136
185
|
autogluon/tabular/registry/_model_registry.py,sha256=Rl8Q7BLzaif4hxNxJF20xGE02vrWwh2ZuUaTmA-UJnE,6824
|
|
137
186
|
autogluon/tabular/testing/__init__.py,sha256=XrEGLmMdmRT6QHNR13M9wna57LO4O3Q4tt27Ca8omAc,79
|
|
138
|
-
autogluon/tabular/testing/fit_helper.py,sha256=
|
|
139
|
-
autogluon/tabular/testing/generate_datasets.py,sha256=
|
|
187
|
+
autogluon/tabular/testing/fit_helper.py,sha256=pj3P0ENMDhr04laxsLL0_IDX-8msMFo9Wn5XSLFCaqI,21092
|
|
188
|
+
autogluon/tabular/testing/generate_datasets.py,sha256=nvcAmI-tOh5fwx_ZTx2aRa1n7CsXb96wbR-xqNy1C5w,3884
|
|
140
189
|
autogluon/tabular/testing/model_fit_helper.py,sha256=ZjWpw2nyeFnsrccmkfQtx3qbA8HJx282XX2rwdS-LIs,3808
|
|
141
190
|
autogluon/tabular/trainer/__init__.py,sha256=PW_PGL-tWoQzx3ES2S53bQEZOtsRWTYiM9QdOqsk0dI,38
|
|
142
|
-
autogluon/tabular/trainer/abstract_trainer.py,sha256=
|
|
191
|
+
autogluon/tabular/trainer/abstract_trainer.py,sha256=9FiBqOV2h8era6KfydFSqhTlh7RnHkvlvzqsZuij7nE,232527
|
|
143
192
|
autogluon/tabular/trainer/auto_trainer.py,sha256=ZQgQKFT1iHzzun5o5ojdq5pSQmr9ctTkNhe2r9OPOr0,8731
|
|
144
193
|
autogluon/tabular/trainer/model_presets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
145
|
-
autogluon/tabular/trainer/model_presets/presets.py,sha256=
|
|
194
|
+
autogluon/tabular/trainer/model_presets/presets.py,sha256=hoWADaOG576Q_XLV1nY_ju1OWi7EJwHay4jjljqt_E0,16546
|
|
146
195
|
autogluon/tabular/trainer/model_presets/presets_distill.py,sha256=MnFC2GJc6RmDBNAGbsO2XMfo3PjR8cUrZoilWW8gTYQ,3295
|
|
147
196
|
autogluon/tabular/tuning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
148
197
|
autogluon/tabular/tuning/feature_pruner.py,sha256=9iNku8gVbYEkjuKlyITPJDicsNkoraaQOlINQq9iZlQ,6877
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
198
|
+
autogluon_tabular-1.4.1b20251214.dist-info/licenses/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
|
|
199
|
+
autogluon_tabular-1.4.1b20251214.dist-info/licenses/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
|
|
200
|
+
autogluon_tabular-1.4.1b20251214.dist-info/METADATA,sha256=XbmjT9lmPhMkbhK6fgfIBrv6zMe1EMZ3wvoTx_Waons,17015
|
|
201
|
+
autogluon_tabular-1.4.1b20251214.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
202
|
+
autogluon_tabular-1.4.1b20251214.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
|
|
203
|
+
autogluon_tabular-1.4.1b20251214.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
|
|
204
|
+
autogluon_tabular-1.4.1b20251214.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
205
|
+
autogluon_tabular-1.4.1b20251214.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|
|
@@ -1,153 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
|
|
6
|
-
from autogluon.core.constants import BINARY, MULTICLASS
|
|
7
|
-
from autogluon.core.models import AbstractModel
|
|
8
|
-
from autogluon.core.utils import generate_train_test_split
|
|
9
|
-
from autogluon.features.generators import LabelEncoderFeatureGenerator
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class TabPFNModel(AbstractModel):
|
|
13
|
-
"""
|
|
14
|
-
AutoGluon model wrapper to the TabPFN model: https://github.com/automl/TabPFN
|
|
15
|
-
|
|
16
|
-
Paper: "TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second"
|
|
17
|
-
Authors: Noah Hollmann, Samuel Müller, Katharina Eggensperger, and Frank Hutter
|
|
18
|
-
|
|
19
|
-
TabPFN is a viable model option when inference speed is not a concern,
|
|
20
|
-
and the number of rows of training data is less than 10,000.
|
|
21
|
-
|
|
22
|
-
Additionally, TabPFN is only available for classification tasks with up to 10 classes and 100 features.
|
|
23
|
-
|
|
24
|
-
To use this model, `tabpfn` must be installed.
|
|
25
|
-
To install TabPFN, you can run `pip install autogluon.tabular[tabpfn]` or `pip install tabpfn`.
|
|
26
|
-
"""
|
|
27
|
-
ag_key = "TABPFN"
|
|
28
|
-
ag_name = "TabPFN"
|
|
29
|
-
ag_priority = 110
|
|
30
|
-
|
|
31
|
-
def __init__(self, **kwargs):
|
|
32
|
-
super().__init__(**kwargs)
|
|
33
|
-
self._feature_generator = None
|
|
34
|
-
|
|
35
|
-
def _fit(self, X: pd.DataFrame, y: pd.Series, **kwargs):
|
|
36
|
-
from tabpfn import TabPFNClassifier
|
|
37
|
-
|
|
38
|
-
ag_params = self._get_ag_params()
|
|
39
|
-
sample_rows = ag_params.get("sample_rows")
|
|
40
|
-
max_features = ag_params.get("max_features")
|
|
41
|
-
max_classes = ag_params.get("max_classes")
|
|
42
|
-
if max_classes is not None and self.num_classes > max_classes:
|
|
43
|
-
# TODO: Move to earlier stage when problem_type is checked
|
|
44
|
-
raise AssertionError(f"Max allowed classes for the model is {max_classes}, " f"but found {self.num_classes} classes.")
|
|
45
|
-
|
|
46
|
-
# TODO: Make sample_rows generic
|
|
47
|
-
if sample_rows is not None and len(X) > sample_rows:
|
|
48
|
-
X, y = self._subsample_train(X=X, y=y, num_rows=sample_rows)
|
|
49
|
-
X = self.preprocess(X)
|
|
50
|
-
num_features = X.shape[1]
|
|
51
|
-
# TODO: Make max_features generic
|
|
52
|
-
if max_features is not None and num_features > max_features:
|
|
53
|
-
raise AssertionError(f"Max allowed features for the model is {max_features}, " f"but found {num_features} features.")
|
|
54
|
-
hyp = self._get_model_params()
|
|
55
|
-
N_ensemble_configurations = hyp.get("N_ensemble_configurations")
|
|
56
|
-
self.model = TabPFNClassifier(device="cpu", N_ensemble_configurations=N_ensemble_configurations).fit( # TODO: Add GPU option
|
|
57
|
-
X, y, overwrite_warning=True
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
# TODO: Make this generic by creating a generic `preprocess_train` and putting this logic prior to `_preprocess`.
|
|
61
|
-
def _subsample_train(self, X: pd.DataFrame, y: pd.Series, num_rows: int, random_state=0) -> (pd.DataFrame, pd.Series):
|
|
62
|
-
num_rows_to_drop = len(X) - num_rows
|
|
63
|
-
X, _, y, _ = generate_train_test_split(
|
|
64
|
-
X=X,
|
|
65
|
-
y=y,
|
|
66
|
-
problem_type=self.problem_type,
|
|
67
|
-
test_size=num_rows_to_drop,
|
|
68
|
-
random_state=random_state,
|
|
69
|
-
min_cls_count_train=1,
|
|
70
|
-
)
|
|
71
|
-
return X, y
|
|
72
|
-
|
|
73
|
-
def _preprocess(self, X: pd.DataFrame, **kwargs) -> np.ndarray:
|
|
74
|
-
"""
|
|
75
|
-
Converts categorical to label encoded integers
|
|
76
|
-
Keeps missing values, as TabPFN automatically handles missing values internally.
|
|
77
|
-
"""
|
|
78
|
-
X = super()._preprocess(X, **kwargs)
|
|
79
|
-
if self._feature_generator is None:
|
|
80
|
-
self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
|
|
81
|
-
self._feature_generator.fit(X=X)
|
|
82
|
-
if self._feature_generator.features_in:
|
|
83
|
-
X = X.copy()
|
|
84
|
-
X[self._feature_generator.features_in] = self._feature_generator.transform(X=X)
|
|
85
|
-
X = X.to_numpy(dtype=np.float32)
|
|
86
|
-
return X
|
|
87
|
-
|
|
88
|
-
def _set_default_params(self):
|
|
89
|
-
"""
|
|
90
|
-
By default, we only use 1 ensemble configurations to speed up inference times.
|
|
91
|
-
Increase the value to improve model quality while linearly increasing inference time.
|
|
92
|
-
|
|
93
|
-
Model quality improvement diminishes significantly beyond `N_ensemble_configurations=8`.
|
|
94
|
-
"""
|
|
95
|
-
default_params = {
|
|
96
|
-
"N_ensemble_configurations": 1,
|
|
97
|
-
}
|
|
98
|
-
for param, val in default_params.items():
|
|
99
|
-
self._set_default_param_value(param, val)
|
|
100
|
-
|
|
101
|
-
@classmethod
|
|
102
|
-
def supported_problem_types(cls) -> list[str] | None:
|
|
103
|
-
return ["binary", "multiclass"]
|
|
104
|
-
|
|
105
|
-
def _get_default_auxiliary_params(self) -> dict:
|
|
106
|
-
"""
|
|
107
|
-
TabPFN was originally learned on synthetic datasets with 1024 rows, and struggles to
|
|
108
|
-
leverage additional rows effectively beyond a certain point.
|
|
109
|
-
|
|
110
|
-
In the TabPFN paper, performance appeared to stagnate around 4000 rows of training data (Figure 10).
|
|
111
|
-
Thus, we set `sample_rows=4096`, to only use that many rows of training data, even if more is available.
|
|
112
|
-
|
|
113
|
-
TODO: TabPFN scales poorly on large datasets, so we set `max_rows=20000`.
|
|
114
|
-
Not implemented yet, first move this logic to the trainer level to avoid `refit_full` edge-case crashes.
|
|
115
|
-
TabPFN only works on datasets with at most 100 features, so we set `max_features=100`.
|
|
116
|
-
TabPFN only works on datasets with at most 10 classes, so we set `max_classes=10`.
|
|
117
|
-
"""
|
|
118
|
-
default_auxiliary_params = super()._get_default_auxiliary_params()
|
|
119
|
-
default_auxiliary_params.update(
|
|
120
|
-
{
|
|
121
|
-
"sample_rows": 4096,
|
|
122
|
-
# 'max_rows': 20000,
|
|
123
|
-
"max_features": 100,
|
|
124
|
-
"max_classes": 10,
|
|
125
|
-
}
|
|
126
|
-
)
|
|
127
|
-
return default_auxiliary_params
|
|
128
|
-
|
|
129
|
-
# FIXME: Enabling parallel bagging TabPFN creates a lot of warnings / potential failures from Ray
|
|
130
|
-
# TODO: Consider not setting `max_sets=1`, and only setting it in the preset hyperparameter definition.
|
|
131
|
-
@classmethod
|
|
132
|
-
def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
|
|
133
|
-
"""
|
|
134
|
-
Set max_sets to 1 when bagging, otherwise inference time could become extremely slow.
|
|
135
|
-
Set fold_fitting_strategy to sequential_local, as parallel folding causing many warnings / potential errors from Ray.
|
|
136
|
-
"""
|
|
137
|
-
default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
|
|
138
|
-
extra_ag_args_ensemble = {
|
|
139
|
-
"max_sets": 1,
|
|
140
|
-
"fold_fitting_strategy": "sequential_local",
|
|
141
|
-
}
|
|
142
|
-
default_ag_args_ensemble.update(extra_ag_args_ensemble)
|
|
143
|
-
return default_ag_args_ensemble
|
|
144
|
-
|
|
145
|
-
def _ag_params(self) -> set:
|
|
146
|
-
return {"sample_rows", "max_features", "max_classes"}
|
|
147
|
-
|
|
148
|
-
def _more_tags(self) -> dict:
|
|
149
|
-
"""
|
|
150
|
-
Because TabPFN doesn't use validation data for early stopping, it supports refit_full natively.
|
|
151
|
-
"""
|
|
152
|
-
tags = {"can_refit_full": True}
|
|
153
|
-
return tags
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('autogluon',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('autogluon', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('autogluon', [os.path.dirname(p)])));m = m or sys.modules.setdefault('autogluon', types.ModuleType('autogluon'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe
RENAMED
|
File without changes
|