autogluon.tabular 1.5.0b20251228__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.0b20251228-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import os
4
5
  import time
5
6
  from pathlib import Path
6
- import contextlib
7
7
 
8
8
  import numpy as np
9
9
  import pandas as pd
@@ -18,34 +18,36 @@ from ._internal.models.tab2d import Tab2D
18
18
  from ._internal.utils.set_seed import set_seed
19
19
 
20
20
  # Hyperparameter search space
21
- DEFAULT_FINE_TUNE = True # [True, False]
22
- DEFAULT_FINE_TUNE_STEPS = 50 # [50, 60, 70, 80, 90, 100]
23
- DEFAULT_CLS_METRIC = 'log_loss' # ['log_loss', 'accuracy', 'auc']
24
- DEFAULT_REG_METRIC = 'mse' # ['mse', 'mae', 'rmse', 'r2']
25
- SHUFFLE_CLASSES = False # [True, False]
26
- SHUFFLE_FEATURES = False # [True, False]
27
- USE_RANDOM_TRANSFORMS = False # [True, False]
28
- RANDOM_MIRROR_REGRESSION = True # [True, False]
29
- RANDOM_MIRROR_X = True # [True, False]
30
- LR = 0.0001 # [0.00001, 0.000025, 0.00005, 0.000075, 0.0001, 0.00025, 0.0005, 0.00075, 0.001]
31
- PATIENCE = 40 # [30, 35, 40, 45, 50]
32
- WARMUP_STEPS = 1000 # [500, 750, 1000, 1250, 1500]
33
- DEFAULT_CLS_MODEL = 'autogluon/mitra-classifier'
34
- DEFAULT_REG_MODEL = 'autogluon/mitra-regressor'
21
+ DEFAULT_FINE_TUNE = True # [True, False]
22
+ DEFAULT_FINE_TUNE_STEPS = 50 # [50, 60, 70, 80, 90, 100]
23
+ DEFAULT_CLS_METRIC = "log_loss" # ['log_loss', 'accuracy', 'auc']
24
+ DEFAULT_REG_METRIC = "mse" # ['mse', 'mae', 'rmse', 'r2']
25
+ SHUFFLE_CLASSES = False # [True, False]
26
+ SHUFFLE_FEATURES = False # [True, False]
27
+ USE_RANDOM_TRANSFORMS = False # [True, False]
28
+ RANDOM_MIRROR_REGRESSION = True # [True, False]
29
+ RANDOM_MIRROR_X = True # [True, False]
30
+ LR = 0.0001 # [0.00001, 0.000025, 0.00005, 0.000075, 0.0001, 0.00025, 0.0005, 0.00075, 0.001]
31
+ PATIENCE = 40 # [30, 35, 40, 45, 50]
32
+ WARMUP_STEPS = 1000 # [500, 750, 1000, 1250, 1500]
33
+ DEFAULT_CLS_MODEL = "autogluon/mitra-classifier"
34
+ DEFAULT_REG_MODEL = "autogluon/mitra-regressor"
35
35
 
36
36
  # Constants
37
37
  SEED = 0
38
38
  DEFAULT_MODEL_TYPE = "Tab2D"
39
39
 
40
+
40
41
  def _get_default_device():
41
42
  """Get the best available device for the current system."""
42
43
  if torch.cuda.is_available():
43
44
  return "cuda"
44
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
45
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
45
46
  return "mps" # Apple silicon
46
47
  else:
47
48
  return "cpu"
48
49
 
50
+
49
51
  DEFAULT_DEVICE = _get_default_device()
50
52
  DEFAULT_ENSEMBLE = 1
51
53
  DEFAULT_DIM = 512
@@ -55,32 +57,34 @@ DEFAULT_CLASSES = 10
55
57
  DEFAULT_VALIDATION_SPLIT = 0.2
56
58
  USE_HF = True # Use Hugging Face pretrained models if available
57
59
 
60
+
58
61
  class MitraBase(BaseEstimator):
59
62
  """Base class for Mitra models with common functionality."""
60
63
 
