autogluon.tabular 1.5.0b20251228__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.tabular might be problematic. Click here for more details.
- autogluon/tabular/__init__.py +1 -0
- autogluon/tabular/configs/config_helper.py +18 -6
- autogluon/tabular/configs/feature_generator_presets.py +3 -1
- autogluon/tabular/configs/hyperparameter_configs.py +42 -9
- autogluon/tabular/configs/presets_configs.py +38 -14
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
- autogluon/tabular/experimental/_scikit_mixin.py +6 -2
- autogluon/tabular/experimental/_tabular_classifier.py +3 -1
- autogluon/tabular/experimental/_tabular_regressor.py +3 -1
- autogluon/tabular/experimental/plot_leaderboard.py +73 -19
- autogluon/tabular/learner/abstract_learner.py +160 -42
- autogluon/tabular/learner/default_learner.py +78 -22
- autogluon/tabular/models/__init__.py +2 -2
- autogluon/tabular/models/_utils/rapids_utils.py +3 -1
- autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
- autogluon/tabular/models/automm/automm_model.py +12 -3
- autogluon/tabular/models/automm/ft_transformer.py +5 -1
- autogluon/tabular/models/catboost/callbacks.py +2 -2
- autogluon/tabular/models/catboost/catboost_model.py +93 -29
- autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
- autogluon/tabular/models/catboost/catboost_utils.py +3 -1
- autogluon/tabular/models/ebm/ebm_model.py +8 -13
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
- autogluon/tabular/models/fastainn/callbacks.py +20 -3
- autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
- autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
- autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
- autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
- autogluon/tabular/models/knn/knn_model.py +41 -8
- autogluon/tabular/models/lgb/callbacks.py +32 -9
- autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
- autogluon/tabular/models/lgb/lgb_model.py +150 -34
- autogluon/tabular/models/lgb/lgb_utils.py +12 -4
- autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
- autogluon/tabular/models/lr/lr_model.py +40 -10
- autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
- autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
- autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
- autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
- autogluon/tabular/models/mitra/mitra_model.py +16 -11
- autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
- autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
- autogluon/tabular/models/rf/compilers/onnx.py +1 -1
- autogluon/tabular/models/rf/rf_model.py +45 -12
- autogluon/tabular/models/rf/rf_quantile.py +4 -2
- autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
- autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
- autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
- autogluon/tabular/models/tabm/tabm_model.py +8 -4
- autogluon/tabular/models/tabm/tabm_reference.py +53 -85
- autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
- autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
- autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
- autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
- autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
- autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
- autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
- autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
- autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
- autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
- autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
- autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
- autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
- autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
- autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
- autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
- autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
- autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
- autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
- autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
- autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
- autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
- autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
- autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
- autogluon/tabular/models/xgboost/callbacks.py +9 -3
- autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
- autogluon/tabular/models/xt/xt_model.py +1 -0
- autogluon/tabular/predictor/interpretable_predictor.py +3 -1
- autogluon/tabular/predictor/predictor.py +409 -128
- autogluon/tabular/registry/__init__.py +1 -1
- autogluon/tabular/registry/_ag_model_registry.py +4 -5
- autogluon/tabular/registry/_model_registry.py +1 -0
- autogluon/tabular/testing/fit_helper.py +55 -15
- autogluon/tabular/testing/generate_datasets.py +1 -1
- autogluon/tabular/testing/model_fit_helper.py +10 -4
- autogluon/tabular/trainer/abstract_trainer.py +644 -230
- autogluon/tabular/trainer/auto_trainer.py +19 -8
- autogluon/tabular/trainer/model_presets/presets.py +33 -9
- autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
- autogluon/tabular/version.py +1 -1
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
- /autogluon.tabular-1.5.0b20251228-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import os
|
|
4
5
|
import time
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
import contextlib
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import pandas as pd
|
|
@@ -18,34 +18,36 @@ from ._internal.models.tab2d import Tab2D
|
|
|
18
18
|
from ._internal.utils.set_seed import set_seed
|
|
19
19
|
|
|
20
20
|
# Hyperparameter search space
|
|
21
|
-
DEFAULT_FINE_TUNE = True
|
|
22
|
-
DEFAULT_FINE_TUNE_STEPS = 50
|
|
23
|
-
DEFAULT_CLS_METRIC =
|
|
24
|
-
DEFAULT_REG_METRIC =
|
|
25
|
-
SHUFFLE_CLASSES = False
|
|
26
|
-
SHUFFLE_FEATURES = False
|
|
27
|
-
USE_RANDOM_TRANSFORMS = False
|
|
28
|
-
RANDOM_MIRROR_REGRESSION = True
|
|
29
|
-
RANDOM_MIRROR_X = True
|
|
30
|
-
LR = 0.0001
|
|
31
|
-
PATIENCE = 40
|
|
32
|
-
WARMUP_STEPS = 1000
|
|
33
|
-
DEFAULT_CLS_MODEL =
|
|
34
|
-
DEFAULT_REG_MODEL =
|
|
21
|
+
DEFAULT_FINE_TUNE = True # [True, False]
|
|
22
|
+
DEFAULT_FINE_TUNE_STEPS = 50 # [50, 60, 70, 80, 90, 100]
|
|
23
|
+
DEFAULT_CLS_METRIC = "log_loss" # ['log_loss', 'accuracy', 'auc']
|
|
24
|
+
DEFAULT_REG_METRIC = "mse" # ['mse', 'mae', 'rmse', 'r2']
|
|
25
|
+
SHUFFLE_CLASSES = False # [True, False]
|
|
26
|
+
SHUFFLE_FEATURES = False # [True, False]
|
|
27
|
+
USE_RANDOM_TRANSFORMS = False # [True, False]
|
|
28
|
+
RANDOM_MIRROR_REGRESSION = True # [True, False]
|
|
29
|
+
RANDOM_MIRROR_X = True # [True, False]
|
|
30
|
+
LR = 0.0001 # [0.00001, 0.000025, 0.00005, 0.000075, 0.0001, 0.00025, 0.0005, 0.00075, 0.001]
|
|
31
|
+
PATIENCE = 40 # [30, 35, 40, 45, 50]
|
|
32
|
+
WARMUP_STEPS = 1000 # [500, 750, 1000, 1250, 1500]
|
|
33
|
+
DEFAULT_CLS_MODEL = "autogluon/mitra-classifier"
|
|
34
|
+
DEFAULT_REG_MODEL = "autogluon/mitra-regressor"
|
|
35
35
|
|
|
36
36
|
# Constants
|
|
37
37
|
SEED = 0
|
|
38
38
|
DEFAULT_MODEL_TYPE = "Tab2D"
|
|
39
39
|
|
|
40
|
+
|
|
40
41
|
def _get_default_device():
|
|
41
42
|
"""Get the best available device for the current system."""
|
|
42
43
|
if torch.cuda.is_available():
|
|
43
44
|
return "cuda"
|
|
44
|
-
elif hasattr(torch.backends,
|
|
45
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
45
46
|
return "mps" # Apple silicon
|
|
46
47
|
else:
|
|
47
48
|
return "cpu"
|
|
48
49
|
|
|
50
|
+
|
|
49
51
|
DEFAULT_DEVICE = _get_default_device()
|
|
50
52
|
DEFAULT_ENSEMBLE = 1
|
|
51
53
|
DEFAULT_DIM = 512
|
|
@@ -55,32 +57,34 @@ DEFAULT_CLASSES = 10
|
|
|
55
57
|
DEFAULT_VALIDATION_SPLIT = 0.2
|
|
56
58
|
USE_HF = True # Use Hugging Face pretrained models if available
|
|
57
59
|
|
|
60
|
+
|
|
58
61
|
class MitraBase(BaseEstimator):
|
|
59
62
|
"""Base class for Mitra models with common functionality."""
|
|
60
63
|
|
|
61
|
-
def __init__(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
model_type=DEFAULT_MODEL_TYPE,
|
|
67
|
+
n_estimators=DEFAULT_ENSEMBLE,
|
|
68
|
+
device=DEFAULT_DEVICE,
|
|
69
|
+
fine_tune=DEFAULT_FINE_TUNE,
|
|
70
|
+
fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
|
|
71
|
+
metric=DEFAULT_CLS_METRIC,
|
|
72
|
+
state_dict=None,
|
|
73
|
+
hf_model=None,
|
|
74
|
+
patience=PATIENCE,
|
|
75
|
+
lr=LR,
|
|
76
|
+
warmup_steps=WARMUP_STEPS,
|
|
77
|
+
shuffle_classes=SHUFFLE_CLASSES,
|
|
78
|
+
shuffle_features=SHUFFLE_FEATURES,
|
|
79
|
+
use_random_transforms=USE_RANDOM_TRANSFORMS,
|
|
80
|
+
random_mirror_regression=RANDOM_MIRROR_REGRESSION,
|
|
81
|
+
random_mirror_x=RANDOM_MIRROR_X,
|
|
82
|
+
seed=SEED,
|
|
83
|
+
verbose=True,
|
|
84
|
+
):
|
|
81
85
|
"""
|
|
82
86
|
Initialize the base Mitra model.
|
|
83
|
-
|
|
87
|
+
|
|
84
88
|
Parameters
|
|
85
89
|
----------
|
|
86
90
|
model_type : str, default="Tab2D"
|
|
@@ -125,59 +129,62 @@ class MitraBase(BaseEstimator):
|
|
|
125
129
|
model_name=ModelName.TAB2D,
|
|
126
130
|
seed=self.seed,
|
|
127
131
|
hyperparams={
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
132
|
+
"dim_embedding": None,
|
|
133
|
+
"early_stopping_data_split": "VALID",
|
|
134
|
+
"early_stopping_max_samples": 2048,
|
|
135
|
+
"early_stopping_patience": self.patience,
|
|
136
|
+
"grad_scaler_enabled": False,
|
|
137
|
+
"grad_scaler_growth_interval": 1000,
|
|
138
|
+
"grad_scaler_scale_init": 65536.0,
|
|
139
|
+
"grad_scaler_scale_min": 65536.0,
|
|
140
|
+
"label_smoothing": 0.0,
|
|
141
|
+
"lr_scheduler": False,
|
|
142
|
+
"lr_scheduler_patience": 25,
|
|
143
|
+
"max_epochs": self.fine_tune_steps if self.fine_tune else 0,
|
|
144
|
+
"max_samples_query": 1024,
|
|
145
|
+
"max_samples_support": 8192,
|
|
146
|
+
"optimizer": "adamw",
|
|
147
|
+
"lr": self.lr,
|
|
148
|
+
"weight_decay": 0.1,
|
|
149
|
+
"warmup_steps": self.warmup_steps,
|
|
150
|
+
"path_to_weights": self.state_dict,
|
|
151
|
+
"precision": "bfloat16",
|
|
152
|
+
"random_mirror_regression": self.random_mirror_regression,
|
|
153
|
+
"random_mirror_x": self.random_mirror_x,
|
|
154
|
+
"shuffle_classes": self.shuffle_classes,
|
|
155
|
+
"shuffle_features": self.shuffle_features,
|
|
156
|
+
"use_random_transforms": self.use_random_transforms,
|
|
157
|
+
"use_feature_count_scaling": False,
|
|
158
|
+
"use_pretrained_weights": False,
|
|
159
|
+
"use_quantile_transformer": False,
|
|
160
|
+
"budget": time_limit,
|
|
161
|
+
"metric": self.metric,
|
|
158
162
|
},
|
|
159
163
|
)
|
|
160
164
|
|
|
161
165
|
cfg.task = task
|
|
162
|
-
cfg.hyperparams.update(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
166
|
+
cfg.hyperparams.update(
|
|
167
|
+
{
|
|
168
|
+
"n_ensembles": self.n_estimators,
|
|
169
|
+
"dim": DEFAULT_DIM,
|
|
170
|
+
"dim_output": dim_output,
|
|
171
|
+
"n_layers": DEFAULT_LAYERS,
|
|
172
|
+
"n_heads": DEFAULT_HEADS,
|
|
173
|
+
"regression_loss": "mse",
|
|
174
|
+
}
|
|
175
|
+
)
|
|
170
176
|
|
|
171
177
|
return cfg, Tab2D
|
|
172
178
|
|
|
173
|
-
|
|
174
179
|
def _split_data(self, X, y):
|
|
175
180
|
"""Split data into training and validation sets."""
|
|
176
|
-
if hasattr(self,
|
|
181
|
+
if hasattr(self, "task") and self.task == "classification":
|
|
177
182
|
return make_stratified_dataset_split(X, y, seed=self.seed)
|
|
178
183
|
else:
|
|
179
184
|
# For regression, use random split
|
|
180
|
-
val_indices = np.random.choice(
|
|
185
|
+
val_indices = np.random.choice(
|
|
186
|
+
range(len(X)), int(DEFAULT_VALIDATION_SPLIT * len(X)), replace=False
|
|
187
|
+
).tolist()
|
|
181
188
|
train_indices = [i for i in range(len(X)) if i not in val_indices]
|
|
182
189
|
return X[train_indices], X[val_indices], y[train_indices], y[val_indices]
|
|
183
190
|
|
|
@@ -188,7 +195,9 @@ class MitraBase(BaseEstimator):
|
|
|
188
195
|
rng = np.random.RandomState(cfg.seed)
|
|
189
196
|
|
|
190
197
|
success = False
|
|
191
|
-
while not (
|
|
198
|
+
while not (
|
|
199
|
+
success and cfg.hyperparams["max_samples_support"] > 0 and cfg.hyperparams["max_samples_query"] > 0
|
|
200
|
+
):
|
|
192
201
|
try:
|
|
193
202
|
self.trainers.clear()
|
|
194
203
|
|
|
@@ -199,16 +208,18 @@ class MitraBase(BaseEstimator):
|
|
|
199
208
|
model = Tab2D.from_pretrained(self.hf_model, device=self.device)
|
|
200
209
|
else:
|
|
201
210
|
model = Tab2D(
|
|
202
|
-
dim=cfg.hyperparams[
|
|
211
|
+
dim=cfg.hyperparams["dim"],
|
|
203
212
|
dim_output=dim_output,
|
|
204
|
-
n_layers=cfg.hyperparams[
|
|
205
|
-
n_heads=cfg.hyperparams[
|
|
213
|
+
n_layers=cfg.hyperparams["n_layers"],
|
|
214
|
+
n_heads=cfg.hyperparams["n_heads"],
|
|
206
215
|
task=task.upper(),
|
|
207
216
|
use_pretrained_weights=True,
|
|
208
217
|
path_to_weights=Path(self.state_dict),
|
|
209
218
|
device=self.device,
|
|
210
219
|
)
|
|
211
|
-
trainer = TrainerFinetune(
|
|
220
|
+
trainer = TrainerFinetune(
|
|
221
|
+
cfg, model, n_classes=n_classes, device=self.device, rng=rng, verbose=self.verbose
|
|
222
|
+
)
|
|
212
223
|
|
|
213
224
|
start_time = time.time()
|
|
214
225
|
trainer.train(X_train, y_train, X_valid, y_valid)
|
|
@@ -221,27 +232,25 @@ class MitraBase(BaseEstimator):
|
|
|
221
232
|
|
|
222
233
|
except torch.cuda.OutOfMemoryError:
|
|
223
234
|
if cfg.hyperparams["max_samples_support"] >= 2048:
|
|
224
|
-
cfg.hyperparams["max_samples_support"] = int(
|
|
225
|
-
|
|
235
|
+
cfg.hyperparams["max_samples_support"] = int(cfg.hyperparams["max_samples_support"] // 2)
|
|
236
|
+
print(
|
|
237
|
+
f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
|
|
238
|
+
f"to {cfg.hyperparams['max_samples_support']} due to OOM error."
|
|
226
239
|
)
|
|
227
|
-
print(f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
|
|
228
|
-
f"to {cfg.hyperparams['max_samples_support']} due to OOM error.")
|
|
229
240
|
else:
|
|
230
|
-
cfg.hyperparams["max_samples_support"] = int(
|
|
231
|
-
|
|
241
|
+
cfg.hyperparams["max_samples_support"] = int(cfg.hyperparams["max_samples_support"] // 2)
|
|
242
|
+
print(
|
|
243
|
+
f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
|
|
244
|
+
f"to {cfg.hyperparams['max_samples_support']} due to OOM error."
|
|
232
245
|
)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
cfg.hyperparams[
|
|
246
|
+
cfg.hyperparams["max_samples_query"] = int(cfg.hyperparams["max_samples_query"] // 2)
|
|
247
|
+
print(
|
|
248
|
+
f"Reducing max_samples_query from {cfg.hyperparams['max_samples_query'] * 2}"
|
|
249
|
+
f"to {cfg.hyperparams['max_samples_query']} due to OOM error."
|
|
237
250
|
)
|
|
238
|
-
print(f"Reducing max_samples_query from {cfg.hyperparams['max_samples_query'] * 2}"
|
|
239
|
-
f"to {cfg.hyperparams['max_samples_query']} due to OOM error.")
|
|
240
251
|
|
|
241
252
|
if not success:
|
|
242
|
-
raise RuntimeError(
|
|
243
|
-
"Failed to train Mitra model after multiple attempts due to out of memory error."
|
|
244
|
-
)
|
|
253
|
+
raise RuntimeError("Failed to train Mitra model after multiple attempts due to out of memory error.")
|
|
245
254
|
|
|
246
255
|
return self
|
|
247
256
|
|
|
@@ -249,26 +258,27 @@ class MitraBase(BaseEstimator):
|
|
|
249
258
|
class MitraClassifier(MitraBase, ClassifierMixin):
|
|
250
259
|
"""Classifier implementation of Mitra model."""
|
|
251
260
|
|
|
252
|
-
def __init__(
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
model_type=DEFAULT_MODEL_TYPE,
|
|
264
|
+
n_estimators=DEFAULT_ENSEMBLE,
|
|
265
|
+
device=DEFAULT_DEVICE,
|
|
266
|
+
fine_tune=DEFAULT_FINE_TUNE,
|
|
267
|
+
fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
|
|
268
|
+
metric=DEFAULT_CLS_METRIC,
|
|
269
|
+
state_dict=None,
|
|
270
|
+
hf_model=DEFAULT_CLS_MODEL,
|
|
271
|
+
patience=PATIENCE,
|
|
272
|
+
lr=LR,
|
|
273
|
+
warmup_steps=WARMUP_STEPS,
|
|
274
|
+
shuffle_classes=SHUFFLE_CLASSES,
|
|
275
|
+
shuffle_features=SHUFFLE_FEATURES,
|
|
276
|
+
use_random_transforms=USE_RANDOM_TRANSFORMS,
|
|
277
|
+
random_mirror_regression=RANDOM_MIRROR_REGRESSION,
|
|
278
|
+
random_mirror_x=RANDOM_MIRROR_X,
|
|
279
|
+
seed=SEED,
|
|
280
|
+
verbose=True,
|
|
281
|
+
):
|
|
272
282
|
"""Initialize the classifier."""
|
|
273
283
|
super().__init__(
|
|
274
284
|
model_type,
|
|
@@ -290,19 +300,19 @@ class MitraClassifier(MitraBase, ClassifierMixin):
|
|
|
290
300
|
seed=seed,
|
|
291
301
|
verbose=verbose,
|
|
292
302
|
)
|
|
293
|
-
self.task =
|
|
303
|
+
self.task = "classification"
|
|
294
304
|
|
|
295
|
-
def fit(self, X, y, X_val
|
|
305
|
+
def fit(self, X, y, X_val=None, y_val=None, time_limit=None):
|
|
296
306
|
"""
|
|
297
307
|
Fit the ensemble of models.
|
|
298
|
-
|
|
308
|
+
|
|
299
309
|
Parameters
|
|
300
310
|
----------
|
|
301
311
|
X : array-like of shape (n_samples, n_features)
|
|
302
312
|
Training data
|
|
303
313
|
y : array-like of shape (n_samples,)
|
|
304
314
|
Target values
|
|
305
|
-
|
|
315
|
+
|
|
306
316
|
Returns
|
|
307
317
|
-------
|
|
308
318
|
self : object
|
|
@@ -310,7 +320,6 @@ class MitraClassifier(MitraBase, ClassifierMixin):
|
|
|
310
320
|
"""
|
|
311
321
|
|
|
312
322
|
with mitra_deterministic_context():
|
|
313
|
-
|
|
314
323
|
if isinstance(X, pd.DataFrame):
|
|
315
324
|
X = X.values
|
|
316
325
|
if isinstance(y, pd.Series):
|
|
@@ -327,17 +336,26 @@ class MitraClassifier(MitraBase, ClassifierMixin):
|
|
|
327
336
|
else:
|
|
328
337
|
X_train, X_valid, y_train, y_valid = self._split_data(X, y)
|
|
329
338
|
|
|
330
|
-
return self._train_ensemble(
|
|
339
|
+
return self._train_ensemble(
|
|
340
|
+
X_train,
|
|
341
|
+
y_train,
|
|
342
|
+
X_valid,
|
|
343
|
+
y_valid,
|
|
344
|
+
self.task,
|
|
345
|
+
DEFAULT_CLASSES,
|
|
346
|
+
n_classes=DEFAULT_CLASSES,
|
|
347
|
+
time_limit=time_limit,
|
|
348
|
+
)
|
|
331
349
|
|
|
332
350
|
def predict(self, X):
|
|
333
351
|
"""
|
|
334
352
|
Predict class labels for samples in X.
|
|
335
|
-
|
|
353
|
+
|
|
336
354
|
Parameters
|
|
337
355
|
----------
|
|
338
356
|
X : array-like of shape (n_samples, n_features)
|
|
339
357
|
The input samples
|
|
340
|
-
|
|
358
|
+
|
|
341
359
|
Returns
|
|
342
360
|
-------
|
|
343
361
|
y : ndarray of shape (n_samples,)
|
|
@@ -352,12 +370,12 @@ class MitraClassifier(MitraBase, ClassifierMixin):
|
|
|
352
370
|
def predict_proba(self, X):
|
|
353
371
|
"""
|
|
354
372
|
Predict class probabilities for samples in X.
|
|
355
|
-
|
|
373
|
+
|
|
356
374
|
Parameters
|
|
357
375
|
----------
|
|
358
376
|
X : array-like of shape (n_samples, n_features)
|
|
359
377
|
The input samples
|
|
360
|
-
|
|
378
|
+
|
|
361
379
|
Returns
|
|
362
380
|
-------
|
|
363
381
|
p : ndarray of shape (n_samples, n_classes)
|
|
@@ -365,14 +383,13 @@ class MitraClassifier(MitraBase, ClassifierMixin):
|
|
|
365
383
|
"""
|
|
366
384
|
|
|
367
385
|
with mitra_deterministic_context():
|
|
368
|
-
|
|
369
386
|
if isinstance(X, pd.DataFrame):
|
|
370
387
|
X = X.values
|
|
371
388
|
|
|
372
389
|
preds = []
|
|
373
390
|
for trainer in self.trainers:
|
|
374
|
-
logits = trainer.predict(self.X, self.y, X)[
|
|
375
|
-
preds.append(np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True))
|
|
391
|
+
logits = trainer.predict(self.X, self.y, X)[..., : len(np.unique(self.y))] # Remove extra classes
|
|
392
|
+
preds.append(np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)) # Softmax
|
|
376
393
|
preds = sum(preds) / len(preds) # Averaging ensemble predictions
|
|
377
394
|
|
|
378
395
|
return preds
|
|
@@ -381,26 +398,27 @@ class MitraClassifier(MitraBase, ClassifierMixin):
|
|
|
381
398
|
class MitraRegressor(MitraBase, RegressorMixin):
|
|
382
399
|
"""Regressor implementation of Mitra model."""
|
|
383
400
|
|
|
384
|
-
def __init__(
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
401
|
+
def __init__(
|
|
402
|
+
self,
|
|
403
|
+
model_type=DEFAULT_MODEL_TYPE,
|
|
404
|
+
n_estimators=DEFAULT_ENSEMBLE,
|
|
405
|
+
device=DEFAULT_DEVICE,
|
|
406
|
+
fine_tune=DEFAULT_FINE_TUNE,
|
|
407
|
+
fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
|
|
408
|
+
metric=DEFAULT_REG_METRIC,
|
|
409
|
+
state_dict=None,
|
|
410
|
+
hf_model=DEFAULT_REG_MODEL,
|
|
411
|
+
patience=PATIENCE,
|
|
412
|
+
lr=LR,
|
|
413
|
+
warmup_steps=WARMUP_STEPS,
|
|
414
|
+
shuffle_classes=SHUFFLE_CLASSES,
|
|
415
|
+
shuffle_features=SHUFFLE_FEATURES,
|
|
416
|
+
use_random_transforms=USE_RANDOM_TRANSFORMS,
|
|
417
|
+
random_mirror_regression=RANDOM_MIRROR_REGRESSION,
|
|
418
|
+
random_mirror_x=RANDOM_MIRROR_X,
|
|
419
|
+
seed=SEED,
|
|
420
|
+
verbose=True,
|
|
421
|
+
):
|
|
404
422
|
"""Initialize the regressor."""
|
|
405
423
|
super().__init__(
|
|
406
424
|
model_type,
|
|
@@ -422,19 +440,19 @@ class MitraRegressor(MitraBase, RegressorMixin):
|
|
|
422
440
|
seed=seed,
|
|
423
441
|
verbose=verbose,
|
|
424
442
|
)
|
|
425
|
-
self.task =
|
|
443
|
+
self.task = "regression"
|
|
426
444
|
|
|
427
|
-
def fit(self, X, y, X_val
|
|
445
|
+
def fit(self, X, y, X_val=None, y_val=None, time_limit=None):
|
|
428
446
|
"""
|
|
429
447
|
Fit the ensemble of models.
|
|
430
|
-
|
|
448
|
+
|
|
431
449
|
Parameters
|
|
432
450
|
----------
|
|
433
451
|
X : array-like of shape (n_samples, n_features)
|
|
434
452
|
Training data
|
|
435
453
|
y : array-like of shape (n_samples,)
|
|
436
454
|
Target values
|
|
437
|
-
|
|
455
|
+
|
|
438
456
|
Returns
|
|
439
457
|
-------
|
|
440
458
|
self : object
|
|
@@ -442,7 +460,6 @@ class MitraRegressor(MitraBase, RegressorMixin):
|
|
|
442
460
|
"""
|
|
443
461
|
|
|
444
462
|
with mitra_deterministic_context():
|
|
445
|
-
|
|
446
463
|
if isinstance(X, pd.DataFrame):
|
|
447
464
|
X = X.values
|
|
448
465
|
if isinstance(y, pd.Series):
|
|
@@ -464,12 +481,12 @@ class MitraRegressor(MitraBase, RegressorMixin):
|
|
|
464
481
|
def predict(self, X):
|
|
465
482
|
"""
|
|
466
483
|
Predict regression target for samples in X.
|
|
467
|
-
|
|
484
|
+
|
|
468
485
|
Parameters
|
|
469
486
|
----------
|
|
470
487
|
X : array-like of shape (n_samples, n_features)
|
|
471
488
|
The input samples
|
|
472
|
-
|
|
489
|
+
|
|
473
490
|
Returns
|
|
474
491
|
-------
|
|
475
492
|
y : ndarray of shape (n_samples,)
|
|
@@ -477,16 +494,15 @@ class MitraRegressor(MitraBase, RegressorMixin):
|
|
|
477
494
|
"""
|
|
478
495
|
|
|
479
496
|
with mitra_deterministic_context():
|
|
480
|
-
|
|
481
497
|
if isinstance(X, pd.DataFrame):
|
|
482
498
|
X = X.values
|
|
483
|
-
|
|
499
|
+
|
|
484
500
|
preds = []
|
|
485
501
|
for trainer in self.trainers:
|
|
486
502
|
preds.append(trainer.predict(self.X, self.y, X))
|
|
487
|
-
|
|
503
|
+
|
|
488
504
|
return sum(preds) / len(preds) # Averaging ensemble predictions
|
|
489
|
-
|
|
505
|
+
|
|
490
506
|
|
|
491
507
|
@contextlib.contextmanager
|
|
492
508
|
def mitra_deterministic_context():
|