autogluon.tabular 1.3.2b20250715__py3-none-any.whl → 1.3.2b20250717__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  2. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  3. autogluon/tabular/models/mitra/_internal/config/config_run.py +3 -3
  4. autogluon/tabular/models/mitra/_internal/config/enums.py +19 -2
  5. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  6. autogluon/tabular/models/mitra/_internal/core/get_loss.py +22 -23
  7. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +10 -12
  8. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +68 -74
  9. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  10. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +56 -56
  11. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  12. autogluon/tabular/models/mitra/_internal/models/tab2d.py +22 -25
  13. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  14. autogluon/tabular/models/mitra/mitra_model.py +80 -24
  15. autogluon/tabular/models/mitra/sklearn_interface.py +121 -80
  16. autogluon/tabular/models/realmlp/realmlp_model.py +11 -3
  17. autogluon/tabular/models/tabicl/tabicl_model.py +3 -1
  18. autogluon/tabular/models/tabm/_tabm_internal.py +4 -3
  19. autogluon/tabular/models/tabm/tabm_model.py +6 -3
  20. autogluon/tabular/models/tabm/tabm_reference.py +21 -19
  21. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +10 -9
  22. autogluon/tabular/version.py +1 -1
  23. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/METADATA +10 -10
  24. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/RECORD +31 -25
  25. /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250717-py3.9-nspkg.pth +0 -0
  26. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/LICENSE +0 -0
  27. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/NOTICE +0 -0
  28. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/WHEEL +0 -0
  29. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/namespace_packages.txt +0 -0
  30. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/top_level.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250717.dist-info}/zip-safe +0 -0
@@ -1,23 +1,25 @@
1
- from typing import Optional, Self
2
-
3
1
  import random
2
+ from typing import Optional
3
+
4
4
  import numpy as np
5
5
  from loguru import logger
6
- from sklearn.feature_selection import SelectKBest
7
- from sklearn.preprocessing import QuantileTransformer, StandardScaler, OrdinalEncoder
6
+ from sklearn.base import BaseEstimator, TransformerMixin
8
7
  from sklearn.compose import ColumnTransformer
9
8
  from sklearn.decomposition import TruncatedSVD
10
- from sklearn.pipeline import Pipeline, FeatureUnion
11
- from sklearn.base import BaseEstimator, TransformerMixin
9
+ from sklearn.feature_selection import SelectKBest
10
+ from sklearn.pipeline import FeatureUnion, Pipeline
11
+ from sklearn.preprocessing import (OrdinalEncoder, QuantileTransformer,
12
+ StandardScaler)
12
13
 
13
14
  from ..._internal.config.enums import Task
14
15
 
16
+
15
17
  class NoneTransformer(BaseEstimator, TransformerMixin):
16
18
  def fit(self, X, y=None):
17
19
  return self
18
20
  def transform(self, X):
19
21
  return X
20
-
22
+
21
23
  class Preprocessor():
22
24
  """
23
25
  This class is used to preprocess the data before it is pushed through the model.
@@ -28,9 +30,9 @@ class Preprocessor():
28
30
  """
29
31
 
