autogluon.tabular 1.3.2b20250714__py3-none-any.whl → 1.3.2b20250716__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. autogluon/tabular/models/catboost/catboost_model.py +9 -6
  2. autogluon/tabular/models/catboost/catboost_utils.py +10 -0
  3. autogluon/tabular/models/lgb/lgb_model.py +2 -1
  4. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  5. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  6. autogluon/tabular/models/mitra/_internal/config/config_run.py +3 -3
  7. autogluon/tabular/models/mitra/_internal/config/enums.py +20 -3
  8. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  9. autogluon/tabular/models/mitra/_internal/core/get_loss.py +22 -23
  10. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +11 -13
  11. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +69 -75
  12. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  13. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +57 -57
  14. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  15. autogluon/tabular/models/mitra/_internal/models/tab2d.py +23 -26
  16. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  17. autogluon/tabular/models/mitra/mitra_model.py +64 -24
  18. autogluon/tabular/models/mitra/sklearn_interface.py +52 -42
  19. autogluon/tabular/models/realmlp/realmlp_model.py +11 -3
  20. autogluon/tabular/models/tabicl/tabicl_model.py +4 -1
  21. autogluon/tabular/models/tabm/_tabm_internal.py +4 -3
  22. autogluon/tabular/models/tabm/tabm_model.py +7 -3
  23. autogluon/tabular/models/tabm/tabm_reference.py +21 -19
  24. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +10 -9
  25. autogluon/tabular/testing/fit_helper.py +2 -2
  26. autogluon/tabular/version.py +1 -1
  27. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/METADATA +11 -11
  28. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/RECORD +35 -29
  29. /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250716-py3.9-nspkg.pth +0 -0
  30. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/LICENSE +0 -0
  31. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/NOTICE +0 -0
  32. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/WHEEL +0 -0
  33. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/namespace_packages.txt +0 -0
  34. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/top_level.txt +0 -0
  35. {autogluon.tabular-1.3.2b20250714.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/zip-safe +0 -0
@@ -1,23 +1,25 @@
1
- from typing import Optional, Self
2
-
3
1
  import random
2
+ from typing import Optional
3
+
4
4
  import numpy as np
5
5
  from loguru import logger
6
- from sklearn.feature_selection import SelectKBest
7
- from sklearn.preprocessing import QuantileTransformer, StandardScaler, OrdinalEncoder
6
+ from sklearn.base import BaseEstimator, TransformerMixin
8
7
  from sklearn.compose import ColumnTransformer
9
8
  from sklearn.decomposition import TruncatedSVD
10
- from sklearn.pipeline import Pipeline, FeatureUnion
11
- from sklearn.base import BaseEstimator, TransformerMixin
9
+ from sklearn.feature_selection import SelectKBest
10
+ from sklearn.pipeline import FeatureUnion, Pipeline
11
+ from sklearn.preprocessing import (OrdinalEncoder, QuantileTransformer,
12
+ StandardScaler)
12
13
 
13
14
  from ..._internal.config.enums import Task
14
15
 
16
+
15
17
  class NoneTransformer(BaseEstimator, TransformerMixin):
16
18
  def fit(self, X, y=None):
17
19
  return self
18
20
  def transform(self, X):
19
21
  return X
20
-
22
+
21
23
  class Preprocessor():
22
24
  """
23
25
  This class is used to preprocess the data before it is pushed through the model.
@@ -28,9 +30,9 @@ class Preprocessor():
28
30
  """
29
31
 
30
32
  def __init__(
31
- self,
33
+ self,
32
34
  dim_embedding: Optional[int], # Size of the feature embedding. For some models this is None, which means the embedding does not depend on the number of features
33
- n_classes: int, # Actual number of classes in the dataset, assumed to be numbered 0, ..., n_classes - 1
35
+ n_classes: int, # Actual number of classes in the dataset, assumed to be numbered 0, ..., n_classes - 1
34
36
  dim_output: int, # Maximum number of classes the model has been trained on -> size of the output
35
37
  use_quantile_transformer: bool,
36
38
  use_feature_count_scaling: bool,
@@ -53,8 +55,8 @@ class Preprocessor():
53
55
  self.random_mirror_regression = random_mirror_regression
54
56
  self.random_mirror_x = random_mirror_x
55
57
  self.task = task
56
-
57
- def fit(self, X: np.ndarray, y: np.ndarray) -> Self:
58
+
59
+ def fit(self, X: np.ndarray, y: np.ndarray) -> "Preprocessor":
58
60
  """
59
61
  X: np.ndarray [n_samples, n_features]
60
62
  y: np.ndarray [n_samples]
@@ -78,16 +80,16 @@ class Preprocessor():
78
80
  if self.use_quantile_transformer:
79
81
  # If use quantile transform is off, it means that the preprocessing will happen on the GPU.
80
82
  X = self.fit_transform_quantile_transformer(X)
81
-
83
+
82
84
  self.mean, self.std = self.calc_mean_std(X)
83
85
  X = self.normalize_by_mean_std(X, self.mean, self.std)
84
-
86
+
85
87
  if self.use_random_transforms:
86
88
  X = self.transform_tabpfn(X)
87
89
 
88
90
  if self.task == Task.CLASSIFICATION and self.shuffle_classes:
89
91
  self.determine_shuffle_class_order()
90
-
92
+
91
93
  if self.shuffle_features:
92
94
  self.determine_feature_order(X)
93
95
 
@@ -104,7 +106,7 @@ class Preprocessor():
104
106
  X[np.isinf(X)] = 0
105
107
 
106
108
  return self
107
-
109
+
108
110
 
109
111
  def transform_X(self, X: np.ndarray):
110
112
 
@@ -116,12 +118,12 @@ class Preprocessor():
116
118
  # If use quantile transform is off, it means that the preprocessing will happen on the GPU.
117
119
 
118
120
  X = self.quantile_transformer.transform(X)
119
-
121
+
120
122
  X = self.normalize_by_mean_std(X, self.mean, self.std)
121
123
 
122
124
  if self.use_feature_count_scaling:
123
125
  X = self.normalize_by_feature_count(X)
124
-
126
+
125
127
  if self.use_random_transforms:
126
128
  X = self.random_transforms.transform(X)
127
129
 
@@ -140,11 +142,11 @@ class Preprocessor():
140
142
 
141
143
 
142
144
  def transform_tabpfn(self, X: np.ndarray):
143
-
145
+
144
146
  n_samples = X.shape[0]
145
147
  n_features = X.shape[1]
146
-
147
- use_config1 = random.random() < 0.5
148
+
149
+ use_config1 = random.random() < 0.5
148
150
  random_state = random.randint(0, 1000000)
149
151
 
150
152
  if use_config1:
@@ -171,12 +173,12 @@ class Preprocessor():
171
173
  ('ordinal', OrdinalEncoder(
172
174
  handle_unknown="use_encoded_value",
173
175
  unknown_value=np.nan
174
- ), [])
176
+ ), [])
175
177
  ], remainder='passthrough')
176
-
178
+
177
179
  return self.random_transforms.fit_transform(X)
178
-
179
-
180
+
181
+
180
182
  def transform_y(self, y: np.ndarray):
181
183
 
182
184
  if self.task == Task.CLASSIFICATION:
@@ -193,36 +195,34 @@ class Preprocessor():
193
195
  if self.task == Task.REGRESSION and self.random_mirror_regression:
194
196
  y = self.apply_random_mirror_regression(y)
195
197
 
196
- match self.task:
197
- case Task.CLASSIFICATION:
198
- y = y.astype(np.int64)
199
- case Task.REGRESSION:
200
- y = y.astype(np.float32)
198
+ if self.task == Task.CLASSIFICATION:
199
+ y = y.astype(np.int64)
200
+ elif self.task == Task.REGRESSION:
201
+ y = y.astype(np.float32)
201
202
 
202
203
  return y
203
-
204
+
204
205
 
205
206
  def inverse_transform_y(self, y: np.ndarray):
206
207
  # Function used during the prediction to transform the model output back to the original space
207
208
  # For classification, y is assumed to be logits of shape [n_samples, n_classes]
208
209
 
209
- match self.task:
210
- case Task.CLASSIFICATION:
211
- y = self.extract_correct_classes(y)
210
+ if self.task == Task.CLASSIFICATION:
211
+ y = self.extract_correct_classes(y)
212
212
 
213
- if self.shuffle_classes:
214
- y = self.undo_randomize_class_order(y)
213
+ if self.shuffle_classes:
214
+ y = self.undo_randomize_class_order(y)
215
215
 
216
- case Task.REGRESSION:
216
+ elif self.task == Task.REGRESSION:
217
217
 
218
- if self.random_mirror_regression:
219
- y = self.apply_random_mirror_regression(y)
218
+ if self.random_mirror_regression:
219
+ y = self.apply_random_mirror_regression(y)
220
220
 
221
- y = self.undo_normalize_y(y)
221
+ y = self.undo_normalize_y(y)
222
222
 
223
223
  return y
224
224
 
225
-
225
+
226
226
 
227
227
  def fit_transform_quantile_transformer(self, X: np.ndarray) -> np.ndarray:
228
228
 
@@ -233,12 +233,12 @@ class Preprocessor():
233
233
 
234
234
  return X
235
235
 
236
-
236
+
237
237
 
238
238
  def determine_which_features_are_singular(self, x: np.ndarray) -> None:
239
239
 
240
240
  self.singular_features = np.array([ len(np.unique(x_col)) for x_col in x.T ]) == 1
241
-
241
+
242
242
 
243
243
 
244
244
  def determine_which_features_to_select(self, x: np.ndarray, y: np.ndarray) -> None:
@@ -267,7 +267,7 @@ class Preprocessor():
267
267
  x[inds] = np.take(self.pre_nan_mean, inds[1])
268
268
  return x
269
269
 
270
-
270
+
271
271
  def select_features(self, x: np.ndarray) -> np.ndarray:
272
272
 
273
273
  if self.dim_embedding is None:
@@ -278,7 +278,7 @@ class Preprocessor():
278
278
  x = self.select_k_best.transform(x)
279
279
 
280
280
  return x
281
-
281
+
282
282
 
283
283
  def cutoff_singular_features(self, x: np.ndarray, singular_features: np.ndarray) -> np.ndarray:
284
284
 
@@ -295,7 +295,7 @@ class Preprocessor():
295
295
  mean = x.mean(axis=0)
296
296
  std = x.std(axis=0) + 1e-6
297
297
  return mean, std
298
-
298
+
299
299
 
300
300
  def normalize_by_mean_std(self, x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
301
301
  """
@@ -329,23 +329,23 @@ class Preprocessor():
329
329
  added_zeros = np.zeros((x.shape[0], dim_embedding - x.shape[1]), dtype=np.float32)
330
330
  x = np.concatenate([x, added_zeros], axis=1)
331
331
  return x
332
-
332
+
333
333
 
334
334
  def determine_mix_max_scale(self, y: np.ndarray) -> None:
335
335
  self.y_min = y.min()
336
336
  self.y_max = y.max()
337
337
  assert self.y_min != self.y_max, "y_min and y_max are the same, cannot normalize, regression makes no sense"
338
338
 
339
-
339
+
340
340
  def normalize_y(self, y: np.ndarray) -> np.ndarray:
341
341
  y = (y - self.y_min) / (self.y_max - self.y_min)
342
342
  return y
343
-
343
+
344
344
 
345
345
  def undo_normalize_y(self, y: np.ndarray) -> np.ndarray:
346
346
  y = y * (self.y_max - self.y_min) + self.y_min
347
347
  return y
348
-
348
+
349
349
 
350
350
  def determine_regression_mirror(self) -> None:
351
351
  self.regression_mirror = np.random.choice([True, False], size=(1,)).item()
@@ -355,7 +355,7 @@ class Preprocessor():
355
355
  if self.regression_mirror:
356
356
  y = 1 - y
357
357
  return y
358
-
358
+
359
359
 
360
360
  def determine_mirror(self, x: np.ndarray) -> None:
361
361
 
@@ -376,15 +376,15 @@ class Preprocessor():
376
376
  else:
377
377
  self.new_shuffle_classes = np.arange(self.n_classes)
378
378
 
379
-
379
+
380
380
  def randomize_class_order(self, y: np.ndarray) -> np.ndarray:
381
381
 
382
382
  mapping = { i: self.new_shuffle_classes[i] for i in range(self.n_classes) }
383
383
  y = np.array([mapping[i.item()] for i in y], dtype=np.int64)
384
384
 
385
- return y
386
-
387
-
385
+ return y
386
+
387
+
388
388
  def undo_randomize_class_order(self, y_logits: np.ndarray) -> np.ndarray:
389
389
  """
390
390
  We assume y_logits has shape [n_samples, n_classes]
@@ -393,9 +393,9 @@ class Preprocessor():
393
393
  # mapping = {self.new_shuffle_classes[i]: i for i in range(self.n_classes)}
394
394
  mapping = {i: self.new_shuffle_classes[i] for i in range(self.n_classes)}
395
395
  y = np.concatenate([y_logits[:, mapping[i]:mapping[i]+1] for i in range(self.n_classes)], axis=1)
396
-
396
+
397
397
  return y
398
-
398
+
399
399
 
400
400
  def extract_correct_classes(self, y_logits: np.ndarray) -> np.ndarray:
401
401
  # Even though our network might be able to support 10 classes,
@@ -417,4 +417,4 @@ class Preprocessor():
417
417
 
418
418
  x = x[:, self.new_feature_order]
419
419
 
420
- return x
420
+ return x
@@ -0,0 +1 @@
1
+ # Model architecture modules for MitraModel
@@ -1,3 +1,5 @@
1
+ import json
2
+ import os
1
3
  from typing import Optional, Union
2
4
 
3
5
  import einops
@@ -5,11 +7,8 @@ import einx
5
7
  import torch
6
8
  import torch.nn as nn
7
9
  import torch.nn.functional as F
8
- from safetensors.torch import save_file
9
10
  from huggingface_hub import hf_hub_download
10
- from safetensors.torch import load_file
11
- import os
12
- import json
11
+ from safetensors.torch import load_file, save_file
13
12
 
14
13
  # Try to import flash attention, but make it optional
15
14
  try:
@@ -24,8 +23,8 @@ from torch.utils.checkpoint import checkpoint
24
23
  from ..._internal.config.enums import Task
25
24
  from ..._internal.models.base import BaseModel
26
25
  from ..._internal.models.embedding import (
27
- Tab2DEmbeddingX,
28
- Tab2DEmbeddingYClasses,
26
+ Tab2DEmbeddingX,
27
+ Tab2DEmbeddingYClasses,
29
28
  Tab2DEmbeddingYRegression,
30
29
  Tab2DQuantileEmbeddingX,
31
30
  )
@@ -64,16 +63,15 @@ class Tab2D(BaseModel):
64
63
  self.x_embedding = Tab2DEmbeddingX(dim)
65
64
 
66
65
 
67
- match self.task:
68
- case Task.CLASSIFICATION:
69
- self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
70
- case Task.REGRESSION:
71
- if self.dim_output == 1:
72
- self.y_embedding = Tab2DEmbeddingYRegression(dim)
73
- else:
74
- self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output)
75
- case _:
76
- raise ValueError(f"Task {task} not supported")
66
+ if self.task == Task.CLASSIFICATION:
67
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
68
+ elif self.task == Task.REGRESSION:
69
+ if self.dim_output == 1:
70
+ self.y_embedding = Tab2DEmbeddingYRegression(dim)
71
+ else:
72
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output)
73
+ else:
74
+ raise ValueError(f"Task {task} not supported")
77
75
 
78
76
  self.layers = nn.ModuleList()
79
77
 
@@ -165,18 +163,17 @@ class Tab2D(BaseModel):
165
163
 
166
164
  y_query__, x_query__ = einops.unpack(query__, pack_query__, 'b s * c') # (b, n_q, 1, c), (b, n_q, f, c)
167
165
 
168
- match self.task:
166
+ if self.task == Task.REGRESSION:
169
167
  # output has shape (batch_size, n_observations_query, n_features, n_classes)
170
168
  # we want to remove the n_features dimension, and for regression, the n_classes dimension
171
- case Task.REGRESSION:
172
- if self.dim_output == 1:
173
- y_query__ = y_query__[:, :, 0, 0]
174
- else:
175
- y_query__ = y_query__[:, :, 0, :]
176
- case Task.CLASSIFICATION:
169
+ if self.dim_output == 1:
170
+ y_query__ = y_query__[:, :, 0, 0]
171
+ else:
177
172
  y_query__ = y_query__[:, :, 0, :]
178
- case _:
179
- raise ValueError(f"Task {self.task} not supported")
173
+ elif self.task == Task.CLASSIFICATION:
174
+ y_query__ = y_query__[:, :, 0, :]
175
+ else:
176
+ raise ValueError(f"Task {self.task} not supported")
180
177
 
181
178
  return y_query__
182
179
 
@@ -664,4 +661,4 @@ class MultiheadAttention(torch.nn.Module):
664
661
 
665
662
  output = self.o(output)
666
663
 
667
- return output
664
+ return output
@@ -0,0 +1 @@
1
+ # Utility modules for MitraModel
@@ -1,8 +1,12 @@
1
+ import os
2
+ from typing import List, Optional
3
+
1
4
  import pandas as pd
2
- from typing import Optional, List
5
+ import torch
6
+
3
7
  from autogluon.common.utils.resource_utils import ResourceManager
4
8
  from autogluon.core.models import AbstractModel
5
- import os
9
+
6
10
 
7
11
  # TODO: Needs memory usage estimate method
8
12
  class MitraModel(AbstractModel):
@@ -67,12 +71,12 @@ class MitraModel(AbstractModel):
67
71
 
68
72
  def _set_default_params(self):
69
73
  default_params = {
70
- "device": "cuda", # "cpu"
74
+ "device": "cpu",
71
75
  "n_estimators": 1,
72
76
  }
73
77
  for param, val in default_params.items():
74
78
  self._set_default_param_value(param, val)
75
-
79
+
76
80
  def _get_default_auxiliary_params(self) -> dict:
77
81
  default_auxiliary_params = super()._get_default_auxiliary_params()
78
82
  default_auxiliary_params.update(
@@ -87,7 +91,7 @@ class MitraModel(AbstractModel):
87
91
  @property
88
92
  def weights_path(self) -> str:
89
93
  return os.path.join(self.path, self.weights_file_name)
90
-
94
+
91
95
  def save(self, path: str = None, verbose=True) -> str:
92
96
  _model_weights_list = None
93
97
  if self.model is not None:
@@ -98,7 +102,7 @@ class MitraModel(AbstractModel):
98
102
  self.model.trainers[i].model = None
99
103
  self.model.trainers[i].optimizer = None
100
104
  self.model.trainers[i].scheduler_warmup = None
101
- self.model.trainers[i].scheduler_reduce_on_plateau = None
105
+ self.model.trainers[i].scheduler_reduce_on_plateau = None
102
106
  self._weights_saved = True
103
107
  path = super().save(path=path, verbose=verbose)
104
108
  if _model_weights_list is not None:
@@ -108,7 +112,7 @@ class MitraModel(AbstractModel):
108
112
  for i in range(len(self.model.trainers)):
109
113
  self.model.trainers[i].model = _model_weights_list[i]
110
114
  return path
111
-
115
+
112
116
  @classmethod
113
117
  def load(cls, path: str, reset_paths=False, verbose=True):
114
118
  model: MitraModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
@@ -136,14 +140,20 @@ class MitraModel(AbstractModel):
136
140
  return default_ag_args_ensemble
137
141
 
138
142
  def _get_default_resources(self) -> tuple[int, int]:
139
- # logical=False is faster in training
140
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
141
- num_gpus = 1
143
+ # Use only physical cores for better performance based on benchmarks
144
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
145
+
146
+ # Only request GPU if CUDA is available
147
+ if torch.cuda.is_available():
148
+ num_gpus = 1
149
+ else:
150
+ num_gpus = 0
151
+
142
152
  return num_cpus, num_gpus
143
153
 
144
154
  def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
145
155
  return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, **kwargs)
146
-
156
+
147
157
  @classmethod
148
158
  def _estimate_memory_usage_static(
149
159
  cls,
@@ -157,7 +167,7 @@ class MitraModel(AbstractModel):
157
167
  cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
158
168
  cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
159
169
  )
160
-
170
+
161
171
  @classmethod
162
172
  def _estimate_memory_usage_static_cpu_icl(
163
173
  cls,
@@ -165,10 +175,18 @@ class MitraModel(AbstractModel):
165
175
  X: pd.DataFrame,
166
176
  **kwargs,
167
177
  ) -> int:
168
- cpu_memory_kb = 1.3 * (0.001748 * (X.shape[0]**2) * X.shape[1] + \
169
- 0.001206 * X.shape[0] * (X.shape[1]**2) + \
170
- 10.3482 * X.shape[0] * X.shape[1] + \
171
- 6409698)
178
+ rows, features = X.shape[0], X.shape[1]
179
+
180
+ # For very small datasets, use a more conservative estimate
181
+ if rows * features < 100: # Small dataset threshold
182
+ # Use a simpler linear formula for small datasets
183
+ cpu_memory_kb = 1.3 * (100 * rows * features + 1000000) # 1GB base + linear scaling
184
+ else:
185
+ # Original formula for larger datasets
186
+ cpu_memory_kb = 1.3 * (0.001748 * (rows**2) * features + \
187
+ 0.001206 * rows * (features**2) + \
188
+ 10.3482 * rows * features + \
189
+ 6409698)
172
190
  return int(cpu_memory_kb * 1e3)
173
191
 
174
192
  @classmethod
@@ -178,12 +196,20 @@ class MitraModel(AbstractModel):
178
196
  X: pd.DataFrame,
179
197
  **kwargs,
180
198
  ) -> int:
181
- cpu_memory_kb = 1.3 * (0.001 * (X.shape[0]**2) * X.shape[1] + \
182
- 0.004541 * X.shape[0] * (X.shape[1]**2) + \
183
- 46.2974 * X.shape[0] * X.shape[1] + \
184
- 5605681)
199
+ rows, features = X.shape[0], X.shape[1]
200
+
201
+ # For very small datasets, use a more conservative estimate
202
+ if rows * features < 100: # Small dataset threshold
203
+ # Use a simpler linear formula for small datasets
204
+ cpu_memory_kb = 1.3 * (200 * rows * features + 2000000) # 2GB base + linear scaling
205
+ else:
206
+ # Original formula for larger datasets
207
+ cpu_memory_kb = 1.3 * (0.001 * (rows**2) * features + \
208
+ 0.004541 * rows * (features**2) + \
209
+ 46.2974 * rows * features + \
210
+ 5605681)
185
211
  return int(cpu_memory_kb * 1e3)
186
-
212
+
187
213
  @classmethod
188
214
  def _estimate_memory_usage_static_gpu_cpu(
189
215
  cls,
@@ -191,7 +217,13 @@ class MitraModel(AbstractModel):
191
217
  X: pd.DataFrame,
192
218
  **kwargs,
193
219
  ) -> int:
194
- return int(5 * 1e9)
220
+ rows, features = X.shape[0], X.shape[1]
221
+
222
+ # For very small datasets, use a more conservative estimate
223
+ if rows * features < 100: # Small dataset threshold
224
+ return int(2.5 * 1e9) # 2.5GB for small datasets
225
+ else:
226
+ return int(5 * 1e9) # 5GB for larger datasets
195
227
 
196
228
  @classmethod
197
229
  def _estimate_memory_usage_static_gpu_gpu(
@@ -200,7 +232,15 @@ class MitraModel(AbstractModel):
200
232
  X: pd.DataFrame,
201
233
  **kwargs,
202
234
  ) -> int:
203
- gpu_memory_mb = 1.3 * (0.05676 * X.shape[0] * X.shape[1] + 3901)
235
+ rows, features = X.shape[0], X.shape[1]
236
+
237
+ # For very small datasets, use a more conservative estimate
238
+ if rows * features < 100: # Small dataset threshold
239
+ # Use a simpler linear formula for small datasets
240
+ gpu_memory_mb = 1.3 * (10 * rows * features + 2000) # 2GB base + linear scaling
241
+ else:
242
+ # Original formula for larger datasets
243
+ gpu_memory_mb = 1.3 * (0.05676 * rows * features + 3901)
204
244
  return int(gpu_memory_mb * 1e6)
205
245
 
206
246
  @classmethod
@@ -208,7 +248,7 @@ class MitraModel(AbstractModel):
208
248
  return {
209
249
  "can_estimate_memory_usage_static": True,
210
250
  }
211
-
251
+
212
252
  def _more_tags(self) -> dict:
213
253
  tags = {"can_refit_full": True}
214
254
  return tags