61
- def __init__(self,
62
- model_type=DEFAULT_MODEL_TYPE,
63
- n_estimators=DEFAULT_ENSEMBLE,
64
- device=DEFAULT_DEVICE,
65
- fine_tune=DEFAULT_FINE_TUNE,
66
- fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
67
- metric=DEFAULT_CLS_METRIC,
68
- state_dict=None,
69
- hf_model=None,
70
- patience=PATIENCE,
71
- lr=LR,
72
- warmup_steps=WARMUP_STEPS,
73
- shuffle_classes=SHUFFLE_CLASSES,
74
- shuffle_features=SHUFFLE_FEATURES,
75
- use_random_transforms=USE_RANDOM_TRANSFORMS,
76
- random_mirror_regression=RANDOM_MIRROR_REGRESSION,
77
- random_mirror_x=RANDOM_MIRROR_X,
78
- seed=SEED,
79
- verbose=True,
80
- ):
64
+ def __init__(
65
+ self,
66
+ model_type=DEFAULT_MODEL_TYPE,
67
+ n_estimators=DEFAULT_ENSEMBLE,
68
+ device=DEFAULT_DEVICE,
69
+ fine_tune=DEFAULT_FINE_TUNE,
70
+ fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
71
+ metric=DEFAULT_CLS_METRIC,
72
+ state_dict=None,
73
+ hf_model=None,
74
+ patience=PATIENCE,
75
+ lr=LR,
76
+ warmup_steps=WARMUP_STEPS,
77
+ shuffle_classes=SHUFFLE_CLASSES,
78
+ shuffle_features=SHUFFLE_FEATURES,
79
+ use_random_transforms=USE_RANDOM_TRANSFORMS,
80
+ random_mirror_regression=RANDOM_MIRROR_REGRESSION,
81
+ random_mirror_x=RANDOM_MIRROR_X,
82
+ seed=SEED,
83
+ verbose=True,
84
+ ):
81
85
  """
82
86
  Initialize the base Mitra model.
83
-
87
+
84
88
  Parameters
85
89
  ----------
86
90
  model_type : str, default="Tab2D"
@@ -125,59 +129,62 @@ class MitraBase(BaseEstimator):
125
129
  model_name=ModelName.TAB2D,
126
130
  seed=self.seed,
127
131
  hyperparams={
128
- 'dim_embedding': None,
129
- 'early_stopping_data_split': 'VALID',
130
- 'early_stopping_max_samples': 2048,
131
- 'early_stopping_patience': self.patience,
132
- 'grad_scaler_enabled': False,
133
- 'grad_scaler_growth_interval': 1000,
134
- 'grad_scaler_scale_init': 65536.0,
135
- 'grad_scaler_scale_min': 65536.0,
136
- 'label_smoothing': 0.0,
137
- 'lr_scheduler': False,
138
- 'lr_scheduler_patience': 25,
139
- 'max_epochs': self.fine_tune_steps if self.fine_tune else 0,
140
- 'max_samples_query': 1024,
141
- 'max_samples_support': 8192,
142
- 'optimizer': 'adamw',
143
- 'lr': self.lr,
144
- 'weight_decay': 0.1,
145
- 'warmup_steps': self.warmup_steps,
146
- 'path_to_weights': self.state_dict,
147
- 'precision': 'bfloat16',
148
- 'random_mirror_regression': self.random_mirror_regression,
149
- 'random_mirror_x': self.random_mirror_x,
150
- 'shuffle_classes': self.shuffle_classes,
151
- 'shuffle_features': self.shuffle_features,
152
- 'use_random_transforms': self.use_random_transforms,
153
- 'use_feature_count_scaling': False,
154
- 'use_pretrained_weights': False,
155
- 'use_quantile_transformer': False,
156
- 'budget': time_limit,
157
- 'metric': self.metric,
132
+ "dim_embedding": None,
133
+ "early_stopping_data_split": "VALID",
134
+ "early_stopping_max_samples": 2048,
135
+ "early_stopping_patience": self.patience,
136
+ "grad_scaler_enabled": False,
137
+ "grad_scaler_growth_interval": 1000,
138
+ "grad_scaler_scale_init": 65536.0,
139
+ "grad_scaler_scale_min": 65536.0,
140
+ "label_smoothing": 0.0,
141
+ "lr_scheduler": False,
142
+ "lr_scheduler_patience": 25,
143
+ "max_epochs": self.fine_tune_steps if self.fine_tune else 0,
144
+ "max_samples_query": 1024,
145
+ "max_samples_support": 8192,
146
+ "optimizer": "adamw",
147
+ "lr": self.lr,
148
+ "weight_decay": 0.1,
149
+ "warmup_steps": self.warmup_steps,
150
+ "path_to_weights": self.state_dict,
151
+ "precision": "bfloat16",
152
+ "random_mirror_regression": self.random_mirror_regression,
153
+ "random_mirror_x": self.random_mirror_x,
154
+ "shuffle_classes": self.shuffle_classes,
155
+ "shuffle_features": self.shuffle_features,
156
+ "use_random_transforms": self.use_random_transforms,
157
+ "use_feature_count_scaling": False,
158
+ "use_pretrained_weights": False,
159
+ "use_quantile_transformer": False,
160
+ "budget": time_limit,
161
+ "metric": self.metric,
158
162
  },
159
163
  )
160
164
 
161
165
  cfg.task = task
162
- cfg.hyperparams.update({
163
- 'n_ensembles': self.n_estimators,
164
- 'dim': DEFAULT_DIM,
165
- 'dim_output': dim_output,
166
- 'n_layers': DEFAULT_LAYERS,
167
- 'n_heads': DEFAULT_HEADS,
168
- 'regression_loss': 'mse',
169
- })
166
+ cfg.hyperparams.update(
167
+ {
168
+ "n_ensembles": self.n_estimators,
169
+ "dim": DEFAULT_DIM,
170
+ "dim_output": dim_output,
171
+ "n_layers": DEFAULT_LAYERS,
172
+ "n_heads": DEFAULT_HEADS,
173
+ "regression_loss": "mse",
174
+ }
175
+ )
170
176
 
171
177
  return cfg, Tab2D
172
178
 
173
-
174
179
  def _split_data(self, X, y):
175
180
  """Split data into training and validation sets."""
176
- if hasattr(self, 'task') and self.task == 'classification':
181
+ if hasattr(self, "task") and self.task == "classification":
177
182
  return make_stratified_dataset_split(X, y, seed=self.seed)
178
183
  else:
179
184
  # For regression, use random split
180
- val_indices = np.random.choice(range(len(X)), int(DEFAULT_VALIDATION_SPLIT * len(X)), replace=False).tolist()
185
+ val_indices = np.random.choice(
186
+ range(len(X)), int(DEFAULT_VALIDATION_SPLIT * len(X)), replace=False
187
+ ).tolist()
181
188
  train_indices = [i for i in range(len(X)) if i not in val_indices]
182
189
  return X[train_indices], X[val_indices], y[train_indices], y[val_indices]
183
190
 
@@ -188,7 +195,9 @@ class MitraBase(BaseEstimator):
188
195
  rng = np.random.RandomState(cfg.seed)
189
196
 
190
197
  success = False
191
- while not (success and cfg.hyperparams["max_samples_support"] > 0 and cfg.hyperparams["max_samples_query"] > 0):
198
+ while not (
199
+ success and cfg.hyperparams["max_samples_support"] > 0 and cfg.hyperparams["max_samples_query"] > 0
200
+ ):
192
201
  try:
193
202
  self.trainers.clear()
194
203
 
@@ -199,16 +208,18 @@ class MitraBase(BaseEstimator):
199
208
  model = Tab2D.from_pretrained(self.hf_model, device=self.device)
200
209
  else:
201
210
  model = Tab2D(
202
- dim=cfg.hyperparams['dim'],
211
+ dim=cfg.hyperparams["dim"],
203
212
  dim_output=dim_output,
204
- n_layers=cfg.hyperparams['n_layers'],
205
- n_heads=cfg.hyperparams['n_heads'],
213
+ n_layers=cfg.hyperparams["n_layers"],
214
+ n_heads=cfg.hyperparams["n_heads"],
206
215
  task=task.upper(),
207
216
  use_pretrained_weights=True,
208
217
  path_to_weights=Path(self.state_dict),
209
218
  device=self.device,
210
219
  )
211
- trainer = TrainerFinetune(cfg, model, n_classes=n_classes, device=self.device, rng=rng, verbose=self.verbose)
220
+ trainer = TrainerFinetune(
221
+ cfg, model, n_classes=n_classes, device=self.device, rng=rng, verbose=self.verbose
222
+ )
212
223
 
213
224
  start_time = time.time()
214
225
  trainer.train(X_train, y_train, X_valid, y_valid)
@@ -221,27 +232,25 @@ class MitraBase(BaseEstimator):
221
232
 
222
233
  except torch.cuda.OutOfMemoryError:
223
234
  if cfg.hyperparams["max_samples_support"] >= 2048:
224
- cfg.hyperparams["max_samples_support"] = int(
225
- cfg.hyperparams["max_samples_support"] // 2
235
+ cfg.hyperparams["max_samples_support"] = int(cfg.hyperparams["max_samples_support"] // 2)
236
+ print(
237
+ f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
238
+ f"to {cfg.hyperparams['max_samples_support']} due to OOM error."
226
239
  )
227
- print(f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
228
- f"to {cfg.hyperparams['max_samples_support']} due to OOM error.")
229
240
  else:
230
- cfg.hyperparams["max_samples_support"] = int(
231
- cfg.hyperparams["max_samples_support"] // 2
241
+ cfg.hyperparams["max_samples_support"] = int(cfg.hyperparams["max_samples_support"] // 2)
242
+ print(
243
+ f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
244
+ f"to {cfg.hyperparams['max_samples_support']} due to OOM error."
232
245
  )
233
- print(f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
234
- f"to {cfg.hyperparams['max_samples_support']} due to OOM error.")
235
- cfg.hyperparams["max_samples_query"] = int(
236
- cfg.hyperparams["max_samples_query"] // 2
246
+ cfg.hyperparams["max_samples_query"] = int(cfg.hyperparams["max_samples_query"] // 2)
247
+ print(
248
+ f"Reducing max_samples_query from {cfg.hyperparams['max_samples_query'] * 2}"
249
+ f"to {cfg.hyperparams['max_samples_query']} due to OOM error."
237
250
  )
238
- print(f"Reducing max_samples_query from {cfg.hyperparams['max_samples_query'] * 2}"
239
- f"to {cfg.hyperparams['max_samples_query']} due to OOM error.")
240
251
 
241
252
  if not success:
242
- raise RuntimeError(
243
- "Failed to train Mitra model after multiple attempts due to out of memory error."
244
- )
253
+ raise RuntimeError("Failed to train Mitra model after multiple attempts due to out of memory error.")
245
254
 
246
255
  return self
247
256
 
@@ -249,26 +258,27 @@ class MitraBase(BaseEstimator):
249
258
  class MitraClassifier(MitraBase, ClassifierMixin):
250
259
  """Classifier implementation of Mitra model."""
251
260
 
252
- def __init__(self,
253
- model_type=DEFAULT_MODEL_TYPE,
254
- n_estimators=DEFAULT_ENSEMBLE,
255
- device=DEFAULT_DEVICE,
256
- fine_tune=DEFAULT_FINE_TUNE,
257
- fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
258
- metric=DEFAULT_CLS_METRIC,
259
- state_dict=None,
260
- hf_model=DEFAULT_CLS_MODEL,
261
- patience=PATIENCE,
262
- lr=LR,
263
- warmup_steps=WARMUP_STEPS,
264
- shuffle_classes=SHUFFLE_CLASSES,
265
- shuffle_features=SHUFFLE_FEATURES,
266
- use_random_transforms=USE_RANDOM_TRANSFORMS,
267
- random_mirror_regression=RANDOM_MIRROR_REGRESSION,
268
- random_mirror_x=RANDOM_MIRROR_X,
269
- seed=SEED,
270
- verbose=True,
271
- ):
261
+ def __init__(
262
+ self,
263
+ model_type=DEFAULT_MODEL_TYPE,
264
+ n_estimators=DEFAULT_ENSEMBLE,
265
+ device=DEFAULT_DEVICE,
266
+ fine_tune=DEFAULT_FINE_TUNE,
267
+ fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
268
+ metric=DEFAULT_CLS_METRIC,
269
+ state_dict=None,
270
+ hf_model=DEFAULT_CLS_MODEL,
271
+ patience=PATIENCE,
272
+ lr=LR,
273
+ warmup_steps=WARMUP_STEPS,
274
+ shuffle_classes=SHUFFLE_CLASSES,
275
+ shuffle_features=SHUFFLE_FEATURES,
276
+ use_random_transforms=USE_RANDOM_TRANSFORMS,
277
+ random_mirror_regression=RANDOM_MIRROR_REGRESSION,
278
+ random_mirror_x=RANDOM_MIRROR_X,
279
+ seed=SEED,
280
+ verbose=True,
281
+ ):
272
282
  """Initialize the classifier."""
273
283
  super().__init__(
274
284
  model_type,
@@ -290,19 +300,19 @@ class MitraClassifier(MitraBase, ClassifierMixin):
290
300
  seed=seed,
291
301
  verbose=verbose,
292
302
  )
293
- self.task = 'classification'
303
+ self.task = "classification"
294
304
 
295
- def fit(self, X, y, X_val = None, y_val = None, time_limit = None):
305
+ def fit(self, X, y, X_val=None, y_val=None, time_limit=None):
296
306
  """
