autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/pipeline_presets.py +130 -0
  4. autogluon/tabular/configs/presets_configs.py +51 -26
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  7. autogluon/tabular/models/__init__.py +6 -1
  8. autogluon/tabular/models/_utils/rapids_utils.py +1 -1
  9. autogluon/tabular/models/automm/automm_model.py +2 -0
  10. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  11. autogluon/tabular/models/catboost/callbacks.py +3 -2
  12. autogluon/tabular/models/catboost/catboost_model.py +15 -9
  13. autogluon/tabular/models/catboost/catboost_utils.py +17 -3
  14. autogluon/tabular/models/ebm/__init__.py +0 -0
  15. autogluon/tabular/models/ebm/ebm_model.py +259 -0
  16. autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
  17. autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  18. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  19. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
  20. autogluon/tabular/models/knn/knn_model.py +7 -3
  21. autogluon/tabular/models/lgb/lgb_model.py +60 -21
  22. autogluon/tabular/models/lr/lr_model.py +6 -1
  23. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  24. autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
  25. autogluon/tabular/models/mitra/__init__.py +0 -0
  26. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  27. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  28. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  29. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  30. autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
  31. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  32. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  33. autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
  34. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  35. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  36. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
  37. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
  38. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  39. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  40. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
  41. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
  42. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  43. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  44. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  45. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  46. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  47. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  48. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  49. autogluon/tabular/models/mitra/mitra_model.py +380 -0
  50. autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
  51. autogluon/tabular/models/realmlp/__init__.py +0 -0
  52. autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
  53. autogluon/tabular/models/rf/rf_model.py +11 -6
  54. autogluon/tabular/models/tabicl/__init__.py +0 -0
  55. autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
  56. autogluon/tabular/models/tabm/__init__.py +0 -0
  57. autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
  58. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
  59. autogluon/tabular/models/tabm/tabm_model.py +356 -0
  60. autogluon/tabular/models/tabm/tabm_reference.py +631 -0
  61. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
  62. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  63. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  64. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  65. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  66. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  67. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  68. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  69. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  70. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
  71. autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
  72. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
  73. autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
  74. autogluon/tabular/predictor/predictor.py +147 -84
  75. autogluon/tabular/registry/_ag_model_registry.py +12 -2
  76. autogluon/tabular/testing/fit_helper.py +57 -27
  77. autogluon/tabular/testing/generate_datasets.py +7 -0
  78. autogluon/tabular/trainer/abstract_trainer.py +3 -1
  79. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  80. autogluon/tabular/version.py +1 -1
  81. autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
  82. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
  83. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
  84. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
  85. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  86. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  87. autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
  88. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
  89. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
  90. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
  91. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
  92. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
@@ -1,7 +1,5 @@
1
1
  import logging
2
2
 
3
- import numpy as np
4
-
5
3
  from autogluon.common.utils.try_import import try_import_rapids_cuml
6
4
  from autogluon.core.constants import REGRESSION
7
5
 
@@ -51,10 +49,52 @@ class LinearRapidsModel(RapidsModelMixin, LinearModel):
51
49
 
52
50
  def _preprocess(self, X, **kwargs):
53
51
  X = super()._preprocess(X=X, **kwargs)
54
- if not isinstance(X, np.ndarray):
52
+ if hasattr(X, 'toarray'): # Check if it's a sparse matrix
55
53
  X = X.toarray()
56
54
  return X
57
55
 
58
56
  def _fit(self, X, y, **kwargs):
