autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -0
- autogluon/tabular/models/catboost/catboost_model.py +9 -6
- autogluon/tabular/models/catboost/catboost_utils.py +10 -0
- autogluon/tabular/models/lgb/lgb_model.py +2 -1
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +214 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -0
- autogluon/tabular/testing/fit_helper.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/METADATA +21 -12
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/RECORD +36 -16
- /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/zip-safe +0 -0
@@ -23,6 +23,7 @@ from .tabicl.tabicl_model import TabICLModel
|
|
23
23
|
from .tabm.tabm_model import TabMModel
|
24
24
|
from .tabpfnv2.tabpfnv2_model import TabPFNV2Model
|
25
25
|
from .tabpfnmix.tabpfnmix_model import TabPFNMixModel
|
26
|
+
from .mitra.mitra_model import MitraModel
|
26
27
|
from .tabular_nn.torch.tabular_nn_torch import TabularNeuralNetTorchModel
|
27
28
|
from .text_prediction.text_prediction_v1_model import TextPredictorModel
|
28
29
|
from .xgboost.xgboost_model import XGBoostModel
|
@@ -13,13 +13,13 @@ from autogluon.common.features.types import R_BOOL, R_CATEGORY, R_FLOAT, R_INT
|
|
13
13
|
from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
|
14
14
|
from autogluon.common.utils.resource_utils import ResourceManager
|
15
15
|
from autogluon.common.utils.try_import import try_import_catboost
|
16
|
-
from autogluon.core.constants import MULTICLASS, PROBLEM_TYPES_CLASSIFICATION, QUANTILE, SOFTCLASS
|
16
|
+
from autogluon.core.constants import MULTICLASS, PROBLEM_TYPES_CLASSIFICATION, REGRESSION, QUANTILE, SOFTCLASS
|
17
17
|
from autogluon.core.models import AbstractModel
|
18
18
|
from autogluon.core.models._utils import get_early_stopping_rounds
|
19
19
|
from autogluon.core.utils.exceptions import TimeLimitExceeded
|
20
20
|
|
21
21
|
from .callbacks import EarlyStoppingCallback, MemoryCheckCallback, TimeCheckCallback
|
22
|
-
from .catboost_utils import get_catboost_metric_from_ag_metric
|
22
|
+
from .catboost_utils import get_catboost_metric_from_ag_metric, CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION
|
23
23
|
from .hyperparameters.parameters import get_param_baseline
|
24
24
|
from .hyperparameters.searchspaces import get_default_searchspace
|
25
25
|
|
@@ -131,11 +131,14 @@ class CatBoostModel(AbstractModel):
|
|
131
131
|
# FIXME: This is extremely slow due to unoptimized metric / objective sent to CatBoost
|
132
132
|
from .catboost_softclass_utils import SoftclassCustomMetric, SoftclassObjective
|
133
133
|
|
134
|
-
params
|
134
|
+
params.setdefault("loss_function", SoftclassObjective.SoftLogLossObjective())
|
135
135
|
params["eval_metric"] = SoftclassCustomMetric.SoftLogLossMetric()
|
136
|
-
elif self.problem_type
|
137
|
-
#
|
138
|
-
params
|
136
|
+
elif self.problem_type in [REGRESSION, QUANTILE]:
|
137
|
+
# Choose appropriate loss_function that is as close as possible to the eval_metric
|
138
|
+
params.setdefault(
|
139
|
+
"loss_function",
|
140
|
+
CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION.get(params["eval_metric"], params["eval_metric"])
|
141
|
+
)
|
139
142
|
|
140
143
|
model_type = CatBoostClassifier if self.problem_type in PROBLEM_TYPES_CLASSIFICATION else CatBoostRegressor
|
141
144
|
num_rows_train = len(X)
|
@@ -6,6 +6,13 @@ logger = logging.getLogger(__name__)
|
|
6
6
|
|
7
7
|
|
8
8
|
CATBOOST_QUANTILE_PREFIX = "Quantile:"
|
9
|
+
# Mapping from non-optimizable eval_metric to optimizable loss_function.
|
10
|
+
# See https://catboost.ai/docs/en/concepts/loss-functions-regression#usage-information
|
11
|
+
CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION = {
|
12
|
+
"MedianAbsoluteError": "MAE",
|
13
|
+
"SMAPE": "MAPE",
|
14
|
+
"R2": "RMSE",
|
15
|
+
}
|
9
16
|
|
10
17
|
|
11
18
|
# TODO: Add weight support?
|
@@ -65,7 +72,10 @@ def get_catboost_metric_from_ag_metric(metric, problem_type, quantile_levels=Non
|
|
65
72
|
mean_squared_error="RMSE",
|
66
73
|
root_mean_squared_error="RMSE",
|
67
74
|
mean_absolute_error="MAE",
|
75
|
+
mean_absolute_percentage_error="MAPE",
|
76
|
+
# Non-optimizable metrics, see CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION
|
68
77
|
median_absolute_error="MedianAbsoluteError",
|
78
|
+
symmetric_mean_absolute_percentage_error="SMAPE",
|
69
79
|
r2="R2",
|
70
80
|
)
|
71
81
|
metric_class = metric_map.get(metric.name, "RMSE")
|
@@ -281,7 +281,8 @@ class LGBModel(AbstractModel):
|
|
281
281
|
train_params["params"]["metric"] = f'{stopping_metric},{train_params["params"]["metric"]}'
|
282
282
|
|
283
283
|
if self.problem_type == SOFTCLASS:
|
284
|
-
train_params["
|
284
|
+
train_params["params"]["objective"] = lgb_utils.softclass_lgbobj
|
285
|
+
train_params["params"]["num_classes"] = self.num_classes
|
285
286
|
elif self.problem_type == QUANTILE:
|
286
287
|
train_params["params"]["quantile_levels"] = self.quantile_levels
|
287
288
|
if seed_val is not None:
|
File without changes
|
@@ -0,0 +1,190 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Optional
|
6
|
+
import yaml
|
7
|
+
import os
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
11
|
+
|
12
|
+
from ..._internal.config.enums import GeneratorName, ModelName, LossName, Task
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class ConfigData():
|
16
|
+
generator: GeneratorName
|
17
|
+
min_samples_support: int
|
18
|
+
max_samples_support: int
|
19
|
+
n_samples_query: int
|
20
|
+
min_features: int
|
21
|
+
max_features: int
|
22
|
+
max_classes: int
|
23
|
+
sample_multinomial_categorical: bool
|
24
|
+
sample_multinomial_label: bool
|
25
|
+
generator_hyperparams: dict
|
26
|
+
task: Task
|
27
|
+
|
28
|
+
def __post_init__(self):
|
29
|
+
|
30
|
+
assert self.min_samples_support <= self.max_samples_support
|
31
|
+
assert self.min_features <= self.max_features
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class ConfigModel():
|
35
|
+
name: ModelName
|
36
|
+
hyperparams: dict
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class ConfigPreprocessing():
|
41
|
+
use_quantile_transformer: bool
|
42
|
+
use_feature_count_scaling: bool
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class ConfigGradScaler():
|
46
|
+
enabled: bool
|
47
|
+
scale_init: float
|
48
|
+
scale_min: float
|
49
|
+
growth_interval: int
|
50
|
+
|
51
|
+
|
52
|
+
def __post_init__(self):
|
53
|
+
assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
|
54
|
+
assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
|
55
|
+
assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
|
56
|
+
assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class ConfigOptim():
|
60
|
+
steps: int
|
61
|
+
log_every_n_steps: int
|
62
|
+
eval_every_n_steps: int
|
63
|
+
batch_size: int
|
64
|
+
gradient_accumulation_steps: int
|
65
|
+
lr: float
|
66
|
+
weight_decay: float
|
67
|
+
beta1: float
|
68
|
+
beta2: float
|
69
|
+
warmup_steps: int
|
70
|
+
cosine_scheduler: bool
|
71
|
+
max_grad_norm: float
|
72
|
+
label_smoothing: float
|
73
|
+
regression_loss: LossName
|
74
|
+
use_pretrained_weights: bool
|
75
|
+
path_to_weights: str
|
76
|
+
resume_states: bool
|
77
|
+
path_to_states: str
|
78
|
+
precision: str
|
79
|
+
grad_scaler: ConfigGradScaler
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_hydra(cls, cfg_hydra: DictConfig) -> Self:
|
83
|
+
|
84
|
+
grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
|
85
|
+
cfg_dict: dict = OmegaConf.to_container(cfg_hydra) # type: ignore
|
86
|
+
del cfg_dict["grad_scaler"]
|
87
|
+
|
88
|
+
regression_loss = LossName[cfg_dict["regression_loss"]]
|
89
|
+
del cfg_dict["regression_loss"]
|
90
|
+
|
91
|
+
return cls(
|
92
|
+
grad_scaler=grad_scaler,
|
93
|
+
regression_loss=regression_loss,
|
94
|
+
**cfg_dict
|
95
|
+
)
|
96
|
+
|
97
|
+
def __post_init__(self):
|
98
|
+
assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
|
99
|
+
|
100
|
+
class ConfigSaveLoadMixin(yaml.YAMLObject):
|
101
|
+
|
102
|
+
def save(self, path: Path) -> None:
|
103
|
+
|
104
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
105
|
+
|
106
|
+
with open(path, 'w') as f:
|
107
|
+
yaml.dump(self, f, default_flow_style=False)
|
108
|
+
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def load(cls, path: Path) -> Self:
|
112
|
+
|
113
|
+
with open(path, 'r') as f:
|
114
|
+
# It's unsafe, but not unsafer than the pickle module
|
115
|
+
config = yaml.unsafe_load(f)
|
116
|
+
|
117
|
+
return config
|
118
|
+
|
119
|
+
@dataclass
|
120
|
+
class ConfigPretrain(ConfigSaveLoadMixin):
|
121
|
+
run_name: str
|
122
|
+
output_dir: Path
|
123
|
+
seed: int
|
124
|
+
devices: list[torch.device]
|
125
|
+
device: torch.device
|
126
|
+
max_cpus_per_device: Optional[int]
|
127
|
+
use_ddp: bool
|
128
|
+
workers_per_gpu: int
|
129
|
+
model: ConfigModel
|
130
|
+
data: ConfigData
|
131
|
+
optim: ConfigOptim
|
132
|
+
preprocessing: ConfigPreprocessing
|
133
|
+
load_from_file: bool
|
134
|
+
load_path_x: str
|
135
|
+
load_path_y: str
|
136
|
+
save_file: bool
|
137
|
+
save_file_only: bool
|
138
|
+
save_path_x: str
|
139
|
+
save_path_y: str
|
140
|
+
number_of_runs: int
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def from_hydra(cls, cfg_hydra: DictConfig):
|
144
|
+
|
145
|
+
assert not os.path.exists(cfg_hydra.output_dir), f'Output directory {cfg_hydra.output_dir} already exists! Please change to a new folder.'
|
146
|
+
|
147
|
+
output_dir = Path(cfg_hydra.output_dir)
|
148
|
+
|
149
|
+
devices = [torch.device(device) for device in cfg_hydra.devices]
|
150
|
+
|
151
|
+
# Initialize device to cpu, DDP will overwrite this
|
152
|
+
device = torch.device("cpu")
|
153
|
+
|
154
|
+
return cls(
|
155
|
+
run_name=cfg_hydra.run_name,
|
156
|
+
output_dir=output_dir,
|
157
|
+
devices=devices,
|
158
|
+
device=device,
|
159
|
+
max_cpus_per_device=cfg_hydra.max_cpus_per_device,
|
160
|
+
use_ddp=len(devices) > 1,
|
161
|
+
seed=cfg_hydra.seed,
|
162
|
+
workers_per_gpu=cfg_hydra.workers_per_gpu,
|
163
|
+
model = ConfigModel(
|
164
|
+
name = ModelName[cfg_hydra.model.name],
|
165
|
+
hyperparams = OmegaConf.to_container(cfg_hydra.model.hyperparams),
|
166
|
+
),
|
167
|
+
data = ConfigData(
|
168
|
+
generator=GeneratorName(cfg_hydra.data.generator),
|
169
|
+
min_samples_support=cfg_hydra.data.min_samples_support,
|
170
|
+
max_samples_support=cfg_hydra.data.max_samples_support,
|
171
|
+
n_samples_query=cfg_hydra.data.n_samples_query,
|
172
|
+
min_features=cfg_hydra.data.min_features,
|
173
|
+
max_features=cfg_hydra.data.max_features,
|
174
|
+
max_classes=cfg_hydra.data.max_classes,
|
175
|
+
task=Task[cfg_hydra.data.task],
|
176
|
+
sample_multinomial_categorical=cfg_hydra.data.sample_multinomial_categorical,
|
177
|
+
sample_multinomial_label=cfg_hydra.data.sample_multinomial_label,
|
178
|
+
generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams), # type: ignore
|
179
|
+
),
|
180
|
+
optim = ConfigOptim.from_hydra(cfg_hydra.optim),
|
181
|
+
preprocessing = ConfigPreprocessing(**cfg_hydra.preprocessing),
|
182
|
+
load_from_file = cfg_hydra.load_from_file,
|
183
|
+
load_path_x = cfg_hydra.load_path_x,
|
184
|
+
load_path_y = cfg_hydra.load_path_y,
|
185
|
+
save_file = cfg_hydra.save_file,
|
186
|
+
save_file_only = cfg_hydra.save_file_only,
|
187
|
+
save_path_x = cfg_hydra.save_path_x,
|
188
|
+
save_path_y = cfg_hydra.save_path_y,
|
189
|
+
number_of_runs = cfg_hydra.number_of_runs,
|
190
|
+
)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Self
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
|
9
|
+
from ..._internal.config.enums import ModelName
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class ConfigRun(ConfigSaveLoadMixin):
|
13
|
+
device: torch.device
|
14
|
+
seed: int
|
15
|
+
model_name: ModelName
|
16
|
+
hyperparams: dict
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
def create(
|
20
|
+
cls,
|
21
|
+
device: torch.device,
|
22
|
+
seed: int,
|
23
|
+
model_name: ModelName,
|
24
|
+
hyperparams: dict
|
25
|
+
) -> Self:
|
26
|
+
|
27
|
+
return cls(
|
28
|
+
device=device,
|
29
|
+
seed=seed,
|
30
|
+
model_name=model_name,
|
31
|
+
hyperparams=hyperparams
|
32
|
+
)
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from enum import IntEnum, StrEnum
|
2
|
+
|
3
|
+
|
4
|
+
class Task(StrEnum):
|
5
|
+
CLASSIFICATION = "classification"
|
6
|
+
REGRESSION = "regression"
|
7
|
+
|
8
|
+
|
9
|
+
class FeatureType(StrEnum):
|
10
|
+
NUMERICAL = "numerical"
|
11
|
+
CATEGORICAL = "categorical"
|
12
|
+
MIXED = "mixed"
|
13
|
+
|
14
|
+
|
15
|
+
class SearchType(StrEnum):
|
16
|
+
DEFAULT = "default"
|
17
|
+
RANDOM = "random"
|
18
|
+
|
19
|
+
|
20
|
+
class DatasetSize(IntEnum):
|
21
|
+
SMALL = 1000
|
22
|
+
MEDIUM = 10000
|
23
|
+
LARGE = 50000
|
24
|
+
|
25
|
+
|
26
|
+
class DataSplit(StrEnum):
|
27
|
+
TRAIN = "train"
|
28
|
+
VALID = "valid"
|
29
|
+
TEST = "test"
|
30
|
+
|
31
|
+
|
32
|
+
class Phase(StrEnum):
|
33
|
+
TRAINING = "training"
|
34
|
+
VALIDATION = "validation"
|
35
|
+
TESTING = "testing"
|
36
|
+
|
37
|
+
|
38
|
+
class ModelName(StrEnum):
|
39
|
+
PLACEHOLDER = "_placeholder_" # This is a placeholder for the current running model
|
40
|
+
FT_TRANSFORMER = "FT-Transformer"
|
41
|
+
TABPFN = "TabPFN"
|
42
|
+
FOUNDATION = "Foundation"
|
43
|
+
FOUNDATION_FLASH = "FoundationFlash"
|
44
|
+
TAB2D = "Tab2D"
|
45
|
+
TAB2D_COL_ROW = "Tab2D_COL_ROW"
|
46
|
+
TAB2D_SDPA = "Tab2D_SDPA"
|
47
|
+
SAINT = "SAINT"
|
48
|
+
MLP = "MLP"
|
49
|
+
MLP_RTDL = "MLP-rtdl"
|
50
|
+
RESNET = "Resnet"
|
51
|
+
RANDOM_FOREST = "RandomForest"
|
52
|
+
XGBOOST = "XGBoost"
|
53
|
+
CATBOOST = "CatBoost"
|
54
|
+
LIGHTGBM = "LightGBM"
|
55
|
+
GRADIENT_BOOSTING_TREE = "GradientBoostingTree"
|
56
|
+
HIST_GRADIENT_BOOSTING_TREE = "HistGradientBoostingTree"
|
57
|
+
LOGISTIC_REGRESSION = "LogisticRegression"
|
58
|
+
LINEAR_REGRESSION = "LinearRegression"
|
59
|
+
DECISION_TREE = "DecisionTree"
|
60
|
+
KNN = "KNN"
|
61
|
+
STG = "STG"
|
62
|
+
SVM = "SVM"
|
63
|
+
TABNET = "TabNet"
|
64
|
+
TABTRANSFORMER = "TabTransformer"
|
65
|
+
DEEPFM = "DeepFM"
|
66
|
+
VIME = "VIME"
|
67
|
+
DANET = "DANet"
|
68
|
+
NODE = "NODE"
|
69
|
+
AUTOGLUON = "AutoGluon"
|
70
|
+
|
71
|
+
|
72
|
+
class ModelClass(StrEnum):
|
73
|
+
BASE = 'base'
|
74
|
+
GBDT = 'GBDT'
|
75
|
+
NN = 'NN'
|
76
|
+
ICLT = 'ICLT'
|
77
|
+
|
78
|
+
|
79
|
+
class DownstreamTask(StrEnum):
|
80
|
+
ZEROSHOT = "zeroshot"
|
81
|
+
FINETUNE = "finetune"
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
class BenchmarkName(StrEnum):
|
86
|
+
DEBUG_CLASSIFICATION = "debug_classification"
|
87
|
+
DEBUG_REGRESSION = "debug_regression"
|
88
|
+
DEBUG_TABZILLA = "debug_tabzilla"
|
89
|
+
|
90
|
+
CATEGORICAL_CLASSIFICATION = "categorical_classification"
|
91
|
+
NUMERICAL_CLASSIFICATION = "numerical_classification"
|
92
|
+
CATEGORICAL_REGRESSION = "categorical_regression"
|
93
|
+
NUMERICAL_REGRESSION = "numerical_regression"
|
94
|
+
CATEGORICAL_CLASSIFICATION_LARGE = "categorical_classification_large"
|
95
|
+
NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
|
96
|
+
CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
|
97
|
+
NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
|
98
|
+
|
99
|
+
TABZILLA_HARD = "tabzilla_hard"
|
100
|
+
TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
|
101
|
+
TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
|
102
|
+
|
103
|
+
|
104
|
+
class BenchmarkOrigin(StrEnum):
|
105
|
+
TABZILLA = "tabzilla"
|
106
|
+
WHYTREES = "whytrees"
|
107
|
+
|
108
|
+
|
109
|
+
class GeneratorName(StrEnum):
|
110
|
+
TABPFN = 'tabpfn'
|
111
|
+
TREE = 'tree'
|
112
|
+
RANDOMFOREST = 'randomforest'
|
113
|
+
NEIGHBOR = 'neighbor'
|
114
|
+
MIX = 'mix'
|
115
|
+
PERLIN = 'perlin'
|
116
|
+
MIX_7 = 'mix_7'
|
117
|
+
MIX_6 = 'mix_6'
|
118
|
+
MIX_5 = 'mix_5'
|
119
|
+
MIX_5_GP = 'mix_5_gp'
|
120
|
+
MIX_4 = 'mix_4'
|
121
|
+
MIX_4_AG = 'mix_4_ag'
|
122
|
+
LR = 'lr'
|
123
|
+
POLY = 'poly'
|
124
|
+
SAMPLE_RF = 'sample_rf'
|
125
|
+
SAMPLE_GP = 'sample_gp'
|
126
|
+
TABREPO = 'tabrepo'
|
127
|
+
MIX_4_TABREPO = 'mix_4_tabrepo'
|
128
|
+
MIX_4_TABPFNV2 = 'mix_4_tabpfnv2'
|
129
|
+
|
130
|
+
|
131
|
+
class MetricName(StrEnum):
|
132
|
+
ACCURACY = "accuracy"
|
133
|
+
F1 = "f1"
|
134
|
+
AUC = "auc"
|
135
|
+
MSE = "mse"
|
136
|
+
MAE = "mae"
|
137
|
+
R2 = "r2"
|
138
|
+
LOG_LOSS = "log_loss"
|
139
|
+
RMSE = "rmse"
|
140
|
+
|
141
|
+
|
142
|
+
class LossName(StrEnum):
|
143
|
+
CROSS_ENTROPY = "cross_entropy"
|
144
|
+
MSE = "mse"
|
145
|
+
MAE = "mae"
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
|
4
|
+
|
5
|
+
class EarlyStopping():
|
6
|
+
|
7
|
+
def __init__(self, patience=10, delta=0.0001, metric='log_loss'):
|
8
|
+
|
9
|
+
self.patience = patience
|
10
|
+
self.counter = 0
|
11
|
+
self.best_score = None
|
12
|
+
self.early_stop = False
|
13
|
+
self.delta = delta
|
14
|
+
self.metric = metric
|
15
|
+
|
16
|
+
|
17
|
+
def __call__(self, val_loss):
|
18
|
+
|
19
|
+
# smaller is better for these metrics
|
20
|
+
if self.metric in ["log_loss", "mse", "mae", "rmse"]:
|
21
|
+
score = -val_loss
|
22
|
+
# larger is better for these metrics
|
23
|
+
elif self.metric in ["accuracy", "roc_auc", "r2"]:
|
24
|
+
score = val_loss
|
25
|
+
else:
|
26
|
+
raise ValueError(f"Unsupported metric: {self.metric}. Supported metrics are: log_loss, mse, mae, rmse, accuracy, roc_auc, r2.")
|
27
|
+
|
28
|
+
if self.best_score is None:
|
29
|
+
self.best_score = score
|
30
|
+
elif score < self.best_score + self.delta:
|
31
|
+
self.counter += 1
|
32
|
+
if self.counter >= self.patience:
|
33
|
+
self.early_stop = True
|
34
|
+
else:
|
35
|
+
self.best_score = score
|
36
|
+
self.counter = 0
|
37
|
+
|
38
|
+
def we_should_stop(self):
|
39
|
+
return self.early_stop
|
40
|
+
|
41
|
+
|
42
|
+
class Checkpoint():
|
43
|
+
|
44
|
+
def __init__(self):
|
45
|
+
self.curr_best_loss = np.inf
|
46
|
+
self.best_model: dict
|
47
|
+
|
48
|
+
def reset(self, net: torch.nn.Module):
|
49
|
+
self.curr_best_loss = np.inf
|
50
|
+
self.best_model = net.state_dict()
|
51
|
+
for key in self.best_model:
|
52
|
+
self.best_model[key] = self.best_model[key].to('cpu')
|
53
|
+
|
54
|
+
|
55
|
+
def __call__(self, net: torch.nn.Module, loss: float):
|
56
|
+
|
57
|
+
if loss < self.curr_best_loss:
|
58
|
+
self.curr_best_loss = loss
|
59
|
+
self.best_model = net.state_dict()
|
60
|
+
for key in self.best_model:
|
61
|
+
self.best_model[key] = self.best_model[key].to('cpu')
|
62
|
+
|
63
|
+
|
64
|
+
def set_to_best(self, net):
|
65
|
+
net.load_state_dict(self.best_model)
|
66
|
+
|
67
|
+
|
68
|
+
class EpochStatistics():
|
69
|
+
|
70
|
+
def __init__(self) -> None:
|
71
|
+
self.n = 0
|
72
|
+
self.loss = 0
|
73
|
+
self.score = 0
|
74
|
+
|
75
|
+
def update(self, loss, score, n):
|
76
|
+
self.n += n
|
77
|
+
self.loss += loss * n
|
78
|
+
self.score += score * n
|
79
|
+
|
80
|
+
def get(self):
|
81
|
+
return self.loss / self.n, self.score / self.n
|
82
|
+
|
83
|
+
class TrackOutput():
|
84
|
+
|
85
|
+
def __init__(self) -> None:
|
86
|
+
self.y_true: list[np.ndarray] = []
|
87
|
+
self.y_pred: list[np.ndarray] = []
|
88
|
+
|
89
|
+
def update(self, y_true: np.ndarray, y_pred: np.ndarray):
|
90
|
+
self.y_true.append(y_true)
|
91
|
+
self.y_pred.append(y_pred)
|
92
|
+
|
93
|
+
def get(self):
|
94
|
+
return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import torch
|
2
|
+
import einops
|
3
|
+
|
4
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
5
|
+
from ..._internal.config.config_run import ConfigRun
|
6
|
+
from ..._internal.config.enums import LossName, Task
|
7
|
+
|
8
|
+
class CrossEntropyLossExtraBatch(torch.nn.Module):
|
9
|
+
|
10
|
+
def __init__(self, label_smoothing: float):
|
11
|
+
super().__init__()
|
12
|
+
|
13
|
+
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
14
|
+
|
15
|
+
|
16
|
+
def forward(self, input, target):
|
17
|
+
"""
|
18
|
+
Input has shape (batch_size, num_samples, num_classes)
|
19
|
+
Target has shape (batch_size, num_samples)
|
20
|
+
|
21
|
+
Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
|
22
|
+
"""
|
23
|
+
|
24
|
+
input = einops.rearrange(input, 'b s c -> (b s) c')
|
25
|
+
target = einops.rearrange(target, 'b s -> (b s)')
|
26
|
+
|
27
|
+
return self.loss(input, target)
|
28
|
+
|
29
|
+
def get_loss(cfg: ConfigRun):
|
30
|
+
|
31
|
+
match (cfg.task, cfg.hyperparams['regression_loss']):
|
32
|
+
case (Task.REGRESSION, LossName.MSE):
|
33
|
+
return torch.nn.MSELoss()
|
34
|
+
case (Task.REGRESSION, LossName.MAE):
|
35
|
+
return torch.nn.L1Loss()
|
36
|
+
case (Task.REGRESSION, LossName.CROSS_ENTROPY):
|
37
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
38
|
+
case (Task.CLASSIFICATION, _):
|
39
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
40
|
+
case (_, _):
|
41
|
+
raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
|
42
|
+
|
43
|
+
def get_loss_pretrain(cfg: ConfigPretrain):
|
44
|
+
|
45
|
+
match (cfg.data.task, cfg.optim.regression_loss):
|
46
|
+
case (Task.REGRESSION, LossName.MSE):
|
47
|
+
return torch.nn.MSELoss()
|
48
|
+
case (Task.REGRESSION, LossName.MAE):
|
49
|
+
return torch.nn.L1Loss()
|
50
|
+
case (Task.REGRESSION, LossName.CROSS_ENTROPY):
|
51
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
52
|
+
case (Task.CLASSIFICATION, _):
|
53
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
54
|
+
case (_, _):
|
55
|
+
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|