297
307
  Fit the ensemble of models.
298
-
308
+
299
309
  Parameters
300
310
  ----------
301
311
  X : array-like of shape (n_samples, n_features)
302
312
  Training data
303
313
  y : array-like of shape (n_samples,)
304
314
  Target values
305
-
315
+
306
316
  Returns
307
317
  -------
308
318
  self : object
@@ -310,7 +320,6 @@ class MitraClassifier(MitraBase, ClassifierMixin):
310
320
  """
311
321
 
312
322
  with mitra_deterministic_context():
313
-
314
323
  if isinstance(X, pd.DataFrame):
315
324
  X = X.values
316
325
  if isinstance(y, pd.Series):
@@ -327,17 +336,26 @@ class MitraClassifier(MitraBase, ClassifierMixin):
327
336
  else:
328
337
  X_train, X_valid, y_train, y_valid = self._split_data(X, y)
329
338
 
330
- return self._train_ensemble(X_train, y_train, X_valid, y_valid, self.task, DEFAULT_CLASSES, n_classes=DEFAULT_CLASSES, time_limit=time_limit)
339
+ return self._train_ensemble(
340
+ X_train,
341
+ y_train,
342
+ X_valid,
343
+ y_valid,
344
+ self.task,
345
+ DEFAULT_CLASSES,
346
+ n_classes=DEFAULT_CLASSES,
347
+ time_limit=time_limit,
348
+ )
331
349
 
332
350
  def predict(self, X):
333
351
  """