59
- kwargs.pop("sample_weight", None) # sample_weight is not supported
60
- super()._fit(X=X, y=y, **kwargs)
57
+ """
58
+ Custom fit method for RAPIDS cuML models that handles parameter compatibility
59
+ and bypasses sklearn-specific incremental training approach.
60
+ """
61
+ # Preprocess data
62
+ X = self.preprocess(X, is_train=True)
63
+ if self.problem_type == 'binary':
64
+ y = y.astype(int).values
65
+
66
+ # Create cuML model with filtered parameters
67
+ model_cls = self._get_model_type()
68
+
69
+ # Comprehensive parameter filtering for cuML compatibility
70
+ cuml_incompatible_params = {
71
+ # AutoGluon-specific preprocessing parameters
72
+ 'vectorizer_dict_size', 'proc.ngram_range', 'proc.skew_threshold',
73
+ 'proc.impute_strategy', 'handle_text',
74
+ # sklearn-specific parameters not supported by cuML
75
+ 'n_jobs', 'warm_start', 'multi_class', 'dual', 'intercept_scaling',
76
+ 'class_weight', 'random_state', 'verbose',
77
+ # Parameters that need conversion or special handling
78
+ 'penalty', 'C'
79
+ }
80
+
81
+ # Filter out incompatible parameters
82
+ filtered_params = {k: v for k, v in self.params.items()
83
+ if k not in cuml_incompatible_params}
84
+
85
+ # Handle parameter conversions for cuML
86
+ if self.problem_type == REGRESSION:
87
+ # Convert sklearn's C parameter to cuML's alpha
88
+ if 'C' in self.params:
89
+ filtered_params['alpha'] = 1.0 / self.params['C']
90
+ else:
91
+ # For classification, keep C parameter
92
+ if 'C' in self.params:
93
+ filtered_params['C'] = self.params['C']
94
+
95
+ # Create and fit cuML model - let cuML handle its own error messages
96
+ self.model = model_cls(**filtered_params)
97
+ self.model.fit(X, y)
98
+
99
+ # Add missing sklearn-compatible attributes for AutoGluon compatibility
100
+ self.model.n_iter_ = None # cuML doesn't track iterations like sklearn
File without changes
@@ -0,0 +1 @@
1
+ # Internal modules for MitraModel
@@ -0,0 +1 @@
1
+ # Configuration modules for MitraModel
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import yaml
7
+ import os
8
+
9
+ import torch
10
+ from omegaconf import DictConfig, OmegaConf
11
+
12
+ from ..._internal.config.enums import GeneratorName, ModelName, LossName, Task
13
+
14
+ @dataclass
15
+ class ConfigData():
16
+ generator: GeneratorName
17
+ min_samples_support: int
18
+ max_samples_support: int
19
+ n_samples_query: int
20
+ min_features: int
21
+ max_features: int
22
+ max_classes: int
23
+ sample_multinomial_categorical: bool
24
+ sample_multinomial_label: bool
25
+ generator_hyperparams: dict
26
+ task: Task
27
+
28
+ def __post_init__(self):
29
+
30
+ assert self.min_samples_support <= self.max_samples_support
31
+ assert self.min_features <= self.max_features
32
+
33
+ @dataclass
34
+ class ConfigModel():
35
+ name: ModelName
36
+ hyperparams: dict
37
+
38
+
39
+ @dataclass
40
+ class ConfigPreprocessing():
41
+ use_quantile_transformer: bool
42
+ use_feature_count_scaling: bool
43
+
44
+ @dataclass
45
+ class ConfigGradScaler():
46
+ enabled: bool
47
+ scale_init: float
48
+ scale_min: float
49
+ growth_interval: int
50
+
51
+
52
+ def __post_init__(self):
53
+ assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
54
+ assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
55
+ assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
56
+ assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"
57
+
58
+ @dataclass
59
+ class ConfigOptim():
60
+ steps: int
61
+ log_every_n_steps: int
62
+ eval_every_n_steps: int
63
+ batch_size: int
64
+ gradient_accumulation_steps: int
65
+ lr: float
66
+ weight_decay: float
67
+ beta1: float
68
+ beta2: float
69
+ warmup_steps: int
70
+ cosine_scheduler: bool
71
+ max_grad_norm: float
72
+ label_smoothing: float
73
+ regression_loss: LossName
74
+ use_pretrained_weights: bool
75
+ path_to_weights: str
76
+ resume_states: bool
77
+ path_to_states: str
78
+ precision: str
79
+ grad_scaler: ConfigGradScaler
80
+
81
+ @classmethod
82
+ def from_hydra(cls, cfg_hydra: DictConfig) -> Self:
83
+
84
+ grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
85
+ cfg_dict: dict = OmegaConf.to_container(cfg_hydra) # type: ignore
86
+ del cfg_dict["grad_scaler"]
87
+
88
+ regression_loss = LossName[cfg_dict["regression_loss"]]
89
+ del cfg_dict["regression_loss"]
90
+
91
+ return cls(
92
+ grad_scaler=grad_scaler,
93
+ regression_loss=regression_loss,
94
+ **cfg_dict
95
+ )
96
+
97
+ def __post_init__(self):
98
+ assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
99
+
100
+ class ConfigSaveLoadMixin(yaml.YAMLObject):
101
+
102
+ def save(self, path: Path) -> None:
103
+
104
+ path.parent.mkdir(parents=True, exist_ok=True)
105
+
106
+ with open(path, 'w') as f:
107
+ yaml.dump(self, f, default_flow_style=False)
108
+
109
+
110
+ @classmethod
111
+ def load(cls, path: Path) -> Self:
112
+
113
+ with open(path, 'r') as f:
114
+ # It's unsafe, but not unsafer than the pickle module
115
+ config = yaml.unsafe_load(f)
116
+
117
+ return config
118
+
119
+ @dataclass
120
+ class ConfigPretrain(ConfigSaveLoadMixin):
121
+ run_name: str
122
+ output_dir: Path
123
+ seed: int
124
+ devices: list[torch.device]
125
+ device: torch.device
126
+ max_cpus_per_device: Optional[int]
127
+ use_ddp: bool
128
+ workers_per_gpu: int
129
+ model: ConfigModel
130
+ data: ConfigData
131
+ optim: ConfigOptim
132
+ preprocessing: ConfigPreprocessing
133
+ load_from_file: bool
134
+ load_path_x: str
135
+ load_path_y: str
136
+ save_file: bool
137
+ save_file_only: bool
138
+ save_path_x: str
139
+ save_path_y: str
140
+ number_of_runs: int
141
+
142
+ @classmethod
143
+ def from_hydra(cls, cfg_hydra: DictConfig):
144
+
145
+ assert not os.path.exists(cfg_hydra.output_dir), f'Output directory {cfg_hydra.output_dir} already exists! Please change to a new folder.'
146
+
147
+ output_dir = Path(cfg_hydra.output_dir)
148
+
149
+ devices = [torch.device(device) for device in cfg_hydra.devices]
150
+
151
+ # Initialize device to cpu, DDP will overwrite this
152
+ device = torch.device("cpu")
153
+
154
+ return cls(
155
+ run_name=cfg_hydra.run_name,
156
+ output_dir=output_dir,
157
+ devices=devices,
158
+ device=device,
159
+ max_cpus_per_device=cfg_hydra.max_cpus_per_device,
160
+ use_ddp=len(devices) > 1,
161
+ seed=cfg_hydra.seed,
162
+ workers_per_gpu=cfg_hydra.workers_per_gpu,
163
+ model = ConfigModel(
164
+ name = ModelName[cfg_hydra.model.name],
165
+ hyperparams = OmegaConf.to_container(cfg_hydra.model.hyperparams),
166
+ ),
167
+ data = ConfigData(
168
+ generator=GeneratorName(cfg_hydra.data.generator),
169
+ min_samples_support=cfg_hydra.data.min_samples_support,
170
+ max_samples_support=cfg_hydra.data.max_samples_support,
171
+ n_samples_query=cfg_hydra.data.n_samples_query,
172
+ min_features=cfg_hydra.data.min_features,
173
+ max_features=cfg_hydra.data.max_features,
174
+ max_classes=cfg_hydra.data.max_classes,
175
+ task=Task[cfg_hydra.data.task],
176
+ sample_multinomial_categorical=cfg_hydra.data.sample_multinomial_categorical,
177
+ sample_multinomial_label=cfg_hydra.data.sample_multinomial_label,
178
+ generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams), # type: ignore
179
+ ),
180
+ optim = ConfigOptim.from_hydra(cfg_hydra.optim),
181
+ preprocessing = ConfigPreprocessing(**cfg_hydra.preprocessing),
182
+ load_from_file = cfg_hydra.load_from_file,
183
+ load_path_x = cfg_hydra.load_path_x,
184
+ load_path_y = cfg_hydra.load_path_y,
185
+ save_file = cfg_hydra.save_file,
186
+ save_file_only = cfg_hydra.save_file_only,
187
+ save_path_x = cfg_hydra.save_path_x,
188
+ save_path_y = cfg_hydra.save_path_y,
189
+ number_of_runs = cfg_hydra.number_of_runs,
190
+ )
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+
7
+ from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
8
+ from ..._internal.config.enums import ModelName
9
+
10
+
11
+ @dataclass
12
+ class ConfigRun(ConfigSaveLoadMixin):
13
+ device: torch.device
14
+ seed: int
15
+ model_name: ModelName
16
+ hyperparams: dict
17
+
18
+ @classmethod
19
+ def create(
20
+ cls,
21
+ device: torch.device,
22
+ seed: int,
23
+ model_name: ModelName,
24
+ hyperparams: dict
25
+ ) -> "ConfigRun":
26
+
27
+ return cls(
28
+ device=device,
29
+ seed=seed,
30
+ model_name=model_name,
31
+ hyperparams=hyperparams
32
+ )
@@ -0,0 +1,162 @@
1
+ from enum import IntEnum
2
+
3
+ try:
4
+ from enum import StrEnum
5
+ except ImportError:
6
+ # StrEnum is not available in Python < 3.11, so we create a compatible version
7
+ from enum import Enum
8
+ class StrEnum(str, Enum):
9
+ """
10
+ Enum where members are also (and must be) strings
11
+ """
12
+ def __new__(cls, value):
13
+ if not isinstance(value, str):
14
+ raise TypeError(f"{value!r} is not a string")
15
+ return super().__new__(cls, value)
16
+
17
+ def __str__(self):
18
+ return self.value
19
+
20
+
21
+ class Task(StrEnum):
22
+ CLASSIFICATION = "classification"
23
+ REGRESSION = "regression"
24
+
25
+
26
+ class FeatureType(StrEnum):
27
+ NUMERICAL = "numerical"
28
+ CATEGORICAL = "categorical"
29
+ MIXED = "mixed"
30
+
31
+
32
+ class SearchType(StrEnum):
33
+ DEFAULT = "default"
34
+ RANDOM = "random"
35
+
36
+
37
+ class DatasetSize(IntEnum):
38
+ SMALL = 1000
39
+ MEDIUM = 10000
40
+ LARGE = 50000
41
+
42
+
43
+ class DataSplit(StrEnum):
44
+ TRAIN = "train"
45
+ VALID = "valid"
46
+ TEST = "test"
47
+
48
+
49
+ class Phase(StrEnum):
50
+ TRAINING = "training"
51
+ VALIDATION = "validation"
52
+ TESTING = "testing"
53
+
54
+
55
+ class ModelName(StrEnum):
56
+ PLACEHOLDER = "_placeholder_" # This is a placeholder for the current running model
57
+ FT_TRANSFORMER = "FT-Transformer"
58
+ TABPFN = "TabPFN"
59
+ FOUNDATION = "Foundation"
60
+ FOUNDATION_FLASH = "FoundationFlash"
61
+ TAB2D = "Tab2D"
62
+ TAB2D_COL_ROW = "Tab2D_COL_ROW"
63
+ TAB2D_SDPA = "Tab2D_SDPA"
64
+ SAINT = "SAINT"
65
+ MLP = "MLP"
66
+ MLP_RTDL = "MLP-rtdl"
67
+ RESNET = "Resnet"
68
+ RANDOM_FOREST = "RandomForest"
69
+ XGBOOST = "XGBoost"
70
+ CATBOOST = "CatBoost"
71
+ LIGHTGBM = "LightGBM"
72
+ GRADIENT_BOOSTING_TREE = "GradientBoostingTree"
73
+ HIST_GRADIENT_BOOSTING_TREE = "HistGradientBoostingTree"
74
+ LOGISTIC_REGRESSION = "LogisticRegression"
75
+ LINEAR_REGRESSION = "LinearRegression"
76
+ DECISION_TREE = "DecisionTree"
77
+ KNN = "KNN"
78
+ STG = "STG"
79
+ SVM = "SVM"
80
+ TABNET = "TabNet"
81
+ TABTRANSFORMER = "TabTransformer"
82
+ DEEPFM = "DeepFM"
83
+ VIME = "VIME"
84
+ DANET = "DANet"
85
+ NODE = "NODE"
86
+ AUTOGLUON = "AutoGluon"
87
+
88
+
89
+ class ModelClass(StrEnum):
90
+ BASE = 'base'
91
+ GBDT = 'GBDT'
92
+ NN = 'NN'
93
+ ICLT = 'ICLT'
94
+
95
+
96
+ class DownstreamTask(StrEnum):
97
+ ZEROSHOT = "zeroshot"
98
+ FINETUNE = "finetune"
99
+
100
+
101
+
102
+ class BenchmarkName(StrEnum):
103
+ DEBUG_CLASSIFICATION = "debug_classification"
104
+ DEBUG_REGRESSION = "debug_regression"
105
+ DEBUG_TABZILLA = "debug_tabzilla"
106
+
107
+ CATEGORICAL_CLASSIFICATION = "categorical_classification"
108
+ NUMERICAL_CLASSIFICATION = "numerical_classification"
109
+ CATEGORICAL_REGRESSION = "categorical_regression"
110
+ NUMERICAL_REGRESSION = "numerical_regression"
111
+ CATEGORICAL_CLASSIFICATION_LARGE = "categorical_classification_large"
112
+ NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
113
+ CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
114
+ NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
115
+
116
+ TABZILLA_HARD = "tabzilla_hard"
117
+ TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
118
+ TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
119
+
120
+
121
+ class BenchmarkOrigin(StrEnum):
122
+ TABZILLA = "tabzilla"
123
+ WHYTREES = "whytrees"
124
+
125
+
126
+ class GeneratorName(StrEnum):
127
+ TABPFN = 'tabpfn'
128
+ TREE = 'tree'
129
+ RANDOMFOREST = 'randomforest'
130
+ NEIGHBOR = 'neighbor'
131
+ MIX = 'mix'
132
+ PERLIN = 'perlin'
133
+ MIX_7 = 'mix_7'
134
+ MIX_6 = 'mix_6'
135
+ MIX_5 = 'mix_5'
136
+ MIX_5_GP = 'mix_5_gp'
137
+ MIX_4 = 'mix_4'
138
+ MIX_4_AG = 'mix_4_ag'
139
+ LR = 'lr'
140
+ POLY = 'poly'
141
+ SAMPLE_RF = 'sample_rf'
142
+ SAMPLE_GP = 'sample_gp'
143
+ TABREPO = 'tabrepo'
144
+ MIX_4_TABREPO = 'mix_4_tabrepo'
145
+ MIX_4_TABPFNV2 = 'mix_4_tabpfnv2'
146
+
147
+
148
+ class MetricName(StrEnum):
149
+ ACCURACY = "accuracy"
150
+ F1 = "f1"
151
+ AUC = "auc"
152
+ MSE = "mse"
153
+ MAE = "mae"
154
+ R2 = "r2"
155
+ LOG_LOSS = "log_loss"
156
+ RMSE = "rmse"
157
+
158
+
159
+ class LossName(StrEnum):
160
+ CROSS_ENTROPY = "cross_entropy"
161
+ MSE = "mse"
162
+ MAE = "mae"
@@ -0,0 +1 @@
1
+ # Core modules for MitraModel
@@ -0,0 +1,94 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class EarlyStopping():
6
+
7
+ def __init__(self, patience=10, delta=0.0001, metric='log_loss'):
8
+
9
+ self.patience = patience
10
+ self.counter = 0
11
+ self.best_score = None
12
+ self.early_stop = False
13
+ self.delta = delta
14
+ self.metric = metric
15
+
16
+
17
+ def __call__(self, val_loss):
18
+
19
+ # smaller is better for these metrics
20
+ if self.metric in ["log_loss", "mse", "mae", "rmse"]:
21
+ score = -val_loss
22
+ # larger is better for these metrics
23
+ elif self.metric in ["accuracy", "roc_auc", "r2"]:
24
+ score = val_loss
25
+ else:
26
+ raise ValueError(f"Unsupported metric: {self.metric}. Supported metrics are: log_loss, mse, mae, rmse, accuracy, roc_auc, r2.")
27
+
28
+ if self.best_score is None:
29
+ self.best_score = score
30
+ elif score < self.best_score + self.delta:
31
+ self.counter += 1
32
+ if self.counter >= self.patience:
33
+ self.early_stop = True
34
+ else:
35
+ self.best_score = score
36
+ self.counter = 0
37
+
38
+ def we_should_stop(self):
39
+ return self.early_stop
40
+
41
+
42
+ class Checkpoint():
43
+
44
+ def __init__(self):
45
+ self.curr_best_loss = np.inf
46
+ self.best_model: dict
47
+
48
+ def reset(self, net: torch.nn.Module):
49
+ self.curr_best_loss = np.inf
50
+ self.best_model = net.state_dict()
51
+ for key in self.best_model:
52
+ self.best_model[key] = self.best_model[key].to('cpu')
53
+
54
+
55
+ def __call__(self, net: torch.nn.Module, loss: float):
56
+
57
+ if loss < self.curr_best_loss:
58
+ self.curr_best_loss = loss
59
+ self.best_model = net.state_dict()
60
+ for key in self.best_model:
61
+ self.best_model[key] = self.best_model[key].to('cpu')
62
+
63
+
64
+ def set_to_best(self, net):
65
+ net.load_state_dict(self.best_model)
66
+
67
+
68
+ class EpochStatistics():
69
+
70
+ def __init__(self) -> None:
71
+ self.n = 0
72
+ self.loss = 0
73
+ self.score = 0
74
+
75
+ def update(self, loss, score, n):
76
+ self.n += n
77
+ self.loss += loss * n
78
+ self.score += score * n
79
+
80
+ def get(self):
81
+ return self.loss / self.n, self.score / self.n
82
+
83
+ class TrackOutput():
84
+
85
+ def __init__(self) -> None:
86
+ self.y_true: list[np.ndarray] = []
87
+ self.y_pred: list[np.ndarray] = []
88
+
89
+ def update(self, y_true: np.ndarray, y_pred: np.ndarray):
90
+ self.y_true.append(y_true)
91
+ self.y_pred.append(y_pred)
92
+
93
+ def get(self):
94
+ return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
@@ -0,0 +1,54 @@
1
+ import einops
2
+ import torch
3
+
4
+ from ..._internal.config.config_pretrain import ConfigPretrain
5
+ from ..._internal.config.config_run import ConfigRun
6
+ from ..._internal.config.enums import LossName, Task
7
+
8
+
9
+ class CrossEntropyLossExtraBatch(torch.nn.Module):
10
+
11
+ def __init__(self, label_smoothing: float):
12
+ super().__init__()
13
+
14
+ self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
15
+
16
+
17
+ def forward(self, input, target):
18
+ """
19
+ Input has shape (batch_size, num_samples, num_classes)
20
+ Target has shape (batch_size, num_samples)
21
+
22
+ Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
23
+ """
24
+
25
+ input = einops.rearrange(input, 'b s c -> (b s) c')
26
+ target = einops.rearrange(target, 'b s -> (b s)')
27
+
28
+ return self.loss(input, target)
29
+
30
+ def get_loss(cfg: ConfigRun):
31
+
32
+ if cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MSE:
33
+ return torch.nn.MSELoss()
34
+ elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MAE:
35
+ return torch.nn.L1Loss()
36
+ elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
37
+ return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
38
+ elif cfg.task == Task.CLASSIFICATION:
39
+ return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
40
+ else:
41
+ raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
42
+
43
+ def get_loss_pretrain(cfg: ConfigPretrain):
44
+
45
+ if cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MSE:
46
+ return torch.nn.MSELoss()
47
+ elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MAE:
48
+ return torch.nn.L1Loss()
49
+ elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.CROSS_ENTROPY:
50
+ return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
51
+ elif cfg.data.task == Task.CLASSIFICATION:
52
+ return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
53
+ else:
54
+ raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")