autogluon.tabular 1.3.2b20250718__py3-none-any.whl → 1.3.2b20250719__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,23 @@
1
- # TODO: To ensure deterministic operations we need to set torch.use_deterministic_algorithms(True)
1
+ # TODO: To ensure deterministic operations we need to set torch.use_deterministic_algorithms(True)
2
2
  # and os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'. The CUBLAS environment variable configures
3
3
  # the workspace size for certain CUBLAS operations to ensure reproducibility when using CUDA >= 10.2.
4
4
  # Both settings are required to ensure deterministic behavior in operations such as matrix multiplications.
5
5
  import os
6
- os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
6
+
7
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
7
8
 
8
9
  import os
9
10
  from typing import List, Optional
10
11
 
11
12
  import pandas as pd
13
+ import torch
14
+ import logging
12
15
 
13
16
  from autogluon.common.utils.resource_utils import ResourceManager
14
17
  from autogluon.core.models import AbstractModel
15
18
 
19
+ logger = logging.getLogger(__name__)
20
+
16
21
 
17
22
  # TODO: Needs memory usage estimate method
18
23
  class MitraModel(AbstractModel):
@@ -26,12 +31,26 @@ class MitraModel(AbstractModel):
26
31
  self.problem_type = problem_type
27
32
  self._weights_saved = False
28
33
 
34
+ @staticmethod
35
+ def _get_default_device():
36
+ """Get the best available device for the current system."""
37
+ if ResourceManager.get_gpu_count_torch(cuda_only=True) > 0:
38
+ logger.info("Using CUDA GPU")
39
+ return "cuda"
40
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
41
+ logger.info("Using MPS GPU")
42
+ return "mps" # Apple silicon
43
+ else:
44
+ return "cpu"
45
+
29
46
  def get_model_cls(self):
30
47
  from .sklearn_interface import MitraClassifier
31
- if self.problem_type in ['binary', 'multiclass']:
48
+
49
+ if self.problem_type in ["binary", "multiclass"]:
32
50
  model_cls = MitraClassifier
33
- elif self.problem_type == 'regression':
51
+ elif self.problem_type == "regression":
34
52
  from .sklearn_interface import MitraRegressor
53
+
35
54
  model_cls = MitraRegressor
36
55
  else:
37
56
  raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
@@ -47,7 +66,6 @@ class MitraModel(AbstractModel):
47
66
  num_cpus: int = 1,
48
67
  **kwargs,
49
68
  ):
50
-
51
69
  # TODO: Reset the number of threads based on the specified num_cpus
52
70
  need_to_reset_torch_threads = False
53
71
  torch_threads_og = None
@@ -91,7 +109,7 @@ class MitraModel(AbstractModel):
91
109
 
92
110
  def _set_default_params(self):
93
111
  default_params = {
94
- "device": "cpu",
112
+ "device": self._get_default_device(),
95
113
  "n_estimators": 1,
96
114
  }
97
115
  for param, val in default_params.items():
@@ -127,6 +145,7 @@ class MitraModel(AbstractModel):
127
145
  path = super().save(path=path, verbose=verbose)
128
146
  if _model_weights_list is not None:
129
147
  import torch
148
+
130
149
  os.makedirs(self.path, exist_ok=True)
131
150
  torch.save(_model_weights_list, self.weights_path)
132
151
  for i in range(len(self.model.trainers)):
@@ -139,6 +158,7 @@ class MitraModel(AbstractModel):
139
158
 
140
159
  if model._weights_saved:
141
160
  import torch
161
+
142
162
  model_weights_list = torch.load(model.weights_path, weights_only=False) # nosec B614
143
163
  for i in range(len(model.model.trainers)):
144
164
  model.model.trainers[i].model = model_weights_list[i]
@@ -154,7 +174,7 @@ class MitraModel(AbstractModel):
154
174
  default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
155
175
  # FIXME: Test if it works with parallel, need to enable n_cpus support
156
176
  extra_ag_args_ensemble = {
157
- "fold_fitting_strategy": "sequential_local", # FIXME: Comment out after debugging for large speedup
177
+ "fold_fitting_strategy": "sequential_local", # FIXME: Comment out after debugging for large speedup
158
178
  }
159
179
  default_ag_args_ensemble.update(extra_ag_args_ensemble)
160
180
  return default_ag_args_ensemble
@@ -168,7 +188,9 @@ class MitraModel(AbstractModel):
168
188
  return num_cpus, num_gpus
169
189
 
170
190
  def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
171
- return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, **kwargs)
191
+ return self.estimate_memory_usage_static(
192
+ X=X, problem_type=self.problem_type, num_classes=self.num_classes, **kwargs
193
+ )
172
194
 
173
195
  @classmethod