334
352
  Predict class labels for samples in X.
335
-
353
+
336
354
  Parameters
337
355
  ----------
338
356
  X : array-like of shape (n_samples, n_features)
339
357
  The input samples
340
-
358
+
341
359
  Returns
342
360
  -------
343
361
  y : ndarray of shape (n_samples,)
@@ -352,12 +370,12 @@ class MitraClassifier(MitraBase, ClassifierMixin):
352
370
  def predict_proba(self, X):
353
371
  """
354
372
  Predict class probabilities for samples in X.
355
-
373
+
356
374
  Parameters
357
375
  ----------
358
376
  X : array-like of shape (n_samples, n_features)
359
377
  The input samples
360
-
378
+
361
379
  Returns
362
380
  -------
363
381
  p : ndarray of shape (n_samples, n_classes)
@@ -365,14 +383,13 @@ class MitraClassifier(MitraBase, ClassifierMixin):
365
383
  """
366
384
 
367
385
  with mitra_deterministic_context():
368
-
369
386
  if isinstance(X, pd.DataFrame):
370
387
  X = X.values
371
388
 
372
389
  preds = []
373
390
  for trainer in self.trainers:
374
- logits = trainer.predict(self.X, self.y, X)[...,:len(np.unique(self.y))] # Remove extra classes
375
- preds.append(np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)) # Softmax
391
+ logits = trainer.predict(self.X, self.y, X)[..., : len(np.unique(self.y))] # Remove extra classes
392
+ preds.append(np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)) # Softmax
376
393
  preds = sum(preds) / len(preds) # Averaging ensemble predictions
