autogluon.tabular 1.4.1b20251014__py3-none-any.whl → 1.5.0b20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. autogluon/tabular/configs/hyperparameter_configs.py +4 -0
  2. autogluon/tabular/configs/presets_configs.py +39 -2
  3. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +2 -44
  4. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +2 -0
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +2 -0
  6. autogluon/tabular/learner/default_learner.py +1 -0
  7. autogluon/tabular/models/__init__.py +3 -1
  8. autogluon/tabular/models/abstract/__init__.py +0 -0
  9. autogluon/tabular/models/abstract/abstract_torch_model.py +148 -0
  10. autogluon/tabular/models/catboost/catboost_model.py +2 -5
  11. autogluon/tabular/models/ebm/ebm_model.py +2 -6
  12. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +9 -3
  13. autogluon/tabular/models/lgb/lgb_model.py +60 -17
  14. autogluon/tabular/models/lgb/lgb_utils.py +2 -2
  15. autogluon/tabular/models/lr/lr_model.py +2 -4
  16. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  17. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +14 -1
  18. autogluon/tabular/models/mitra/mitra_model.py +55 -29
  19. autogluon/tabular/models/realmlp/realmlp_model.py +8 -5
  20. autogluon/tabular/models/rf/rf_model.py +6 -8
  21. autogluon/tabular/models/tabdpt/__init__.py +0 -0
  22. autogluon/tabular/models/tabdpt/tabdpt_model.py +253 -0
  23. autogluon/tabular/models/tabicl/tabicl_model.py +15 -5
  24. autogluon/tabular/models/tabm/tabm_model.py +25 -8
  25. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +7 -5
  26. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +451 -0
  27. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +87 -12
  28. autogluon/tabular/models/tabprep/__init__.py +0 -0
  29. autogluon/tabular/models/tabprep/prep_lgb_model.py +21 -0
  30. autogluon/tabular/models/tabprep/prep_mixin.py +220 -0
  31. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -6
  32. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +12 -4
  33. autogluon/tabular/models/xgboost/xgboost_model.py +3 -4
  34. autogluon/tabular/predictor/predictor.py +50 -20
  35. autogluon/tabular/registry/_ag_model_registry.py +8 -2
  36. autogluon/tabular/testing/fit_helper.py +61 -0
  37. autogluon/tabular/trainer/abstract_trainer.py +45 -9
  38. autogluon/tabular/trainer/auto_trainer.py +5 -0
  39. autogluon/tabular/version.py +1 -1
  40. autogluon.tabular-1.5.0b20251222-py3.11-nspkg.pth +1 -0
  41. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/METADATA +97 -87
  42. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/RECORD +48 -38
  43. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/WHEEL +1 -1
  44. autogluon.tabular-1.4.1b20251014-py3.9-nspkg.pth +0 -1
  45. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info/licenses}/LICENSE +0 -0
  46. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info/licenses}/NOTICE +0 -0
  47. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/namespace_packages.txt +0 -0
  48. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/top_level.txt +0 -0
  49. {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/zip-safe +0 -0
@@ -0,0 +1,253 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from autogluon.common.utils.resource_utils import ResourceManager
6
+ from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION
7
+ from autogluon.features.generators import LabelEncoderFeatureGenerator
8
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
9
+
10
+ if TYPE_CHECKING:
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+
15
+ # FIXME: Nick:
16
+ # TODO: batch_size is linear to memory usage
17
+ # 512 default
18
+ # should be less for very large datasets
19
+ # 128 batch_size on Bioresponse -> 12 GB VRAM
20
+ # Train Data Rows: 2500
21
+ # Train Data Columns: 1776
22
+ # Problem Type: binary
23
+ # FIXME: Just set context_size = infinity, everything is way faster, memory usage is way lower, etc.
24
+ # Train Data Rows: 100000
25
+ # Train Data Columns: 10
26
+ # binary
27
+ # only takes 6.7 GB during inference with batch_size = 512
28
+ # FIXME: Make it work when loading on CPU?
29
+ # FIXME: Can we run 8 in parallel to speed up?
30
+ # TODO: clip_sigma == 1 is terrible, clip_sigma == 16 maybe very good? What about higher values?
31
+ # clip_sigma >= 16 is roughly all equivalent
32
+ # FIXME: TabDPT stores self.X_test for no reason
33
+ # FIXME: TabDPT creates faiss_knn even if it is never used. Better if `context_size=None` means it is never created.
34
+ # TODO: unit test
35
+ # TODO: memory estimate
36
+ class TabDPTModel(AbstractTorchModel):
37
+ ag_key = "TABDPT"
38
+ ag_name = "TabDPT"
39
+ seed_name = "seed"
40
+ ag_priority = 50
41
+ default_random_seed = 0
42
+
43
+ def __init__(self, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self._feature_generator = None
46
+ self._predict_hps = None
47
+ self._use_flash_og = None
48
+
49
+ def _fit(
50
+ self,
51
+ X: pd.DataFrame,
52
+ y: pd.Series,
53
+ num_cpus: int = 1,
54
+ num_gpus: int = 0,
55
+ **kwargs,
56
+ ):
57
+ from torch.cuda import is_available
58
+
59
+ device = "cuda" if num_gpus != 0 else "cpu"
60
+ if (device == "cuda") and (not is_available()):
61
+ raise AssertionError(
62
+ "Fit specified to use GPU, but CUDA is not available on this machine. "
63
+ "Please switch to CPU usage instead.",
64
+ )
65
+ from tabdpt import TabDPTClassifier, TabDPTRegressor
66
+
67
+ model_cls = (
68
+ TabDPTClassifier
69
+ if self.problem_type in [BINARY, MULTICLASS]
70
+ else TabDPTRegressor
71
+ )
72
+ fit_params, self._predict_hps = self._get_tabdpt_params(num_gpus=num_gpus)
73
+
74
+ X = self.preprocess(X)
75
+ y = y.to_numpy()
76
+ self.model = model_cls(
77
+ device=device,
78
+ **fit_params,
79
+ )
80
+ self.model.fit(X=X, y=y)
81
+
82
+ def _get_tabdpt_params(self, num_gpus: float) -> tuple[dict, dict]:
83
+ model_params = self._get_model_params()
84
+
85
+ valid_predict_params = (self.seed_name, "context_size", "permute_classes", "temperature", "n_ensembles")
86
+
87
+ predict_params = {}
88
+ for hp in valid_predict_params:
89
+ if hp in model_params:
90
+ predict_params[hp] = model_params.pop(hp)
91
+ predict_params.setdefault(self.seed_name, self.default_random_seed)
92
+ predict_params.setdefault("context_size", None)
93
+
94
+ supported_predict_params = (
95
+ (self.seed_name, "context_size", "n_ensembles", "permute_classes", "temperature")
96
+ if self.problem_type in [BINARY, MULTICLASS]
97
+ else (self.seed_name, "context_size", "n_ensembles")
98
+ )
99
+ predict_params = {key: val for key, val in predict_params.items() if key in supported_predict_params}
100
+
101
+ fit_params = model_params
102
+
103
+ fit_params.setdefault("verbose", False)
104
+ fit_params.setdefault("compile", False)
105
+ if fit_params.get("use_flash", True):
106
+ fit_params["use_flash"] = self._use_flash(num_gpus=num_gpus)
107
+ return fit_params, predict_params
108
+
109
+ @staticmethod
110
+ def _use_flash(num_gpus: float) -> bool:
111
+ """Detect if torch's native flash attention is available on the current machine."""
112
+ if num_gpus == 0:
113
+ return False
114
+
115
+ import torch
116
+
117
+ if not torch.cuda.is_available():
118
+ return False
119
+
120
+ device = torch.device("cuda:0")
121
+ capability = torch.cuda.get_device_capability(device)
122
+
123
+ return capability != (7, 5)
124
+
125
+ def _post_fit(self, **kwargs):
126
+ super()._post_fit(**kwargs)
127
+ self._use_flash_og = self.model.use_flash
128
+ return self
129
+
130
+ def get_device(self) -> str:
131
+ return self.model.device
132
+
133
+ def _set_device(self, device: str):
134
+ self.model.to(device)
135
+ if device == "cpu":
136
+ self.model.use_flash = False
137
+ self.model.model.use_flash = False
138
+ else:
139
+ self.model.use_flash = self._use_flash_og
140
+ self.model.model.use_flash = self._use_flash_og
141
+
142
+ def _get_default_resources(self) -> tuple[int, int]:
143
+ # Use only physical cores for better performance based on benchmarks
144
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
145
+
146
+ num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
147
+
148
+ return num_cpus, num_gpus
149
+
150
+ def get_minimum_resources(
151
+ self, is_gpu_available: bool = False
152
+ ) -> dict[str, int | float]:
153
+ return {
154
+ "num_cpus": 1,
155
+ "num_gpus": 0.5 if is_gpu_available else 0,
156
+ }
157
+
158
+ def _predict_proba(self, X, **kwargs) -> np.ndarray:
159
+ X = self.preprocess(X, **kwargs)
160
+
161
+ if self.problem_type in [REGRESSION]:
162
+ y_pred = self.model.predict(X, **self._predict_hps)
163
+ return y_pred
164
+
165
+ y_pred_proba = self.model.ensemble_predict_proba(X, **self._predict_hps)
166
+ return self._convert_proba_to_unified_form(y_pred_proba)
167
+
168
+ def _preprocess(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame:
169
+ """TabDPT requires numpy array as input."""
170
+ X = super()._preprocess(X, **kwargs)
171
+ if self._feature_generator is None:
172
+ self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
173
+ self._feature_generator.fit(X=X)
174
+ if self._feature_generator.features_in:
175
+ X = X.copy()
176
+ X[self._feature_generator.features_in] = self._feature_generator.transform(
177
+ X=X
178
+ )
179
+ return X.to_numpy()
180
+
181
+ @classmethod
182
+ def supported_problem_types(cls) -> list[str] | None:
183
+ return ["binary", "multiclass", "regression"]
184
+
185
+ @classmethod
186
+ def _class_tags(cls):
187
+ return {"can_estimate_memory_usage_static": True}
188
+
189
+ def _more_tags(self) -> dict:
190
+ return {"can_refit_full": True}
191
+
192
+ def _get_default_auxiliary_params(self) -> dict:
193
+ default_auxiliary_params = super()._get_default_auxiliary_params()
194
+ default_auxiliary_params.update(
195
+ {
196
+ "max_rows": 100000, # TODO: Test >100k rows
197
+ "max_features": 2500, # TODO: Test >2500 features
198
+ "max_classes": 10, # TODO: Test >10 classes
199
+ }
200
+ )
201
+ return default_auxiliary_params
202
+
203
+ @classmethod
204
+ def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
205
+ default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
206
+ extra_ag_args_ensemble = {
207
+ "refit_folds": True,
208
+ }
209
+ default_ag_args_ensemble.update(extra_ag_args_ensemble)
210
+ return default_ag_args_ensemble
211
+
212
+ # FIXME: This is copied from TabPFN, but TabDPT is not the same
213
+ @classmethod
214
+ def _estimate_memory_usage_static(
215
+ cls,
216
+ *,
217
+ X: pd.DataFrame,
218
+ hyperparameters: dict | None = None,
219
+ **kwargs,
220
+ ) -> int:
221
+ """Heuristic memory estimate based on TabPFN's memory estimate logic in:
222
+ https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
223
+
224
+ This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
225
+ """
226
+ # TODO: update, this is not correct anymore, consider using internal TabPFN functions directly.
227
+ features_per_group = 3 # Based on TabPFNv2 default (unused)
228
+ n_layers = 12 # Based on TabPFNv2 default
229
+ embedding_size = 192 # Based on TabPFNv2 default
230
+ dtype_byte_size = 2 # Based on TabPFNv2 default
231
+
232
+ model_mem = 14489108 # Based on TabPFNv2 default
233
+
234
+ n_samples, n_features = X.shape[0], min(X.shape[1], 500)
235
+ n_feature_groups = (
236
+ n_features
237
+ ) / features_per_group + 1 # TODO: Unsure how to calculate this
238
+
239
+ X_mem = n_samples * n_feature_groups * dtype_byte_size
240
+ activation_mem = (
241
+ n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
242
+ )
243
+
244
+ baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
245
+
246
+ # Add some buffer to each term + 1 GB overhead to be safe
247
+ memory_estimate = model_mem + 4 * X_mem + 2 * activation_mem + baseline_overhead_mem_est
248
+
249
+ # TabDPT memory estimation is very inaccurate because it is using TabPFN memory estimate. Double it to be safe.
250
+ memory_estimate = memory_estimate * 2
251
+
252
+ # Note: This memory estimate is way off if `context_size` is not None
253
+ return int(memory_estimate)
@@ -10,14 +10,14 @@ import pandas as pd
10
10
 
11
11
  from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
12
12
  from autogluon.common.utils.resource_utils import ResourceManager
13
- from autogluon.core.models import AbstractModel
14
13
  from autogluon.tabular import __version__
14
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
19
  # TODO: Verify if crashes when weights are not yet downloaded and fit in parallel
20
- class TabICLModel(AbstractModel):
20
+ class TabICLModel(AbstractTorchModel):
21
21
  """
22
22
  TabICL is a foundation model for tabular data using in-context learning
23
23
  that is scalable to larger datasets than TabPFNv2. It is pretrained purely on synthetic data.
@@ -35,6 +35,7 @@ class TabICLModel(AbstractModel):
35
35
  ag_key = "TABICL"
36
36
  ag_name = "TabICL"
37
37
  ag_priority = 65
38
+ seed_name = "random_state"
38
39
 
39
40
  def get_model_cls(self):
40
41
  from tabicl import TabICLClassifier
@@ -89,7 +90,6 @@ class TabICLModel(AbstractModel):
89
90
  **hyp,
90
91
  device=device,
91
92
  n_jobs=num_cpus,
92
- random_state=self.random_seed,
93
93
  )
94
94
  X = self.preprocess(X)
95
95
  self.model = self.model.fit(
@@ -97,8 +97,18 @@ class TabICLModel(AbstractModel):
97
97
  y=y,
98
98
  )
99
99
 
100
- def _get_random_seed_from_hyperparameters(self, hyperparameters: dict) -> int | None | str:
101
- return hyperparameters.get("random_state", "N/A")
100
+ def get_device(self) -> str:
101
+ return self.model.device_.type
102
+
103
+ # TODO: Better to have an official TabICL method for this
104
+ def _set_device(self, device: str):
105
+ device = self.to_torch_device(device)
106
+ self.model.device_ = device
107
+ self.model.device = self.model.device_.type
108
+ self.model.model_ = self.model.model_.to(self.model.device_)
109
+ self.model.inference_config_.COL_CONFIG.device = self.model.device_
110
+ self.model.inference_config_.ROW_CONFIG.device = self.model.device_
111
+ self.model.inference_config_.ICL_CONFIG.device = self.model.device_
102
112
 
103
113
  def _get_default_auxiliary_params(self) -> dict:
104
114
  default_auxiliary_params = super()._get_default_auxiliary_params()
@@ -15,13 +15,13 @@ import time
15
15
  import pandas as pd
16
16
 
17
17
  from autogluon.common.utils.resource_utils import ResourceManager
18
- from autogluon.core.models import AbstractModel
19
18
  from autogluon.tabular import __version__
19
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- class TabMModel(AbstractModel):
24
+ class TabMModel(AbstractTorchModel):
25
25
  """
26
26
  TabM is an efficient ensemble of MLPs that is trained simultaneously with mostly shared parameters.
27
27
 
@@ -39,6 +39,7 @@ class TabMModel(AbstractModel):
39
39
  ag_key = "TABM"
40
40
  ag_name = "TabM"
41
41
  ag_priority = 85
42
+ seed_name = "random_state"
42
43
 
43
44
  def __init__(self, **kwargs):
44
45
  super().__init__(**kwargs)
@@ -86,7 +87,7 @@ class TabMModel(AbstractModel):
86
87
  if X_val is None:
87
88
  from autogluon.core.utils import generate_train_test_split
88
89
 
89
- X_train, X_val, y_train, y_val = generate_train_test_split(
90
+ X, X_val, y, y_val = generate_train_test_split(
90
91
  X=X,
91
92
  y=y,
92
93
  problem_type=self.problem_type,
@@ -97,7 +98,7 @@ class TabMModel(AbstractModel):
97
98
  hyp = self._get_model_params()
98
99
  bool_to_cat = hyp.pop("bool_to_cat", True)
99
100
 
100
- X = self.preprocess(X, is_train=True, bool_to_cat=bool_to_cat)
101
+ X = self.preprocess(X, y=y, is_train=True, bool_to_cat=bool_to_cat)
101
102
  if X_val is not None:
102
103
  X_val = self.preprocess(X_val)
103
104
 
@@ -106,7 +107,6 @@ class TabMModel(AbstractModel):
106
107
  device=device,
107
108
  problem_type=self.problem_type,
108
109
  early_stopping_metric=self.stopping_metric,
109
- random_state=self.random_seed,
110
110
  **hyp,
111
111
  )
112
112
 
@@ -142,8 +142,13 @@ class TabMModel(AbstractModel):
142
142
 
143
143
  return X
144
144
 
145
- def _get_random_seed_from_hyperparameters(self, hyperparameters: dict) -> int | None | str:
146
- return hyperparameters.get("random_state", "N/A")
145
+ def get_device(self) -> str:
146
+ return self.model.device_.type
147
+
148
+ def _set_device(self, device: str):
149
+ device = self.to_torch_device(device)
150
+ self.model.device_ = device
151
+ self.model.model_ = self.model.model_.to(device)
147
152
 
148
153
  @classmethod
149
154
  def supported_problem_types(cls) -> list[str] | None:
@@ -258,6 +263,15 @@ class TabMModel(AbstractModel):
258
263
 
259
264
  return mem_total
260
265
 
266
+ def _get_default_auxiliary_params(self) -> dict:
267
+ default_auxiliary_params = super()._get_default_auxiliary_params()
268
+ default_auxiliary_params.update(
269
+ {
270
+ "max_batch_size": 16384, # avoid excessive VRAM usage
271
+ }
272
+ )
273
+ return default_auxiliary_params
274
+
261
275
  @classmethod
262
276
  def get_tabm_auto_batch_size(cls, n_samples: int) -> int:
263
277
  # by Yury Gorishniy, inferred from the choices in the TabM paper.
@@ -275,7 +289,10 @@ class TabMModel(AbstractModel):
275
289
 
276
290
  @classmethod
277
291
  def _class_tags(cls):
278
- return {"can_estimate_memory_usage_static": True}
292
+ return {
293
+ "can_estimate_memory_usage_static": True,
294
+ "reset_torch_threads": True,
295
+ }
279
296
 
280
297
  def _more_tags(self) -> dict:
281
298
  # TODO: Need to add train params support, track best epoch
@@ -42,6 +42,7 @@ class TabPFNMixModel(AbstractModel):
42
42
  ag_key = "TABPFNMIX"
43
43
  ag_name = "TabPFNMix"
44
44
  ag_priority = 45
45
+ seed_name = "random_state"
45
46
 
46
47
  weights_file_name = "model.pt"
47
48
 
@@ -123,6 +124,7 @@ class TabPFNMixModel(AbstractModel):
123
124
  raise AssertionError(f"Max allowed classes for the model is {max_classes}, " f"but found {self.num_classes} classes.")
124
125
 
125
126
  params = self._get_model_params()
127
+ random_state = params.pop(self.seed_name, self.default_random_seed)
126
128
  sample_rows = ag_params.get("sample_rows", None)
127
129
  sample_rows_val = ag_params.get("sample_rows_val", None)
128
130
  max_rows = ag_params.get("max_rows", None)
@@ -133,11 +135,11 @@ class TabPFNMixModel(AbstractModel):
133
135
 
134
136
  # TODO: Make sample_rows generic
135
137
  if sample_rows is not None and isinstance(sample_rows, int) and len(X) > sample_rows:
136
- X, y = self._subsample_data(X=X, y=y, num_rows=sample_rows)
138
+ X, y = self._subsample_data(X=X, y=y, num_rows=sample_rows, random_state=random_state)
137
139
 
138
140
  # TODO: Make sample_rows generic
139
141
  if X_val is not None and y_val is not None and sample_rows_val is not None and isinstance(sample_rows_val, int) and len(X_val) > sample_rows_val:
140
- X_val, y_val = self._subsample_data(X=X_val, y=y_val, num_rows=sample_rows_val)
142
+ X_val, y_val = self._subsample_data(X=X_val, y=y_val, num_rows=sample_rows_val, random_state=random_state)
141
143
 
142
144
  from ._internal.core.enums import Task
143
145
  if self.problem_type in [REGRESSION, QUANTILE]:
@@ -178,7 +180,7 @@ class TabPFNMixModel(AbstractModel):
178
180
  elif weights_path is not None:
179
181
  logger.log(15, f'\tLoading pre-trained weights from file... (weights_path="{weights_path}")')
180
182
 
181
- cfg = ConfigRun(hyperparams=params, task=task, device=device, seed=self.random_seed)
183
+ cfg = ConfigRun(hyperparams=params, task=task, device=device, seed=random_state)
182
184
 
183
185
  if cfg.hyperparams["max_epochs"] == 0 and cfg.hyperparams["n_ensembles"] != 1:
184
186
  logger.log(
@@ -242,14 +244,14 @@ class TabPFNMixModel(AbstractModel):
242
244
  return self
243
245
 
244
246
  # TODO: Make this generic by creating a generic `preprocess_train` and putting this logic prior to `_preprocess`.
245
- def _subsample_data(self, X: pd.DataFrame, y: pd.Series, num_rows: int) -> (pd.DataFrame, pd.Series):
247
+ def _subsample_data(self, X: pd.DataFrame, y: pd.Series, num_rows: int, random_state: int | None = 0) -> (pd.DataFrame, pd.Series):
246
248
  num_rows_to_drop = len(X) - num_rows
247
249
  X, _, y, _ = generate_train_test_split(
248
250
  X=X,
249
251
  y=y,
250
252
  problem_type=self.problem_type,
251
253
  test_size=num_rows_to_drop,
252
- random_state=self.random_seed,
254
+ random_state=random_state,
253
255
  min_cls_count_train=1,
254
256
  )
255
257
  return X, y