174
196
  def _estimate_memory_usage_static(
@@ -199,10 +221,9 @@ class MitraModel(AbstractModel):
199
221
  cpu_memory_kb = 1.3 * (100 * rows * features + 1000000) # 1GB base + linear scaling
200
222
  else:
201
223
  # Original formula for larger datasets
202
- cpu_memory_kb = 1.3 * (0.001748 * (rows**2) * features + \
203
- 0.001206 * rows * (features**2) + \
204
- 10.3482 * rows * features + \
205
- 6409698)
224
+ cpu_memory_kb = 1.3 * (
225
+ 0.001748 * (rows**2) * features + 0.001206 * rows * (features**2) + 10.3482 * rows * features + 6409698
226
+ )
206
227
  return int(cpu_memory_kb * 1e3)
207
228
 
208
229
  @classmethod
@@ -220,10 +241,9 @@ class MitraModel(AbstractModel):
220
241
  cpu_memory_kb = 1.3 * (200 * rows * features + 2000000) # 2GB base + linear scaling
221
242
  else:
222
243
  # Original formula for larger datasets
223
- cpu_memory_kb = 1.3 * (0.001 * (rows**2) * features + \
224
- 0.004541 * rows * (features**2) + \
225
- 46.2974 * rows * features + \
226
- 5605681)
244
+ cpu_memory_kb = 1.3 * (
245
+ 0.001 * (rows**2) * features + 0.004541 * rows * (features**2) + 46.2974 * rows * features + 5605681
246
+ )
227
247
  return int(cpu_memory_kb * 1e3)
228
248
 
229
249
  @classmethod
@@ -1,4 +1,4 @@
1
1
  """This is the autogluon version file."""
2
2
 
3
- __version__ = "1.3.2b20250718"
3
+ __version__ = "1.3.2b20250719"
4
4
  __lite__ = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.tabular
3
- Version: 1.3.2b20250718
3
+ Version: 1.3.2b20250719
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -41,20 +41,20 @@ Requires-Dist: scipy<1.17,>=1.5.4
41
41
  Requires-Dist: pandas<2.4.0,>=2.0.0
42
42
  Requires-Dist: scikit-learn<1.8.0,>=1.4.0
43
43
  Requires-Dist: networkx<4,>=3.0
44
- Requires-Dist: autogluon.core==1.3.2b20250718
45
- Requires-Dist: autogluon.features==1.3.2b20250718
44
+ Requires-Dist: autogluon.core==1.3.2b20250719
45
+ Requires-Dist: autogluon.features==1.3.2b20250719
46
46
  Provides-Extra: all
47
+ Requires-Dist: autogluon.core[all]==1.3.2b20250719; extra == "all"
47
48
  Requires-Dist: einops<0.9,>=0.7; extra == "all"
49
+ Requires-Dist: catboost<1.3,>=1.2; extra == "all"
50
+ Requires-Dist: xgboost<3.1,>=2.0; extra == "all"
48
51
  Requires-Dist: huggingface-hub[torch]; extra == "all"
49
- Requires-Dist: numpy<2.3.0,>=1.25; extra == "all"
50
- Requires-Dist: fastai<2.9,>=2.3.1; extra == "all"
52
+ Requires-Dist: pytabkit<1.6,>=1.5; extra == "all"
51
53
  Requires-Dist: torch<2.8,>=2.2; extra == "all"
52
- Requires-Dist: autogluon.core[all]==1.3.2b20250718; extra == "all"
53
54
  Requires-Dist: spacy<3.9; extra == "all"
54
- Requires-Dist: xgboost<3.1,>=2.0; extra == "all"
55
- Requires-Dist: pytabkit<1.6,>=1.5; extra == "all"
56
- Requires-Dist: catboost<1.3,>=1.2; extra == "all"
57
55
  Requires-Dist: lightgbm<4.7,>=4.0; extra == "all"
56
+ Requires-Dist: fastai<2.9,>=2.3.1; extra == "all"
57
+ Requires-Dist: numpy<2.3.0,>=1.25; extra == "all"
58
58
  Provides-Extra: catboost
59
59
  Requires-Dist: numpy<2.3.0,>=1.25; extra == "catboost"
60
60
  Requires-Dist: catboost<1.3,>=1.2; extra == "catboost"
@@ -72,7 +72,7 @@ Requires-Dist: einx; extra == "mitra"
72
72
  Requires-Dist: omegaconf; extra == "mitra"
73
73
  Requires-Dist: transformers; extra == "mitra"
74
74
  Provides-Extra: ray
75
- Requires-Dist: autogluon.core[all]==1.3.2b20250718; extra == "ray"
75
+ Requires-Dist: autogluon.core[all]==1.3.2b20250719; extra == "ray"
76
76
  Provides-Extra: realmlp
77
77
  Requires-Dist: pytabkit<1.6,>=1.5; extra == "realmlp"
78
78
  Provides-Extra: skex
@@ -1,6 +1,6 @@
1
- autogluon.tabular-1.3.2b20250718-py3.9-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
1
+ autogluon.tabular-1.3.2b20250719-py3.9-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
2
2
  autogluon/tabular/__init__.py,sha256=2OXpJCvENRHubBTYNIPpHX93WWuFZzsJBtTZbNVHVas,400
3
- autogluon/tabular/version.py,sha256=t7hPQFF0BzYTBfD-vM9hoER3q-C5x0pjSWoVO1dcT0w,91
3
+ autogluon/tabular/version.py,sha256=YLKHqDF99tyKxzCGdB6cvf6IZT4wW9FS65QYC0TWIYM,91
4
4
  autogluon/tabular/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  autogluon/tabular/configs/config_helper.py,sha256=JsdVGmpcYL88GPKBznPtqJ1sGaByOSvLn7KWU-HyVoQ,21085
6
6
  autogluon/tabular/configs/feature_generator_presets.py,sha256=EV5Ym8VW15q92MwOUpTi7wZFS2QooM51fLg3RdUsn-M,1223
@@ -68,7 +68,7 @@ autogluon/tabular/models/lr/hyperparameters/__init__.py,sha256=47DEQpj8HBSa-_TIm
68
68
  autogluon/tabular/models/lr/hyperparameters/parameters.py,sha256=Hr5YC13zjbt3CfCbzGj8iXUIuDn-Q7FvDT2uSuiSVlM,1414
69
69
  autogluon/tabular/models/lr/hyperparameters/searchspaces.py,sha256=Igywc-B6qJ9EBLdasrDhW-Ot5FGirIzbXLwv5HRe5Xo,276
70
70
  autogluon/tabular/models/mitra/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
71
- autogluon/tabular/models/mitra/mitra_model.py,sha256=XiTlzy-RbbHe1t8VCU1y976zmqrDAkO_HhkBiSlk7mM,9985
71
+ autogluon/tabular/models/mitra/mitra_model.py,sha256=gvoh22xk17G9ctcYL3RDDWkFcqEbFsJsNKXClPHSK8M,10429
72
72
  autogluon/tabular/models/mitra/sklearn_interface.py,sha256=nX830-_7KYjMnwJ8m8jhCfG7BXU379Ecn5Lu3RvN8Us,18513
73
73
  autogluon/tabular/models/mitra/_internal/__init__.py,sha256=dN2dz1pGMgQTFiSf9oYbyq23iJUxV8QNlOX3qw3KUO4,35
74
74
  autogluon/tabular/models/mitra/_internal/config/__init__.py,sha256=Exu_Sx6-K-D5peDQ_TibsjZpqAALs2-9IXfq8hu1mwU,40
@@ -188,11 +188,11 @@ autogluon/tabular/trainer/model_presets/presets.py,sha256=hoWADaOG576Q_XLV1nY_ju
188
188
  autogluon/tabular/trainer/model_presets/presets_distill.py,sha256=MnFC2GJc6RmDBNAGbsO2XMfo3PjR8cUrZoilWW8gTYQ,3295
189
189
  autogluon/tabular/tuning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
190
190
  autogluon/tabular/tuning/feature_pruner.py,sha256=9iNku8gVbYEkjuKlyITPJDicsNkoraaQOlINQq9iZlQ,6877
191
- autogluon.tabular-1.3.2b20250718.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
192
- autogluon.tabular-1.3.2b20250718.dist-info/METADATA,sha256=edzh0r-bATMf-HfQBg7-Gdox3Bq8KfuT101utiZe35s,14646
193
- autogluon.tabular-1.3.2b20250718.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
194
- autogluon.tabular-1.3.2b20250718.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
195
- autogluon.tabular-1.3.2b20250718.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
196
- autogluon.tabular-1.3.2b20250718.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
197
- autogluon.tabular-1.3.2b20250718.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
198
- autogluon.tabular-1.3.2b20250718.dist-info/RECORD,,
191
+ autogluon.tabular-1.3.2b20250719.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
192
+ autogluon.tabular-1.3.2b20250719.dist-info/METADATA,sha256=KRUVAwkQN8aBxfrNHAd02IIgTbtJx-cxqm9rfUP4E68,14646
193
+ autogluon.tabular-1.3.2b20250719.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
194
+ autogluon.tabular-1.3.2b20250719.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
195
+ autogluon.tabular-1.3.2b20250719.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
196
+ autogluon.tabular-1.3.2b20250719.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
197
+ autogluon.tabular-1.3.2b20250719.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
198
+ autogluon.tabular-1.3.2b20250719.dist-info/RECORD,,