30
32
  def __init__(
31
- self,
33
+ self,
32
34
  dim_embedding: Optional[int], # Size of the feature embedding. For some models this is None, which means the embedding does not depend on the number of features
33
- n_classes: int, # Actual number of classes in the dataset, assumed to be numbered 0, ..., n_classes - 1
35
+ n_classes: int, # Actual number of classes in the dataset, assumed to be numbered 0, ..., n_classes - 1
34
36
  dim_output: int, # Maximum number of classes the model has been trained on -> size of the output
35
37
  use_quantile_transformer: bool,
36
38
  use_feature_count_scaling: bool,
@@ -53,8 +55,8 @@ class Preprocessor():
53
55
  self.random_mirror_regression = random_mirror_regression
54
56
  self.random_mirror_x = random_mirror_x
55
57
  self.task = task
56
-
57
- def fit(self, X: np.ndarray, y: np.ndarray) -> Self:
58
+
59
+ def fit(self, X: np.ndarray, y: np.ndarray) -> "Preprocessor":
58
60
  """
59
61
  X: np.ndarray [n_samples, n_features]
60
62
  y: np.ndarray [n_samples]
@@ -78,16 +80,16 @@ class Preprocessor():
78
80
  if self.use_quantile_transformer:
79
81
  # If use quantile transform is off, it means that the preprocessing will happen on the GPU.
80
82
  X = self.fit_transform_quantile_transformer(X)
81
-
83
+
82
84
  self.mean, self.std = self.calc_mean_std(X)
83
85
  X = self.normalize_by_mean_std(X, self.mean, self.std)
84
-
86
+
85
87
  if self.use_random_transforms:
86
88
  X = self.transform_tabpfn(X)
87
89
 
88
90
  if self.task == Task.CLASSIFICATION and self.shuffle_classes:
89
91
  self.determine_shuffle_class_order()
90
-
92
+
91
93
  if self.shuffle_features:
92
94
  self.determine_feature_order(X)
93
95
 
@@ -104,7 +106,7 @@ class Preprocessor():
104
106
  X[np.isinf(X)] = 0
105
107
 
106
108
  return self
107
-
109
+
108
110
 
109
111
  def transform_X(self, X: np.ndarray):
110
112
 
@@ -116,12 +118,12 @@ class Preprocessor():
116
118
  # If use quantile transform is off, it means that the preprocessing will happen on the GPU.
117
119
 
118
120
  X = self.quantile_transformer.transform(X)
119
-
121
+
120
122
  X = self.normalize_by_mean_std(X, self.mean, self.std)
121
123
 
122
124
  if self.use_feature_count_scaling:
123
125
  X = self.normalize_by_feature_count(X)
124
-
126
+
125
127
  if self.use_random_transforms:
126
128
  X = self.random_transforms.transform(X)
127
129
 
@@ -140,11 +142,11 @@ class Preprocessor():
140
142
 
141
143
 
142
144
  def transform_tabpfn(self, X: np.ndarray):
143
-
145
+
144
146
  n_samples = X.shape[0]
145
147
  n_features = X.shape[1]
146
-
147
- use_config1 = random.random() < 0.5
148
+
149
+ use_config1 = random.random() < 0.5
148
150
  random_state = random.randint(0, 1000000)
149
151
 
150
152
  if use_config1:
@@ -171,12 +173,12 @@ class Preprocessor():
171
173
  ('ordinal', OrdinalEncoder(
172
174
  handle_unknown="use_encoded_value",
173
175
  unknown_value=np.nan
174
- ), [])
176
+ ), [])
175
177
  ], remainder='passthrough')
176
-
178
+
177
179
  return self.random_transforms.fit_transform(X)
178
-
179
-
180
+
181
+
180
182
  def transform_y(self, y: np.ndarray):
181
183
 
182
184
  if self.task == Task.CLASSIFICATION:
@@ -193,36 +195,34 @@ class Preprocessor():
193
195
  if self.task == Task.REGRESSION and self.random_mirror_regression:
194
196
  y = self.apply_random_mirror_regression(y)
195
197
 
196
- match self.task:
197
- case Task.CLASSIFICATION:
198
- y = y.astype(np.int64)
199
- case Task.REGRESSION:
200
- y = y.astype(np.float32)
198
+ if self.task == Task.CLASSIFICATION:
199
+ y = y.astype(np.int64)
200
+ elif self.task == Task.REGRESSION:
201
+ y = y.astype(np.float32)
201
202
 
202
203
  return y
203
-
204
+
204
205
 
205
206
  def inverse_transform_y(self, y: np.ndarray):
206
207
  # Function used during the prediction to transform the model output back to the original space
207
208
  # For classification, y is assumed to be logits of shape [n_samples, n_classes]
208
209
 
209
- match self.task:
210
- case Task.CLASSIFICATION:
211
- y = self.extract_correct_classes(y)
210
+ if self.task == Task.CLASSIFICATION:
211
+ y = self.extract_correct_classes(y)
212
212
 
213
- if self.shuffle_classes:
214
- y = self.undo_randomize_class_order(y)
213
+ if self.shuffle_classes:
214
+ y = self.undo_randomize_class_order(y)
215
215
 
216
- case Task.REGRESSION:
216
+ elif self.task == Task.REGRESSION:
217
217
 
218
- if self.random_mirror_regression:
219
- y = self.apply_random_mirror_regression(y)
218
+ if self.random_mirror_regression:
219
+ y = self.apply_random_mirror_regression(y)
220
220
 
221
- y = self.undo_normalize_y(y)
221
+ y = self.undo_normalize_y(y)
222
222
 
223
223
  return y
224
224
 
225
-
225
+
226
226
 
227
227
  def fit_transform_quantile_transformer(self, X: np.ndarray) -> np.ndarray:
228
228
 
@@ -233,12 +233,12 @@ class Preprocessor():
233
233
 
234
234
  return X
235
235
 
236
-
236
+
237
237
 
238
238
  def determine_which_features_are_singular(self, x: np.ndarray) -> None:
239
239
 
240
240
  self.singular_features = np.array([ len(np.unique(x_col)) for x_col in x.T ]) == 1
241
-
241
+
242
242
 
243
243
 
244
244
  def determine_which_features_to_select(self, x: np.ndarray, y: np.ndarray) -> None:
@@ -267,7 +267,7 @@ class Preprocessor():
267
267
  x[inds] = np.take(self.pre_nan_mean, inds[1])
268
268
  return x
269
269
 
270
-
270
+
271
271
  def select_features(self, x: np.ndarray) -> np.ndarray:
272
272
 
273
273
  if self.dim_embedding is None:
@@ -278,7 +278,7 @@ class Preprocessor():
278
278
  x = self.select_k_best.transform(x)
279
279
 
280
280
  return x
281
-
281
+
282
282
 
283
283
  def cutoff_singular_features(self, x: np.ndarray, singular_features: np.ndarray) -> np.ndarray:
284
284
 
@@ -295,7 +295,7 @@ class Preprocessor():
295
295
  mean = x.mean(axis=0)
296
296
  std = x.std(axis=0) + 1e-6
297
297
  return mean, std
298
-
298
+
299
299
 
300
300
  def normalize_by_mean_std(self, x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
301
301
  """
@@ -329,23 +329,23 @@ class Preprocessor():
329
329
  added_zeros = np.zeros((x.shape[0], dim_embedding - x.shape[1]), dtype=np.float32)
330
330
  x = np.concatenate([x, added_zeros], axis=1)
331
331
  return x
332
-
332
+
333
333
 
334
334
  def determine_mix_max_scale(self, y: np.ndarray) -> None:
335
335
  self.y_min = y.min()
336
336
  self.y_max = y.max()
337
337
  assert self.y_min != self.y_max, "y_min and y_max are the same, cannot normalize, regression makes no sense"
338
338
 
339
-
339
+
340
340
  def normalize_y(self, y: np.ndarray) -> np.ndarray:
341
341
  y = (y - self.y_min) / (self.y_max - self.y_min)
342
342
  return y
343
-
343
+
344
344
 
345
345
  def undo_normalize_y(self, y: np.ndarray) -> np.ndarray:
346
346
  y = y * (self.y_max - self.y_min) + self.y_min
347
347
  return y
348
-
348
+
349
349
 
350
350
  def determine_regression_mirror(self) -> None:
351
351
  self.regression_mirror = np.random.choice([True, False], size=(1,)).item()
@@ -355,7 +355,7 @@ class Preprocessor():
355
355
  if self.regression_mirror:
356
356
  y = 1 - y
357
357
  return y
358
-
358
+
359
359
 
360
360
  def determine_mirror(self, x: np.ndarray) -> None:
361
361
 
@@ -376,15 +376,15 @@ class Preprocessor():
376
376
  else:
377
377
  self.new_shuffle_classes = np.arange(self.n_classes)
378
378
 
379
-
379
+
380
380
  def randomize_class_order(self, y: np.ndarray) -> np.ndarray:
381
381
 
382
382
  mapping = { i: self.new_shuffle_classes[i] for i in range(self.n_classes) }
383
383
  y = np.array([mapping[i.item()] for i in y], dtype=np.int64)
384
384
 
385
- return y
386
-
387
-
385
+ return y
386
+
387
+
388
388
  def undo_randomize_class_order(self, y_logits: np.ndarray) -> np.ndarray:
389
389
  """
390
390
  We assume y_logits has shape [n_samples, n_classes]
@@ -393,9 +393,9 @@ class Preprocessor():
393
393
  # mapping = {self.new_shuffle_classes[i]: i for i in range(self.n_classes)}
394
394
  mapping = {i: self.new_shuffle_classes[i] for i in range(self.n_classes)}
395
395
  y = np.concatenate([y_logits[:, mapping[i]:mapping[i]+1] for i in range(self.n_classes)], axis=1)
396
-
396
+
397
397
  return y
398
-
398
+
399
399
 
400
400
  def extract_correct_classes(self, y_logits: np.ndarray) -> np.ndarray:
401
401
  # Even though our network might be able to support 10 classes,
@@ -0,0 +1 @@
1
+ # Model architecture modules for MitraModel
@@ -1,3 +1,5 @@
1
+ import json
2
+ import os
1
3
  from typing import Optional, Union
2
4
 
3
5
  import einops
@@ -5,11 +7,8 @@ import einx
5
7
  import torch
6
8
  import torch.nn as nn
7
9
  import torch.nn.functional as F
8
- from safetensors.torch import save_file
9
10
  from huggingface_hub import hf_hub_download
10
- from safetensors.torch import load_file
11
- import os
12
- import json
11
+ from safetensors.torch import load_file, save_file
13
12
 
14
13
  # Try to import flash attention, but make it optional
15
14
  try:
@@ -24,8 +23,8 @@ from torch.utils.checkpoint import checkpoint
24
23
  from ..._internal.config.enums import Task
25
24
  from ..._internal.models.base import BaseModel
26
25
  from ..._internal.models.embedding import (
27
- Tab2DEmbeddingX,
28
- Tab2DEmbeddingYClasses,
26
+ Tab2DEmbeddingX,
27
+ Tab2DEmbeddingYClasses,
29
28
  Tab2DEmbeddingYRegression,
30
29
  Tab2DQuantileEmbeddingX,
31
30
  )
@@ -64,16 +63,15 @@ class Tab2D(BaseModel):
64
63
  self.x_embedding = Tab2DEmbeddingX(dim)
65
64
 
66
65
 
67
- match self.task:
68
- case Task.CLASSIFICATION:
69
- self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
70
- case Task.REGRESSION:
71
- if self.dim_output == 1:
72
- self.y_embedding = Tab2DEmbeddingYRegression(dim)
73
- else:
74
- self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output)
75
- case _:
76
- raise ValueError(f"Task {task} not supported")
66
+ if self.task == Task.CLASSIFICATION:
67
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
68
+ elif self.task == Task.REGRESSION:
69
+ if self.dim_output == 1:
70
+ self.y_embedding = Tab2DEmbeddingYRegression(dim)
71
+ else:
72
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output)
73
+ else:
74
+ raise ValueError(f"Task {task} not supported")
77
75
 
78
76
  self.layers = nn.ModuleList()
79
77
 
@@ -165,18 +163,17 @@ class Tab2D(BaseModel):
165
163
 
166
164
  y_query__, x_query__ = einops.unpack(query__, pack_query__, 'b s * c') # (b, n_q, 1, c), (b, n_q, f, c)
167
165
 
168
- match self.task:
166
+ if self.task == Task.REGRESSION:
169
167
  # output has shape (batch_size, n_observations_query, n_features, n_classes)
170
168
  # we want to remove the n_features dimension, and for regression, the n_classes dimension
171
- case Task.REGRESSION:
172
- if self.dim_output == 1:
173
- y_query__ = y_query__[:, :, 0, 0]
174
- else:
175
- y_query__ = y_query__[:, :, 0, :]
176
- case Task.CLASSIFICATION:
169
+ if self.dim_output == 1:
170
+ y_query__ = y_query__[:, :, 0, 0]
171
+ else:
177
172
  y_query__ = y_query__[:, :, 0, :]
178
- case _:
179
- raise ValueError(f"Task {self.task} not supported")
173
+ elif self.task == Task.CLASSIFICATION:
174
+ y_query__ = y_query__[:, :, 0, :]
175
+ else:
176
+ raise ValueError(f"Task {self.task} not supported")
180
177
 
181
178
  return y_query__
182
179
 
@@ -0,0 +1 @@
1
+ # Utility modules for MitraModel
@@ -1,8 +1,18 @@
1
+ # TODO: To ensure deterministic operations we need to set torch.use_deterministic_algorithms(True)
2
+ # and os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'. The CUBLAS environment variable configures
3
+ # the workspace size for certain CUBLAS operations to ensure reproducibility when using CUDA >= 10.2.
4
+ # Both settings are required to ensure deterministic behavior in operations such as matrix multiplications.
5
+ import os
6
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
7
+
8
+ import os
9
+ from typing import List, Optional
10
+
1
11
  import pandas as pd
2
- from typing import Optional, List
12
+
3
13
  from autogluon.common.utils.resource_utils import ResourceManager
4
14
  from autogluon.core.models import AbstractModel
5
- import os
15
+
6
16
 
7
17
  # TODO: Needs memory usage estimate method
8
18
  class MitraModel(AbstractModel):
@@ -37,6 +47,17 @@ class MitraModel(AbstractModel):
37
47
  num_cpus: int = 1,
38
48
  **kwargs,
39
49
  ):
50
+
51
+ # TODO: Reset the number of threads based on the specified num_cpus
52
+ need_to_reset_torch_threads = False
53
+ torch_threads_og = None
54
+ if num_cpus is not None and isinstance(num_cpus, (int, float)):
55
+ torch_threads_og = torch.get_num_threads()
56
+ if torch_threads_og != num_cpus:
57
+ # reset torch threads back to original value after fit
58
+ torch.set_num_threads(num_cpus)
59
+ need_to_reset_torch_threads = True
60
+
40
61
  model_cls = self.get_model_cls()
41
62
 
42
63
  hyp = self._get_model_params()
@@ -65,14 +86,17 @@ class MitraModel(AbstractModel):
65
86
  time_limit=time_limit,
66
87
  )
67
88
 
89
+ if need_to_reset_torch_threads:
90
+ torch.set_num_threads(torch_threads_og)
91
+
68
92
  def _set_default_params(self):
69
93
  default_params = {
70
- "device": "cuda", # "cpu"
94
+ "device": "cpu",
71
95
  "n_estimators": 1,
72
96
  }
73
97
  for param, val in default_params.items():
74
98
  self._set_default_param_value(param, val)
75
-
99
+
76
100
  def _get_default_auxiliary_params(self) -> dict:
77
101
  default_auxiliary_params = super()._get_default_auxiliary_params()
78
102
  default_auxiliary_params.update(
@@ -87,7 +111,7 @@ class MitraModel(AbstractModel):
87
111
  @property
88
112
  def weights_path(self) -> str:
89
113
  return os.path.join(self.path, self.weights_file_name)
90
-
114
+
91
115
  def save(self, path: str = None, verbose=True) -> str:
92
116
  _model_weights_list = None
93
117
  if self.model is not None:
@@ -98,7 +122,7 @@ class MitraModel(AbstractModel):
98
122
  self.model.trainers[i].model = None
99
123
  self.model.trainers[i].optimizer = None
100
124
  self.model.trainers[i].scheduler_warmup = None
101
- self.model.trainers[i].scheduler_reduce_on_plateau = None
125
+ self.model.trainers[i].scheduler_reduce_on_plateau = None
102
126
  self._weights_saved = True
103
127
  path = super().save(path=path, verbose=verbose)
104
128
  if _model_weights_list is not None:
@@ -108,7 +132,7 @@ class MitraModel(AbstractModel):
108
132
  for i in range(len(self.model.trainers)):
109
133
  self.model.trainers[i].model = _model_weights_list[i]
110
134
  return path
111
-
135
+
112
136
  @classmethod
113
137
  def load(cls, path: str, reset_paths=False, verbose=True):
114
138
  model: MitraModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
@@ -136,14 +160,16 @@ class MitraModel(AbstractModel):
136
160
  return default_ag_args_ensemble
137
161
 
138
162
  def _get_default_resources(self) -> tuple[int, int]:
139
- # logical=False is faster in training
140
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
141
- num_gpus = 1
163
+ # Use only physical cores for better performance based on benchmarks
164
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
165
+
166
+ num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
167
+
142
168
  return num_cpus, num_gpus
143
169
 
144
170
  def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
145
171
  return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, **kwargs)
146
-
172
+
147
173
  @classmethod
148
174
  def _estimate_memory_usage_static(
149
175
  cls,
@@ -157,7 +183,7 @@ class MitraModel(AbstractModel):
157
183
  cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
158
184
  cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
159
185
  )
160
-
186
+
161
187
  @classmethod
162
188
  def _estimate_memory_usage_static_cpu_icl(
163
189
  cls,
@@ -165,10 +191,18 @@ class MitraModel(AbstractModel):
165
191
  X: pd.DataFrame,
166
192
  **kwargs,
167
193
  ) -> int:
168
- cpu_memory_kb = 1.3 * (0.001748 * (X.shape[0]**2) * X.shape[1] + \
169
- 0.001206 * X.shape[0] * (X.shape[1]**2) + \
170
- 10.3482 * X.shape[0] * X.shape[1] + \
171
- 6409698)
194
+ rows, features = X.shape[0], X.shape[1]
195
+
196
+ # For very small datasets, use a more conservative estimate
197
+ if rows * features < 100: # Small dataset threshold
198
+ # Use a simpler linear formula for small datasets
199
+ cpu_memory_kb = 1.3 * (100 * rows * features + 1000000) # 1GB base + linear scaling
200
+ else:
201
+ # Original formula for larger datasets
202
+ cpu_memory_kb = 1.3 * (0.001748 * (rows**2) * features + \
203
+ 0.001206 * rows * (features**2) + \
204
+ 10.3482 * rows * features + \
205
+ 6409698)
172
206
  return int(cpu_memory_kb * 1e3)
173
207
 
174
208
  @classmethod
@@ -178,12 +212,20 @@ class MitraModel(AbstractModel):
178
212
  X: pd.DataFrame,
179
213
  **kwargs,
180
214
  ) -> int:
181
- cpu_memory_kb = 1.3 * (0.001 * (X.shape[0]**2) * X.shape[1] + \
182
- 0.004541 * X.shape[0] * (X.shape[1]**2) + \
183
- 46.2974 * X.shape[0] * X.shape[1] + \
184
- 5605681)
215
+ rows, features = X.shape[0], X.shape[1]
216
+
217
+ # For very small datasets, use a more conservative estimate
218
+ if rows * features < 100: # Small dataset threshold
219
+ # Use a simpler linear formula for small datasets
220
+ cpu_memory_kb = 1.3 * (200 * rows * features + 2000000) # 2GB base + linear scaling
221
+ else:
222
+ # Original formula for larger datasets
223
+ cpu_memory_kb = 1.3 * (0.001 * (rows**2) * features + \
224
+ 0.004541 * rows * (features**2) + \
225
+ 46.2974 * rows * features + \
226
+ 5605681)
185
227
  return int(cpu_memory_kb * 1e3)
186
-
228
+
187
229
  @classmethod
188
230
  def _estimate_memory_usage_static_gpu_cpu(
189
231
  cls,
@@ -191,7 +233,13 @@ class MitraModel(AbstractModel):
191
233
  X: pd.DataFrame,
192
234
  **kwargs,
193
235
  ) -> int:
194
- return int(5 * 1e9)
236
+ rows, features = X.shape[0], X.shape[1]
237
+
238
+ # For very small datasets, use a more conservative estimate
239
+ if rows * features < 100: # Small dataset threshold
240
+ return int(2.5 * 1e9) # 2.5GB for small datasets
241
+ else:
242
+ return int(5 * 1e9) # 5GB for larger datasets
195
243
 
196
244
  @classmethod
197
245
  def _estimate_memory_usage_static_gpu_gpu(
@@ -200,7 +248,15 @@ class MitraModel(AbstractModel):
200
248
  X: pd.DataFrame,
201
249
  **kwargs,
202
250
  ) -> int:
203
- gpu_memory_mb = 1.3 * (0.05676 * X.shape[0] * X.shape[1] + 3901)
251
+ rows, features = X.shape[0], X.shape[1]
252
+
253
+ # For very small datasets, use a more conservative estimate
254
+ if rows * features < 100: # Small dataset threshold
255
+ # Use a simpler linear formula for small datasets
256
+ gpu_memory_mb = 1.3 * (10 * rows * features + 2000) # 2GB base + linear scaling
257
+ else:
258
+ # Original formula for larger datasets
259
+ gpu_memory_mb = 1.3 * (0.05676 * rows * features + 3901)
204
260
  return int(gpu_memory_mb * 1e6)
205
261
 
206
262
  @classmethod
@@ -208,7 +264,7 @@ class MitraModel(AbstractModel):
208
264
  return {
209
265
  "can_estimate_memory_usage_static": True,
210
266
  }
211
-
267
+
212
268
  def _more_tags(self) -> dict:
213
269
  tags = {"can_refit_full": True}
214
270
  return tags