autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/config_helper.py +1 -1
- autogluon/tabular/configs/hyperparameter_configs.py +2 -265
- autogluon/tabular/configs/pipeline_presets.py +130 -0
- autogluon/tabular/configs/presets_configs.py +51 -26
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
- autogluon/tabular/models/__init__.py +6 -1
- autogluon/tabular/models/_utils/rapids_utils.py +1 -1
- autogluon/tabular/models/automm/automm_model.py +2 -0
- autogluon/tabular/models/automm/ft_transformer.py +4 -1
- autogluon/tabular/models/catboost/callbacks.py +3 -2
- autogluon/tabular/models/catboost/catboost_model.py +15 -9
- autogluon/tabular/models/catboost/catboost_utils.py +17 -3
- autogluon/tabular/models/ebm/__init__.py +0 -0
- autogluon/tabular/models/ebm/ebm_model.py +259 -0
- autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
- autogluon/tabular/models/knn/knn_model.py +7 -3
- autogluon/tabular/models/lgb/lgb_model.py +60 -21
- autogluon/tabular/models/lr/lr_model.py +6 -1
- autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
- autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +380 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
- autogluon/tabular/models/realmlp/__init__.py +0 -0
- autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
- autogluon/tabular/models/rf/rf_model.py +11 -6
- autogluon/tabular/models/tabicl/__init__.py +0 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
- autogluon/tabular/models/tabm/__init__.py +0 -0
- autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
- autogluon/tabular/models/tabm/tabm_model.py +356 -0
- autogluon/tabular/models/tabm/tabm_reference.py +631 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
- autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
- autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
- autogluon/tabular/predictor/predictor.py +147 -84
- autogluon/tabular/registry/_ag_model_registry.py +12 -2
- autogluon/tabular/testing/fit_helper.py +57 -27
- autogluon/tabular/testing/generate_datasets.py +7 -0
- autogluon/tabular/trainer/abstract_trainer.py +3 -1
- autogluon/tabular/trainer/model_presets/presets.py +10 -1
- autogluon/tabular/version.py +1 -1
- autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
|
|
5
3
|
from autogluon.common.utils.try_import import try_import_rapids_cuml
|
|
6
4
|
from autogluon.core.constants import REGRESSION
|
|
7
5
|
|
|
@@ -51,10 +49,52 @@ class LinearRapidsModel(RapidsModelMixin, LinearModel):
|
|
|
51
49
|
|
|
52
50
|
def _preprocess(self, X, **kwargs):
|
|
53
51
|
X = super()._preprocess(X=X, **kwargs)
|
|
54
|
-
if
|
|
52
|
+
if hasattr(X, 'toarray'): # Check if it's a sparse matrix
|
|
55
53
|
X = X.toarray()
|
|
56
54
|
return X
|
|
57
55
|
|
|
58
56
|
def _fit(self, X, y, **kwargs):
|
|
59
|
-
|
|
60
|
-
|
|
57
|
+
"""
|
|
58
|
+
Custom fit method for RAPIDS cuML models that handles parameter compatibility
|
|
59
|
+
and bypasses sklearn-specific incremental training approach.
|
|
60
|
+
"""
|
|
61
|
+
# Preprocess data
|
|
62
|
+
X = self.preprocess(X, is_train=True)
|
|
63
|
+
if self.problem_type == 'binary':
|
|
64
|
+
y = y.astype(int).values
|
|
65
|
+
|
|
66
|
+
# Create cuML model with filtered parameters
|
|
67
|
+
model_cls = self._get_model_type()
|
|
68
|
+
|
|
69
|
+
# Comprehensive parameter filtering for cuML compatibility
|
|
70
|
+
cuml_incompatible_params = {
|
|
71
|
+
# AutoGluon-specific preprocessing parameters
|
|
72
|
+
'vectorizer_dict_size', 'proc.ngram_range', 'proc.skew_threshold',
|
|
73
|
+
'proc.impute_strategy', 'handle_text',
|
|
74
|
+
# sklearn-specific parameters not supported by cuML
|
|
75
|
+
'n_jobs', 'warm_start', 'multi_class', 'dual', 'intercept_scaling',
|
|
76
|
+
'class_weight', 'random_state', 'verbose',
|
|
77
|
+
# Parameters that need conversion or special handling
|
|
78
|
+
'penalty', 'C'
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# Filter out incompatible parameters
|
|
82
|
+
filtered_params = {k: v for k, v in self.params.items()
|
|
83
|
+
if k not in cuml_incompatible_params}
|
|
84
|
+
|
|
85
|
+
# Handle parameter conversions for cuML
|
|
86
|
+
if self.problem_type == REGRESSION:
|
|
87
|
+
# Convert sklearn's C parameter to cuML's alpha
|
|
88
|
+
if 'C' in self.params:
|
|
89
|
+
filtered_params['alpha'] = 1.0 / self.params['C']
|
|
90
|
+
else:
|
|
91
|
+
# For classification, keep C parameter
|
|
92
|
+
if 'C' in self.params:
|
|
93
|
+
filtered_params['C'] = self.params['C']
|
|
94
|
+
|
|
95
|
+
# Create and fit cuML model - let cuML handle its own error messages
|
|
96
|
+
self.model = model_cls(**filtered_params)
|
|
97
|
+
self.model.fit(X, y)
|
|
98
|
+
|
|
99
|
+
# Add missing sklearn-compatible attributes for AutoGluon compatibility
|
|
100
|
+
self.model.n_iter_ = None # cuML doesn't track iterations like sklearn
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Internal modules for MitraModel
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Configuration modules for MitraModel
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
import yaml
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
|
11
|
+
|
|
12
|
+
from ..._internal.config.enums import GeneratorName, ModelName, LossName, Task
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ConfigData():
|
|
16
|
+
generator: GeneratorName
|
|
17
|
+
min_samples_support: int
|
|
18
|
+
max_samples_support: int
|
|
19
|
+
n_samples_query: int
|
|
20
|
+
min_features: int
|
|
21
|
+
max_features: int
|
|
22
|
+
max_classes: int
|
|
23
|
+
sample_multinomial_categorical: bool
|
|
24
|
+
sample_multinomial_label: bool
|
|
25
|
+
generator_hyperparams: dict
|
|
26
|
+
task: Task
|
|
27
|
+
|
|
28
|
+
def __post_init__(self):
|
|
29
|
+
|
|
30
|
+
assert self.min_samples_support <= self.max_samples_support
|
|
31
|
+
assert self.min_features <= self.max_features
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ConfigModel():
|
|
35
|
+
name: ModelName
|
|
36
|
+
hyperparams: dict
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ConfigPreprocessing():
|
|
41
|
+
use_quantile_transformer: bool
|
|
42
|
+
use_feature_count_scaling: bool
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ConfigGradScaler():
|
|
46
|
+
enabled: bool
|
|
47
|
+
scale_init: float
|
|
48
|
+
scale_min: float
|
|
49
|
+
growth_interval: int
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
|
|
54
|
+
assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
|
|
55
|
+
assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
|
|
56
|
+
assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ConfigOptim():
|
|
60
|
+
steps: int
|
|
61
|
+
log_every_n_steps: int
|
|
62
|
+
eval_every_n_steps: int
|
|
63
|
+
batch_size: int
|
|
64
|
+
gradient_accumulation_steps: int
|
|
65
|
+
lr: float
|
|
66
|
+
weight_decay: float
|
|
67
|
+
beta1: float
|
|
68
|
+
beta2: float
|
|
69
|
+
warmup_steps: int
|
|
70
|
+
cosine_scheduler: bool
|
|
71
|
+
max_grad_norm: float
|
|
72
|
+
label_smoothing: float
|
|
73
|
+
regression_loss: LossName
|
|
74
|
+
use_pretrained_weights: bool
|
|
75
|
+
path_to_weights: str
|
|
76
|
+
resume_states: bool
|
|
77
|
+
path_to_states: str
|
|
78
|
+
precision: str
|
|
79
|
+
grad_scaler: ConfigGradScaler
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_hydra(cls, cfg_hydra: DictConfig) -> Self:
|
|
83
|
+
|
|
84
|
+
grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
|
|
85
|
+
cfg_dict: dict = OmegaConf.to_container(cfg_hydra) # type: ignore
|
|
86
|
+
del cfg_dict["grad_scaler"]
|
|
87
|
+
|
|
88
|
+
regression_loss = LossName[cfg_dict["regression_loss"]]
|
|
89
|
+
del cfg_dict["regression_loss"]
|
|
90
|
+
|
|
91
|
+
return cls(
|
|
92
|
+
grad_scaler=grad_scaler,
|
|
93
|
+
regression_loss=regression_loss,
|
|
94
|
+
**cfg_dict
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def __post_init__(self):
|
|
98
|
+
assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
|
|
99
|
+
|
|
100
|
+
class ConfigSaveLoadMixin(yaml.YAMLObject):
|
|
101
|
+
|
|
102
|
+
def save(self, path: Path) -> None:
|
|
103
|
+
|
|
104
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
105
|
+
|
|
106
|
+
with open(path, 'w') as f:
|
|
107
|
+
yaml.dump(self, f, default_flow_style=False)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def load(cls, path: Path) -> Self:
|
|
112
|
+
|
|
113
|
+
with open(path, 'r') as f:
|
|
114
|
+
# It's unsafe, but not unsafer than the pickle module
|
|
115
|
+
config = yaml.unsafe_load(f)
|
|
116
|
+
|
|
117
|
+
return config
|
|
118
|
+
|
|
119
|
+
@dataclass
|
|
120
|
+
class ConfigPretrain(ConfigSaveLoadMixin):
|
|
121
|
+
run_name: str
|
|
122
|
+
output_dir: Path
|
|
123
|
+
seed: int
|
|
124
|
+
devices: list[torch.device]
|
|
125
|
+
device: torch.device
|
|
126
|
+
max_cpus_per_device: Optional[int]
|
|
127
|
+
use_ddp: bool
|
|
128
|
+
workers_per_gpu: int
|
|
129
|
+
model: ConfigModel
|
|
130
|
+
data: ConfigData
|
|
131
|
+
optim: ConfigOptim
|
|
132
|
+
preprocessing: ConfigPreprocessing
|
|
133
|
+
load_from_file: bool
|
|
134
|
+
load_path_x: str
|
|
135
|
+
load_path_y: str
|
|
136
|
+
save_file: bool
|
|
137
|
+
save_file_only: bool
|
|
138
|
+
save_path_x: str
|
|
139
|
+
save_path_y: str
|
|
140
|
+
number_of_runs: int
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_hydra(cls, cfg_hydra: DictConfig):
|
|
144
|
+
|
|
145
|
+
assert not os.path.exists(cfg_hydra.output_dir), f'Output directory {cfg_hydra.output_dir} already exists! Please change to a new folder.'
|
|
146
|
+
|
|
147
|
+
output_dir = Path(cfg_hydra.output_dir)
|
|
148
|
+
|
|
149
|
+
devices = [torch.device(device) for device in cfg_hydra.devices]
|
|
150
|
+
|
|
151
|
+
# Initialize device to cpu, DDP will overwrite this
|
|
152
|
+
device = torch.device("cpu")
|
|
153
|
+
|
|
154
|
+
return cls(
|
|
155
|
+
run_name=cfg_hydra.run_name,
|
|
156
|
+
output_dir=output_dir,
|
|
157
|
+
devices=devices,
|
|
158
|
+
device=device,
|
|
159
|
+
max_cpus_per_device=cfg_hydra.max_cpus_per_device,
|
|
160
|
+
use_ddp=len(devices) > 1,
|
|
161
|
+
seed=cfg_hydra.seed,
|
|
162
|
+
workers_per_gpu=cfg_hydra.workers_per_gpu,
|
|
163
|
+
model = ConfigModel(
|
|
164
|
+
name = ModelName[cfg_hydra.model.name],
|
|
165
|
+
hyperparams = OmegaConf.to_container(cfg_hydra.model.hyperparams),
|
|
166
|
+
),
|
|
167
|
+
data = ConfigData(
|
|
168
|
+
generator=GeneratorName(cfg_hydra.data.generator),
|
|
169
|
+
min_samples_support=cfg_hydra.data.min_samples_support,
|
|
170
|
+
max_samples_support=cfg_hydra.data.max_samples_support,
|
|
171
|
+
n_samples_query=cfg_hydra.data.n_samples_query,
|
|
172
|
+
min_features=cfg_hydra.data.min_features,
|
|
173
|
+
max_features=cfg_hydra.data.max_features,
|
|
174
|
+
max_classes=cfg_hydra.data.max_classes,
|
|
175
|
+
task=Task[cfg_hydra.data.task],
|
|
176
|
+
sample_multinomial_categorical=cfg_hydra.data.sample_multinomial_categorical,
|
|
177
|
+
sample_multinomial_label=cfg_hydra.data.sample_multinomial_label,
|
|
178
|
+
generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams), # type: ignore
|
|
179
|
+
),
|
|
180
|
+
optim = ConfigOptim.from_hydra(cfg_hydra.optim),
|
|
181
|
+
preprocessing = ConfigPreprocessing(**cfg_hydra.preprocessing),
|
|
182
|
+
load_from_file = cfg_hydra.load_from_file,
|
|
183
|
+
load_path_x = cfg_hydra.load_path_x,
|
|
184
|
+
load_path_y = cfg_hydra.load_path_y,
|
|
185
|
+
save_file = cfg_hydra.save_file,
|
|
186
|
+
save_file_only = cfg_hydra.save_file_only,
|
|
187
|
+
save_path_x = cfg_hydra.save_path_x,
|
|
188
|
+
save_path_y = cfg_hydra.save_path_y,
|
|
189
|
+
number_of_runs = cfg_hydra.number_of_runs,
|
|
190
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
|
|
8
|
+
from ..._internal.config.enums import ModelName
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ConfigRun(ConfigSaveLoadMixin):
|
|
13
|
+
device: torch.device
|
|
14
|
+
seed: int
|
|
15
|
+
model_name: ModelName
|
|
16
|
+
hyperparams: dict
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def create(
|
|
20
|
+
cls,
|
|
21
|
+
device: torch.device,
|
|
22
|
+
seed: int,
|
|
23
|
+
model_name: ModelName,
|
|
24
|
+
hyperparams: dict
|
|
25
|
+
) -> "ConfigRun":
|
|
26
|
+
|
|
27
|
+
return cls(
|
|
28
|
+
device=device,
|
|
29
|
+
seed=seed,
|
|
30
|
+
model_name=model_name,
|
|
31
|
+
hyperparams=hyperparams
|
|
32
|
+
)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from enum import IntEnum
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
except ImportError:
|
|
6
|
+
# StrEnum is not available in Python < 3.11, so we create a compatible version
|
|
7
|
+
from enum import Enum
|
|
8
|
+
class StrEnum(str, Enum):
|
|
9
|
+
"""
|
|
10
|
+
Enum where members are also (and must be) strings
|
|
11
|
+
"""
|
|
12
|
+
def __new__(cls, value):
|
|
13
|
+
if not isinstance(value, str):
|
|
14
|
+
raise TypeError(f"{value!r} is not a string")
|
|
15
|
+
return super().__new__(cls, value)
|
|
16
|
+
|
|
17
|
+
def __str__(self):
|
|
18
|
+
return self.value
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Task(StrEnum):
|
|
22
|
+
CLASSIFICATION = "classification"
|
|
23
|
+
REGRESSION = "regression"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FeatureType(StrEnum):
|
|
27
|
+
NUMERICAL = "numerical"
|
|
28
|
+
CATEGORICAL = "categorical"
|
|
29
|
+
MIXED = "mixed"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SearchType(StrEnum):
|
|
33
|
+
DEFAULT = "default"
|
|
34
|
+
RANDOM = "random"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DatasetSize(IntEnum):
|
|
38
|
+
SMALL = 1000
|
|
39
|
+
MEDIUM = 10000
|
|
40
|
+
LARGE = 50000
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DataSplit(StrEnum):
|
|
44
|
+
TRAIN = "train"
|
|
45
|
+
VALID = "valid"
|
|
46
|
+
TEST = "test"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Phase(StrEnum):
|
|
50
|
+
TRAINING = "training"
|
|
51
|
+
VALIDATION = "validation"
|
|
52
|
+
TESTING = "testing"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ModelName(StrEnum):
|
|
56
|
+
PLACEHOLDER = "_placeholder_" # This is a placeholder for the current running model
|
|
57
|
+
FT_TRANSFORMER = "FT-Transformer"
|
|
58
|
+
TABPFN = "TabPFN"
|
|
59
|
+
FOUNDATION = "Foundation"
|
|
60
|
+
FOUNDATION_FLASH = "FoundationFlash"
|
|
61
|
+
TAB2D = "Tab2D"
|
|
62
|
+
TAB2D_COL_ROW = "Tab2D_COL_ROW"
|
|
63
|
+
TAB2D_SDPA = "Tab2D_SDPA"
|
|
64
|
+
SAINT = "SAINT"
|
|
65
|
+
MLP = "MLP"
|
|
66
|
+
MLP_RTDL = "MLP-rtdl"
|
|
67
|
+
RESNET = "Resnet"
|
|
68
|
+
RANDOM_FOREST = "RandomForest"
|
|
69
|
+
XGBOOST = "XGBoost"
|
|
70
|
+
CATBOOST = "CatBoost"
|
|
71
|
+
LIGHTGBM = "LightGBM"
|
|
72
|
+
GRADIENT_BOOSTING_TREE = "GradientBoostingTree"
|
|
73
|
+
HIST_GRADIENT_BOOSTING_TREE = "HistGradientBoostingTree"
|
|
74
|
+
LOGISTIC_REGRESSION = "LogisticRegression"
|
|
75
|
+
LINEAR_REGRESSION = "LinearRegression"
|
|
76
|
+
DECISION_TREE = "DecisionTree"
|
|
77
|
+
KNN = "KNN"
|
|
78
|
+
STG = "STG"
|
|
79
|
+
SVM = "SVM"
|
|
80
|
+
TABNET = "TabNet"
|
|
81
|
+
TABTRANSFORMER = "TabTransformer"
|
|
82
|
+
DEEPFM = "DeepFM"
|
|
83
|
+
VIME = "VIME"
|
|
84
|
+
DANET = "DANet"
|
|
85
|
+
NODE = "NODE"
|
|
86
|
+
AUTOGLUON = "AutoGluon"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ModelClass(StrEnum):
|
|
90
|
+
BASE = 'base'
|
|
91
|
+
GBDT = 'GBDT'
|
|
92
|
+
NN = 'NN'
|
|
93
|
+
ICLT = 'ICLT'
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DownstreamTask(StrEnum):
|
|
97
|
+
ZEROSHOT = "zeroshot"
|
|
98
|
+
FINETUNE = "finetune"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class BenchmarkName(StrEnum):
|
|
103
|
+
DEBUG_CLASSIFICATION = "debug_classification"
|
|
104
|
+
DEBUG_REGRESSION = "debug_regression"
|
|
105
|
+
DEBUG_TABZILLA = "debug_tabzilla"
|
|
106
|
+
|
|
107
|
+
CATEGORICAL_CLASSIFICATION = "categorical_classification"
|
|
108
|
+
NUMERICAL_CLASSIFICATION = "numerical_classification"
|
|
109
|
+
CATEGORICAL_REGRESSION = "categorical_regression"
|
|
110
|
+
NUMERICAL_REGRESSION = "numerical_regression"
|
|
111
|
+
CATEGORICAL_CLASSIFICATION_LARGE = "categorical_classification_large"
|
|
112
|
+
NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
|
|
113
|
+
CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
|
|
114
|
+
NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
|
|
115
|
+
|
|
116
|
+
TABZILLA_HARD = "tabzilla_hard"
|
|
117
|
+
TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
|
|
118
|
+
TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class BenchmarkOrigin(StrEnum):
|
|
122
|
+
TABZILLA = "tabzilla"
|
|
123
|
+
WHYTREES = "whytrees"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class GeneratorName(StrEnum):
|
|
127
|
+
TABPFN = 'tabpfn'
|
|
128
|
+
TREE = 'tree'
|
|
129
|
+
RANDOMFOREST = 'randomforest'
|
|
130
|
+
NEIGHBOR = 'neighbor'
|
|
131
|
+
MIX = 'mix'
|
|
132
|
+
PERLIN = 'perlin'
|
|
133
|
+
MIX_7 = 'mix_7'
|
|
134
|
+
MIX_6 = 'mix_6'
|
|
135
|
+
MIX_5 = 'mix_5'
|
|
136
|
+
MIX_5_GP = 'mix_5_gp'
|
|
137
|
+
MIX_4 = 'mix_4'
|
|
138
|
+
MIX_4_AG = 'mix_4_ag'
|
|
139
|
+
LR = 'lr'
|
|
140
|
+
POLY = 'poly'
|
|
141
|
+
SAMPLE_RF = 'sample_rf'
|
|
142
|
+
SAMPLE_GP = 'sample_gp'
|
|
143
|
+
TABREPO = 'tabrepo'
|
|
144
|
+
MIX_4_TABREPO = 'mix_4_tabrepo'
|
|
145
|
+
MIX_4_TABPFNV2 = 'mix_4_tabpfnv2'
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class MetricName(StrEnum):
|
|
149
|
+
ACCURACY = "accuracy"
|
|
150
|
+
F1 = "f1"
|
|
151
|
+
AUC = "auc"
|
|
152
|
+
MSE = "mse"
|
|
153
|
+
MAE = "mae"
|
|
154
|
+
R2 = "r2"
|
|
155
|
+
LOG_LOSS = "log_loss"
|
|
156
|
+
RMSE = "rmse"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class LossName(StrEnum):
|
|
160
|
+
CROSS_ENTROPY = "cross_entropy"
|
|
161
|
+
MSE = "mse"
|
|
162
|
+
MAE = "mae"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Core modules for MitraModel
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EarlyStopping():
|
|
6
|
+
|
|
7
|
+
def __init__(self, patience=10, delta=0.0001, metric='log_loss'):
|
|
8
|
+
|
|
9
|
+
self.patience = patience
|
|
10
|
+
self.counter = 0
|
|
11
|
+
self.best_score = None
|
|
12
|
+
self.early_stop = False
|
|
13
|
+
self.delta = delta
|
|
14
|
+
self.metric = metric
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def __call__(self, val_loss):
|
|
18
|
+
|
|
19
|
+
# smaller is better for these metrics
|
|
20
|
+
if self.metric in ["log_loss", "mse", "mae", "rmse"]:
|
|
21
|
+
score = -val_loss
|
|
22
|
+
# larger is better for these metrics
|
|
23
|
+
elif self.metric in ["accuracy", "roc_auc", "r2"]:
|
|
24
|
+
score = val_loss
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Unsupported metric: {self.metric}. Supported metrics are: log_loss, mse, mae, rmse, accuracy, roc_auc, r2.")
|
|
27
|
+
|
|
28
|
+
if self.best_score is None:
|
|
29
|
+
self.best_score = score
|
|
30
|
+
elif score < self.best_score + self.delta:
|
|
31
|
+
self.counter += 1
|
|
32
|
+
if self.counter >= self.patience:
|
|
33
|
+
self.early_stop = True
|
|
34
|
+
else:
|
|
35
|
+
self.best_score = score
|
|
36
|
+
self.counter = 0
|
|
37
|
+
|
|
38
|
+
def we_should_stop(self):
|
|
39
|
+
return self.early_stop
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Checkpoint():
|
|
43
|
+
|
|
44
|
+
def __init__(self):
|
|
45
|
+
self.curr_best_loss = np.inf
|
|
46
|
+
self.best_model: dict
|
|
47
|
+
|
|
48
|
+
def reset(self, net: torch.nn.Module):
|
|
49
|
+
self.curr_best_loss = np.inf
|
|
50
|
+
self.best_model = net.state_dict()
|
|
51
|
+
for key in self.best_model:
|
|
52
|
+
self.best_model[key] = self.best_model[key].to('cpu')
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def __call__(self, net: torch.nn.Module, loss: float):
|
|
56
|
+
|
|
57
|
+
if loss < self.curr_best_loss:
|
|
58
|
+
self.curr_best_loss = loss
|
|
59
|
+
self.best_model = net.state_dict()
|
|
60
|
+
for key in self.best_model:
|
|
61
|
+
self.best_model[key] = self.best_model[key].to('cpu')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def set_to_best(self, net):
|
|
65
|
+
net.load_state_dict(self.best_model)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EpochStatistics():
|
|
69
|
+
|
|
70
|
+
def __init__(self) -> None:
|
|
71
|
+
self.n = 0
|
|
72
|
+
self.loss = 0
|
|
73
|
+
self.score = 0
|
|
74
|
+
|
|
75
|
+
def update(self, loss, score, n):
|
|
76
|
+
self.n += n
|
|
77
|
+
self.loss += loss * n
|
|
78
|
+
self.score += score * n
|
|
79
|
+
|
|
80
|
+
def get(self):
|
|
81
|
+
return self.loss / self.n, self.score / self.n
|
|
82
|
+
|
|
83
|
+
class TrackOutput():
|
|
84
|
+
|
|
85
|
+
def __init__(self) -> None:
|
|
86
|
+
self.y_true: list[np.ndarray] = []
|
|
87
|
+
self.y_pred: list[np.ndarray] = []
|
|
88
|
+
|
|
89
|
+
def update(self, y_true: np.ndarray, y_pred: np.ndarray):
|
|
90
|
+
self.y_true.append(y_true)
|
|
91
|
+
self.y_pred.append(y_pred)
|
|
92
|
+
|
|
93
|
+
def get(self):
|
|
94
|
+
return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import einops
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
|
5
|
+
from ..._internal.config.config_run import ConfigRun
|
|
6
|
+
from ..._internal.config.enums import LossName, Task
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CrossEntropyLossExtraBatch(torch.nn.Module):
|
|
10
|
+
|
|
11
|
+
def __init__(self, label_smoothing: float):
|
|
12
|
+
super().__init__()
|
|
13
|
+
|
|
14
|
+
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def forward(self, input, target):
|
|
18
|
+
"""
|
|
19
|
+
Input has shape (batch_size, num_samples, num_classes)
|
|
20
|
+
Target has shape (batch_size, num_samples)
|
|
21
|
+
|
|
22
|
+
Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
input = einops.rearrange(input, 'b s c -> (b s) c')
|
|
26
|
+
target = einops.rearrange(target, 'b s -> (b s)')
|
|
27
|
+
|
|
28
|
+
return self.loss(input, target)
|
|
29
|
+
|
|
30
|
+
def get_loss(cfg: ConfigRun):
|
|
31
|
+
|
|
32
|
+
if cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MSE:
|
|
33
|
+
return torch.nn.MSELoss()
|
|
34
|
+
elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MAE:
|
|
35
|
+
return torch.nn.L1Loss()
|
|
36
|
+
elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
|
37
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
|
38
|
+
elif cfg.task == Task.CLASSIFICATION:
|
|
39
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
|
|
42
|
+
|
|
43
|
+
def get_loss_pretrain(cfg: ConfigPretrain):
|
|
44
|
+
|
|
45
|
+
if cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MSE:
|
|
46
|
+
return torch.nn.MSELoss()
|
|
47
|
+
elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MAE:
|
|
48
|
+
return torch.nn.L1Loss()
|
|
49
|
+
elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.CROSS_ENTROPY:
|
|
50
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
|
51
|
+
elif cfg.data.task == Task.CLASSIFICATION:
|
|
52
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|