autogluon.tabular 1.2.1b20250114__py3-none-any.whl → 1.2.1b20250116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -156,12 +156,26 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
156
156
 
157
157
  return processor_kwargs, optimizer_kwargs, fit_kwargs, loss_kwargs, params
158
158
 
159
- def _fit(self, X, y, X_val=None, y_val=None, time_limit=None, sample_weight=None, num_cpus=1, num_gpus=0, reporter=None, verbosity=2, **kwargs):
159
+ def _fit(
160
+ self,
161
+ X: pd.DataFrame,
162
+ y: pd.Series,
163
+ X_val: pd.DataFrame = None,
164
+ y_val: pd.Series = None,
165
+ X_test: pd.DataFrame = None,
166
+ y_test: pd.Series = None,
167
+ time_limit: float = None,
168
+ sample_weight=None,
169
+ num_cpus: int = 1,
170
+ num_gpus: float = 0,
171
+ reporter=None,
172
+ verbosity: int = 2,
173
+ **kwargs,
174
+ ):
160
175
  try_import_torch()
161
176
  import torch
162
177
 
163
178
  torch.set_num_threads(num_cpus)
164
- from .tabular_torch_dataset import TabularTorchDataset
165
179
 
166
180
  start_time = time.time()
167
181
 
@@ -188,19 +202,20 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
188
202
  self.num_dataloading_workers = 0 # TODO: verify 0 is typically faster and uses less memory than 1 in pytorch
189
203
  self.num_dataloading_workers = 0 # TODO: >0 crashes on MacOS
190
204
  self.max_batch_size = params.pop("max_batch_size", 512)
191
- batch_size = params.pop("batch_size", None)
192
- if batch_size is None:
193
- if isinstance(X, TabularTorchDataset):
194
- batch_size = min(int(2 ** (3 + np.floor(np.log10(len(X))))), self.max_batch_size)
195
- else:
196
- batch_size = min(int(2 ** (3 + np.floor(np.log10(X.shape[0])))), self.max_batch_size)
197
205
 
198
- X_test = kwargs.get("X_test", None)
199
- y_test = kwargs.get("y_test", None)
206
+ train_dataset = self._generate_dataset(X=X, y=y, train_params=processor_kwargs, is_train=True)
207
+ if X_val is not None and y_val is not None:
208
+ val_dataset = self._generate_dataset(X=X_val, y=y_val)
209
+ else:
210
+ val_dataset = None
211
+ if X_test is not None and y_test is not None:
212
+ test_dataset = self._generate_dataset(X=X_test, y=y_test)
213
+ else:
214
+ test_dataset = None
200
215
 
201
- train_dataset = self._generate_dataset(X, y, train_params=processor_kwargs, is_train=True)
202
- val_dataset = self._generate_dataset(X_val, y_val)
203
- test_dataset = self._generate_dataset(X_test, y_test)
216
+ batch_size = params.pop("batch_size", None)
217
+ if batch_size is None:
218
+ batch_size = min(int(2 ** (3 + np.floor(np.log10(len(X))))), self.max_batch_size, len(X))
204
219
 