377
394
 
378
395
  return preds
@@ -381,26 +398,27 @@ class MitraClassifier(MitraBase, ClassifierMixin):
381
398
  class MitraRegressor(MitraBase, RegressorMixin):
382
399
  """Regressor implementation of Mitra model."""
383
400
 
384
- def __init__(self,
385
- model_type=DEFAULT_MODEL_TYPE,
386
- n_estimators=DEFAULT_ENSEMBLE,
387
- device=DEFAULT_DEVICE,
388
- fine_tune=DEFAULT_FINE_TUNE,
389
- fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
390
- metric=DEFAULT_REG_METRIC,
391
- state_dict=None,
392
- hf_model=DEFAULT_REG_MODEL,
393
- patience=PATIENCE,
394
- lr=LR,
395
- warmup_steps=WARMUP_STEPS,
396
- shuffle_classes=SHUFFLE_CLASSES,
397
- shuffle_features=SHUFFLE_FEATURES,
398
- use_random_transforms=USE_RANDOM_TRANSFORMS,
399
- random_mirror_regression=RANDOM_MIRROR_REGRESSION,
400
- random_mirror_x=RANDOM_MIRROR_X,
401
- seed=SEED,
402
- verbose=True,
403
- ):
401
+ def __init__(
402
+ self,
403
+ model_type=DEFAULT_MODEL_TYPE,
404
+ n_estimators=DEFAULT_ENSEMBLE,
405
+ device=DEFAULT_DEVICE,
406
+ fine_tune=DEFAULT_FINE_TUNE,
407
+ fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
408
+ metric=DEFAULT_REG_METRIC,
409
+ state_dict=None,
410
+ hf_model=DEFAULT_REG_MODEL,
411
+ patience=PATIENCE,
412
+ lr=LR,
413
+ warmup_steps=WARMUP_STEPS,
414
+ shuffle_classes=SHUFFLE_CLASSES,
415
+ shuffle_features=SHUFFLE_FEATURES,
416
+ use_random_transforms=USE_RANDOM_TRANSFORMS,
417
+ random_mirror_regression=RANDOM_MIRROR_REGRESSION,
418
+ random_mirror_x=RANDOM_MIRROR_X,
419
+ seed=SEED,
420
+ verbose=True,
421
+ ):
404
422
  """Initialize the regressor."""