205
220
  logger.log(
206
221
  15,
@@ -255,16 +270,16 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
255
270
 
256
271
  def _train_net(
257
272
  self,
258
- train_dataset,
259
- loss_kwargs,
260
- batch_size,
261
- num_epochs,
262
- epochs_wo_improve,
263
- val_dataset=None,
264
- test_dataset=None,
265
- time_limit=None,
273
+ train_dataset: TabularTorchDataset,
274
+ loss_kwargs: dict,
275
+ batch_size: int,
276
+ num_epochs: int,
277
+ epochs_wo_improve: int,
278
+ val_dataset: TabularTorchDataset = None,
279
+ test_dataset: TabularTorchDataset = None,
280
+ time_limit: float = None,
266
281
  reporter=None,
267
- verbosity=2,
282
+ verbosity: int = 2,
268
283
  ):
269
284
  import torch
270
285
 
@@ -634,13 +649,13 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
634
649
  preds_dataset = np.concatenate(preds_dataset, 0)
635
650
  return preds_dataset
636
651
 
637
- def _generate_dataset(self, X: pd.DataFrame, y: pd.Series, train_params: dict = {}, is_train: bool = False):
652
+ def _generate_dataset(self, X: pd.DataFrame | TabularTorchDataset, y: pd.Series, train_params: dict = {}, is_train: bool = False) -> TabularTorchDataset:
638
653
  """
639
654
  Generate TabularTorchDataset from X and y.
640
655
 
641
656
  Params:
642
657
  -------
643
- X: pd.DataFrame
658
+ X: pd.DataFrame | TabularTorchDataset
644
659
  The X data.
645
660
  y: pd.Series
646
661
  The y data.
@@ -676,14 +691,11 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
676
691
  use_ngram_features=use_ngram_features,
677
692
  )
678
693
  else:
679
- if X is not None:
680
- if isinstance(X, TabularTorchDataset):
681
- dataset = X
682
- else:
683
- X = self.preprocess(X)
684
- dataset = self._process_test_data(df=X, labels=y)
694
+ if isinstance(X, TabularTorchDataset):
695
+ dataset = X
685
696
  else:
686
- dataset = None
697
+ X = self.preprocess(X)
698
+ dataset = self._process_test_data(df=X, labels=y)
687
699
 
688
700
  return dataset
689
701
 
@@ -1,3 +1,4 @@
1
1
  """This is the autogluon version file."""
2
- __version__ = '1.2.1b20250114'
2
+
3
+ __version__ = "1.2.1b20250116"
3
4
  __lite__ = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.tabular
3
- Version: 1.2.1b20250114
3
+ Version: 1.2.1b20250116
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -39,19 +39,19 @@ Requires-Dist: scipy<1.16,>=1.5.4
39
39
  Requires-Dist: pandas<2.3.0,>=2.0.0
40
40
  Requires-Dist: scikit-learn<1.5.3,>=1.4.0
41
41
  Requires-Dist: networkx<4,>=3.0
42
- Requires-Dist: autogluon.core==1.2.1b20250114
43
- Requires-Dist: autogluon.features==1.2.1b20250114
42
+ Requires-Dist: autogluon.core==1.2.1b20250116
43
+ Requires-Dist: autogluon.features==1.2.1b20250116
44
44
  Provides-Extra: all
45
45
  Requires-Dist: numpy<2.0.0,>=1.25; extra == "all"
46
- Requires-Dist: fastai<2.8,>=2.3.1; extra == "all"
47
- Requires-Dist: lightgbm<4.6,>=4.0; extra == "all"
48
- Requires-Dist: torch<2.6,>=2.2; extra == "all"
49
- Requires-Dist: spacy<3.8; extra == "all"
50
46
  Requires-Dist: xgboost<2.2,>=1.6; extra == "all"
51
- Requires-Dist: autogluon.core[all]==1.2.1b20250114; extra == "all"
52
47
  Requires-Dist: einops<0.9,>=0.7; extra == "all"
48
+ Requires-Dist: autogluon.core[all]==1.2.1b20250116; extra == "all"
53
49
  Requires-Dist: catboost<1.3,>=1.2; extra == "all"
50
+ Requires-Dist: torch<2.6,>=2.2; extra == "all"
51
+ Requires-Dist: spacy<3.8; extra == "all"
52
+ Requires-Dist: fastai<2.8,>=2.3.1; extra == "all"
54
53
  Requires-Dist: huggingface-hub[torch]; extra == "all"
54
+ Requires-Dist: lightgbm<4.6,>=4.0; extra == "all"
55
55
  Provides-Extra: catboost
56
56
  Requires-Dist: numpy<2.0.0,>=1.25; extra == "catboost"
57
57
  Requires-Dist: catboost<1.3,>=1.2; extra == "catboost"
@@ -64,7 +64,7 @@ Requires-Dist: imodels<1.4.0,>=1.3.10; extra == "imodels"
64
64
  Provides-Extra: lightgbm
65
65
  Requires-Dist: lightgbm<4.6,>=4.0; extra == "lightgbm"
66
66
  Provides-Extra: ray
67
- Requires-Dist: autogluon.core[all]==1.2.1b20250114; extra == "ray"
67
+ Requires-Dist: autogluon.core[all]==1.2.1b20250116; extra == "ray"
68
68
  Provides-Extra: skex
69
69
  Requires-Dist: scikit-learn-intelex<2025.1,>=2024.0; extra == "skex"
70
70
  Provides-Extra: skl2onnx
@@ -1,6 +1,6 @@
1
- autogluon.tabular-1.2.1b20250114-py3.8-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
1
+ autogluon.tabular-1.2.1b20250116-py3.8-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
2
2
  autogluon/tabular/__init__.py,sha256=2OXpJCvENRHubBTYNIPpHX93WWuFZzsJBtTZbNVHVas,400
3
- autogluon/tabular/version.py,sha256=Um-j8InMDDRIJ5WoL607BMYYvkzsVWU8cxXJ2aXfjbQ,90
3
+ autogluon/tabular/version.py,sha256=qtf1yrCzRdeJXfqZGb9YWrrjOFhl2omGSnhF5i131QY,91
4
4
  autogluon/tabular/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  autogluon/tabular/configs/config_helper.py,sha256=Pb2aW9Z9w77pYKPRVZ3nBzHY3KJaiEJSJ747zZcJIVk,21132
6
6
  autogluon/tabular/configs/feature_generator_presets.py,sha256=EV5Ym8VW15q92MwOUpTi7wZFS2QooM51fLg3RdUsn-M,1223
@@ -120,7 +120,7 @@ autogluon/tabular/models/tabular_nn/hyperparameters/__init__.py,sha256=47DEQpj8H
120
120
  autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py,sha256=Z3t_U1f7jfolPey6lzqgJyoFbVgoncFNSvCKXSuLxeU,6465
121
121
  autogluon/tabular/models/tabular_nn/hyperparameters/searchspaces.py,sha256=pT9cJ3MaWPnaQwAf47Yz6f0-L9qDBknahERbggAp52U,2810
122
122
  autogluon/tabular/models/tabular_nn/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py,sha256=XVtRg9y4jNOScdT3RK-UhcbqHuLtyptxbI4drNg3RaE,42093
123
+ autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py,sha256=1ZJHM52wMZHSbXty19ykKA5tNql6ISO7nuFmAAuR-uU,42324
124
124
  autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py,sha256=oelC0uA9KNVtNKXU5jTywg-OfIF-5AguAXFYSKwN3zU,13499
125
125
  autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py,sha256=Qc3PwXTD8A7PgXi6EGuaBCrN3jsFAXDLCW7i6tE5wYI,11338
126
126
  autogluon/tabular/models/tabular_nn/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -152,11 +152,11 @@ autogluon/tabular/trainer/model_presets/presets.py,sha256=1E-Z1FxUpyydaoEdxcTCg7
152
152
  autogluon/tabular/trainer/model_presets/presets_distill.py,sha256=MnFC2GJc6RmDBNAGbsO2XMfo3PjR8cUrZoilWW8gTYQ,3295
153
153
  autogluon/tabular/tuning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
154
154
  autogluon/tabular/tuning/feature_pruner.py,sha256=9iNku8gVbYEkjuKlyITPJDicsNkoraaQOlINQq9iZlQ,6877
155
- autogluon.tabular-1.2.1b20250114.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
156
- autogluon.tabular-1.2.1b20250114.dist-info/METADATA,sha256=yz9jxr4weyLJLpGkDAqM-ZSWeAnSCC_xU8w1zmbHxpE,14315
157
- autogluon.tabular-1.2.1b20250114.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
158
- autogluon.tabular-1.2.1b20250114.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
159
- autogluon.tabular-1.2.1b20250114.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
160
- autogluon.tabular-1.2.1b20250114.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
161
- autogluon.tabular-1.2.1b20250114.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
162
- autogluon.tabular-1.2.1b20250114.dist-info/RECORD,,
155
+ autogluon.tabular-1.2.1b20250116.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
156
+ autogluon.tabular-1.2.1b20250116.dist-info/METADATA,sha256=QSKAy7j9NWzgoqH_awYRFKYq2fif9r603N26Xl6eweM,14315
157
+ autogluon.tabular-1.2.1b20250116.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
158
+ autogluon.tabular-1.2.1b20250116.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
159
+ autogluon.tabular-1.2.1b20250116.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
160
+ autogluon.tabular-1.2.1b20250116.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
161
+ autogluon.tabular-1.2.1b20250116.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
162
+ autogluon.tabular-1.2.1b20250116.dist-info/RECORD,,