405
423
  super().__init__(
406
424
  model_type,
@@ -422,19 +440,19 @@ class MitraRegressor(MitraBase, RegressorMixin):
422
440
  seed=seed,
423
441
  verbose=verbose,
424
442
  )
425
- self.task = 'regression'
443
+ self.task = "regression"
426
444
 
427
- def fit(self, X, y, X_val = None, y_val = None, time_limit = None):
445
+ def fit(self, X, y, X_val=None, y_val=None, time_limit=None):
428
446
  """
429
447
  Fit the ensemble of models.
430
-
448
+
431
449
  Parameters
432
450
  ----------
433
451
  X : array-like of shape (n_samples, n_features)
434
452
  Training data
435
453
  y : array-like of shape (n_samples,)
436
454
  Target values
437
-
455
+
438
456
  Returns
439
457
  -------
440
458
  self : object
@@ -442,7 +460,6 @@ class MitraRegressor(MitraBase, RegressorMixin):
442
460
  """
443
461
 
444
462
  with mitra_deterministic_context():
445
-
446
463
  if isinstance(X, pd.DataFrame):
447
464
  X = X.values
448
465
  if isinstance(y, pd.Series):
@@ -464,12 +481,12 @@ class MitraRegressor(MitraBase, RegressorMixin):
464
481
  def predict(self, X):
465
482
  """
466
483
  Predict regression target for samples in X.
467
-
484
+
468
485
  Parameters
469
486
  ----------
470
487
  X : array-like of shape (n_samples, n_features)
471
488
  The input samples
472
-
489
+
473
490
  Returns
474
491
  -------
475
492
  y : ndarray of shape (n_samples,)
@@ -477,16 +494,15 @@ class MitraRegressor(MitraBase, RegressorMixin):
477
494
  """
478
495
 
479
496
  with mitra_deterministic_context():
480
-
481
497
  if isinstance(X, pd.DataFrame):
482
498
  X = X.values
483
-
499
+
484
500
  preds = []
485
501
  for trainer in self.trainers:
486
502
  preds.append(trainer.predict(self.X, self.y, X))
487
-
503
+
488
504
  return sum(preds) / len(preds) # Averaging ensemble predictions
489
-
505
+
490
506
 
491
507
  @contextlib.contextmanager
492
508
  def mitra_deterministic_context():