autogluon.tabular 1.3.2b20250711__py3-none-any.whl → 1.3.2b20250712__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -1
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +376 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/METADATA +13 -15
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/RECORD +21 -14
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/zip-safe +0 -0
@@ -0,0 +1,747 @@
|
|
1
|
+
"""Random Forest implementation that uses TabPFN at the leaf nodes."""
|
2
|
+
|
3
|
+
# Copyright (c) Prior Labs GmbH 2025.
|
4
|
+
# Licensed under the Apache License, Version 2.0
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import time
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import torch
|
13
|
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
14
|
+
from sklearn.utils.multiclass import unique_labels
|
15
|
+
|
16
|
+
from .sklearn_compat import validate_data
|
17
|
+
from .sklearn_based_decision_tree_tabpfn import (
|
18
|
+
DecisionTreeTabPFNClassifier,
|
19
|
+
DecisionTreeTabPFNRegressor,
|
20
|
+
)
|
21
|
+
|
22
|
+
logging.basicConfig(
|
23
|
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
24
|
+
level=logging.INFO,
|
25
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
26
|
+
)
|
27
|
+
logger = logging.getLogger("RF-PFN")
|
28
|
+
|
29
|
+
|
30
|
+
def softmax_numpy(logits: np.ndarray) -> np.ndarray:
|
31
|
+
"""Apply softmax to numpy array of logits.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
logits: Input logits array
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
Probabilities after softmax
|
38
|
+
"""
|
39
|
+
exp_logits = np.exp(logits) # Apply exponential to each logit
|
40
|
+
sum_exp_logits = np.sum(
|
41
|
+
exp_logits,
|
42
|
+
axis=-1,
|
43
|
+
keepdims=True,
|
44
|
+
) # Sum of exponentials across classes
|
45
|
+
return exp_logits / sum_exp_logits # Normalize to get probabilities
|
46
|
+
|
47
|
+
|
48
|
+
class RandomForestTabPFNBase:
|
49
|
+
"""Base Class for common functionalities."""
|
50
|
+
|
51
|
+
def get_n_estimators(self, X: np.ndarray) -> int:
|
52
|
+
"""Get the number of estimators to use.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
X: Input features
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Number of estimators
|
59
|
+
"""
|
60
|
+
return self.n_estimators
|
61
|
+
|
62
|
+
def _validate_tabpfn(self):
|
63
|
+
"""Validate that tabpfn is not None and is of the correct type.
|
64
|
+
|
65
|
+
Raises:
|
66
|
+
ValueError: If tabpfn is None
|
67
|
+
TypeError: If tabpfn is not of the expected type
|
68
|
+
"""
|
69
|
+
if self.tabpfn is None:
|
70
|
+
raise ValueError(
|
71
|
+
f"The tabpfn parameter cannot be None. Please provide a TabPFN{'Classifier' if self.task_type == 'multiclass' else 'Regressor'} instance.",
|
72
|
+
)
|
73
|
+
|
74
|
+
if self.task_type == "multiclass":
|
75
|
+
# For classifier, check for predict_proba method
|
76
|
+
if not hasattr(self.tabpfn, "predict_proba"):
|
77
|
+
raise TypeError(
|
78
|
+
f"Expected a TabPFNClassifier instance with predict_proba method, but got {type(self.tabpfn).__name__}",
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
# For regressor, check for predict method but no predict_proba
|
82
|
+
if not hasattr(self.tabpfn, "predict"):
|
83
|
+
raise TypeError(
|
84
|
+
f"Expected a TabPFNRegressor instance with predict method, but got {type(self.tabpfn).__name__}",
|
85
|
+
)
|
86
|
+
if hasattr(self.tabpfn, "predict_proba"):
|
87
|
+
raise TypeError(
|
88
|
+
"Expected a TabPFNRegressor instance, but got a classifier with predict_proba method. "
|
89
|
+
"Please use TabPFNRegressor with RandomForestTabPFNRegressor.",
|
90
|
+
)
|
91
|
+
|
92
|
+
def fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None):
|
93
|
+
"""Fits RandomForestTabPFN.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
X: Feature training data
|
97
|
+
y: Label training data
|
98
|
+
sample_weight: Weights of each sample
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
Fitted model
|
102
|
+
|
103
|
+
Raises:
|
104
|
+
ValueError: If n_estimators is not positive
|
105
|
+
ValueError: If tabpfn is None
|
106
|
+
TypeError: If tabpfn is not of the expected type
|
107
|
+
"""
|
108
|
+
# Validate tabpfn parameter
|
109
|
+
self._validate_tabpfn()
|
110
|
+
|
111
|
+
self.estimator = self.init_base_estimator()
|
112
|
+
self.X = X
|
113
|
+
self.n_estimators = self.get_n_estimators(X)
|
114
|
+
|
115
|
+
X, y = validate_data(
|
116
|
+
self,
|
117
|
+
X,
|
118
|
+
y,
|
119
|
+
ensure_all_finite=False,
|
120
|
+
)
|
121
|
+
|
122
|
+
if self.task_type == "multiclass":
|
123
|
+
self.classes_ = unique_labels(y)
|
124
|
+
self.n_classes_ = len(self.classes_)
|
125
|
+
|
126
|
+
# Special case for depth 0 - just use TabPFN directly
|
127
|
+
if self.max_depth == 0:
|
128
|
+
self.tabpfn.fit(X, y)
|
129
|
+
return self
|
130
|
+
|
131
|
+
# Initialize the tree estimators - convert to Python int to ensure client compatibility
|
132
|
+
n_estimators = (
|
133
|
+
int(self.n_estimators)
|
134
|
+
if hasattr(self.n_estimators, "item")
|
135
|
+
else self.n_estimators
|
136
|
+
)
|
137
|
+
if n_estimators <= 0:
|
138
|
+
raise ValueError(
|
139
|
+
f"n_estimators must be greater than zero, got {n_estimators}",
|
140
|
+
)
|
141
|
+
|
142
|
+
# Initialize estimators list
|
143
|
+
self.estimators_ = []
|
144
|
+
|
145
|
+
# Generate bootstrapped datasets and fit trees
|
146
|
+
for i in range(n_estimators):
|
147
|
+
# Clone the base estimator
|
148
|
+
tree = self.init_base_estimator()
|
149
|
+
|
150
|
+
# Bootstrap sample if requested (like in RandomForest)
|
151
|
+
if self.bootstrap:
|
152
|
+
n_samples = X.shape[0]
|
153
|
+
|
154
|
+
# Convert max_samples to Python int if needed for client compatibility
|
155
|
+
max_samples = self.max_samples
|
156
|
+
if max_samples is not None and hasattr(max_samples, "item"):
|
157
|
+
max_samples = int(max_samples)
|
158
|
+
|
159
|
+
# Calculate sample size (convert to Python int for client compatibility)
|
160
|
+
sample_size = (
|
161
|
+
n_samples if max_samples is None else int(max_samples * n_samples)
|
162
|
+
)
|
163
|
+
sample_size = (
|
164
|
+
int(sample_size) if hasattr(sample_size, "item") else sample_size
|
165
|
+
)
|
166
|
+
|
167
|
+
# Generate random indices for bootstrapping
|
168
|
+
indices = np.random.choice(
|
169
|
+
n_samples,
|
170
|
+
size=sample_size,
|
171
|
+
replace=True,
|
172
|
+
)
|
173
|
+
|
174
|
+
# Handle pandas DataFrame properly by converting to numpy or using iloc
|
175
|
+
if hasattr(X, "iloc") and hasattr(
|
176
|
+
X,
|
177
|
+
"values",
|
178
|
+
): # It's a pandas DataFrame
|
179
|
+
X_boot = (
|
180
|
+
X.iloc[indices].values
|
181
|
+
if hasattr(X, "values")
|
182
|
+
else X.iloc[indices]
|
183
|
+
)
|
184
|
+
y_boot = (
|
185
|
+
y[indices]
|
186
|
+
if isinstance(y, np.ndarray)
|
187
|
+
else y.iloc[indices]
|
188
|
+
if hasattr(y, "iloc")
|
189
|
+
else np.array(y)[indices]
|
190
|
+
)
|
191
|
+
else: # It's a numpy array or similar
|
192
|
+
X_boot = X[indices]
|
193
|
+
y_boot = y[indices]
|
194
|
+
else:
|
195
|
+
X_boot = X
|
196
|
+
y_boot = y
|
197
|
+
|
198
|
+
# Fit the tree on bootstrapped data
|
199
|
+
tree.fit(X_boot, y_boot)
|
200
|
+
self.estimators_.append(tree)
|
201
|
+
|
202
|
+
# Track features seen during fit
|
203
|
+
self.n_features_in_ = X.shape[1]
|
204
|
+
|
205
|
+
# Set flag to indicate successful fit
|
206
|
+
self._fitted = True
|
207
|
+
|
208
|
+
return self
|
209
|
+
|
210
|
+
|
211
|
+
class RandomForestTabPFNClassifier(RandomForestTabPFNBase, RandomForestClassifier):
|
212
|
+
"""RandomForestTabPFNClassifier implements Random Forest using TabPFN at leaf nodes.
|
213
|
+
|
214
|
+
This classifier combines decision trees with TabPFN models at the leaf nodes for
|
215
|
+
improved performance on tabular data. It extends scikit-learn's RandomForestClassifier
|
216
|
+
with TabPFN's neural network capabilities.
|
217
|
+
|
218
|
+
Parameters:
|
219
|
+
tabpfn: TabPFNClassifier instance to use at leaf nodes
|
220
|
+
n_jobs: Number of parallel jobs
|
221
|
+
categorical_features: List of categorical feature indices
|
222
|
+
show_progress: Whether to display progress during fitting
|
223
|
+
verbose: Verbosity level (0=quiet, >0=verbose)
|
224
|
+
adaptive_tree: Whether to use adaptive tree-based method
|
225
|
+
fit_nodes: Whether to fit the leaf node models
|
226
|
+
adaptive_tree_overwrite_metric: Metric used for adaptive node fitting
|
227
|
+
adaptive_tree_test_size: Test size for adaptive node fitting
|
228
|
+
adaptive_tree_min_train_samples: Minimum samples for training leaf nodes
|
229
|
+
adaptive_tree_max_train_samples: Maximum samples for training leaf nodes
|
230
|
+
adaptive_tree_min_valid_samples_fraction_of_train: Min fraction of validation samples
|
231
|
+
preprocess_X_once: Whether to preprocess X only once
|
232
|
+
max_predict_time: Maximum time allowed for prediction (seconds)
|
233
|
+
rf_average_logits: Whether to average logits instead of probabilities
|
234
|
+
dt_average_logits: Whether to average logits in decision trees
|
235
|
+
adaptive_tree_skip_class_missing: Whether to skip classes missing in nodes
|
236
|
+
n_estimators: Number of trees in the forest
|
237
|
+
criterion: Function to measure split quality
|
238
|
+
max_depth: Maximum depth of the trees
|
239
|
+
min_samples_split: Minimum samples required to split a node
|
240
|
+
min_samples_leaf: Minimum samples required at a leaf node
|
241
|
+
min_weight_fraction_leaf: Minimum weighted fraction of sum total
|
242
|
+
max_features: Number of features to consider for best split
|
243
|
+
max_leaf_nodes: Maximum number of leaf nodes
|
244
|
+
min_impurity_decrease: Minimum impurity decrease required for split
|
245
|
+
bootstrap: Whether to use bootstrap samples
|
246
|
+
oob_score: Whether to use out-of-bag samples
|
247
|
+
random_state: Controls randomness of the estimator
|
248
|
+
warm_start: Whether to reuse previous solution
|
249
|
+
class_weight: Weights associated with classes
|
250
|
+
ccp_alpha: Complexity parameter for minimal cost-complexity pruning
|
251
|
+
max_samples: Number of samples to draw to train each tree
|
252
|
+
"""
|
253
|
+
|
254
|
+
task_type = "multiclass"
|
255
|
+
|
256
|
+
def __init__(
|
257
|
+
self,
|
258
|
+
tabpfn=None,
|
259
|
+
n_jobs=1,
|
260
|
+
categorical_features=None,
|
261
|
+
show_progress=False,
|
262
|
+
verbose=0,
|
263
|
+
adaptive_tree=True,
|
264
|
+
fit_nodes=True,
|
265
|
+
adaptive_tree_overwrite_metric="log_loss",
|
266
|
+
adaptive_tree_test_size=0.2,
|
267
|
+
adaptive_tree_min_train_samples=100,
|
268
|
+
adaptive_tree_max_train_samples=5000,
|
269
|
+
adaptive_tree_min_valid_samples_fraction_of_train=0.2,
|
270
|
+
preprocess_X_once=False,
|
271
|
+
max_predict_time=60,
|
272
|
+
rf_average_logits=True,
|
273
|
+
dt_average_logits=True,
|
274
|
+
adaptive_tree_skip_class_missing=True,
|
275
|
+
# Added to make cloneable.
|
276
|
+
n_estimators=100,
|
277
|
+
criterion="gini",
|
278
|
+
max_depth=5,
|
279
|
+
min_samples_split=1000,
|
280
|
+
min_samples_leaf=5,
|
281
|
+
min_weight_fraction_leaf=0.0,
|
282
|
+
max_features="sqrt",
|
283
|
+
max_leaf_nodes=None,
|
284
|
+
min_impurity_decrease=0.0,
|
285
|
+
bootstrap=True,
|
286
|
+
oob_score=False,
|
287
|
+
random_state=None,
|
288
|
+
warm_start=False,
|
289
|
+
class_weight=None,
|
290
|
+
ccp_alpha=0.0,
|
291
|
+
max_samples=None,
|
292
|
+
):
|
293
|
+
super().__init__(
|
294
|
+
n_estimators=n_estimators,
|
295
|
+
criterion=criterion,
|
296
|
+
max_depth=max_depth,
|
297
|
+
min_samples_split=min_samples_split,
|
298
|
+
min_samples_leaf=min_samples_leaf,
|
299
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
300
|
+
verbose=verbose,
|
301
|
+
n_jobs=n_jobs,
|
302
|
+
max_features=max_features,
|
303
|
+
max_leaf_nodes=max_leaf_nodes,
|
304
|
+
min_impurity_decrease=min_impurity_decrease,
|
305
|
+
bootstrap=bootstrap,
|
306
|
+
oob_score=oob_score,
|
307
|
+
random_state=random_state,
|
308
|
+
warm_start=warm_start,
|
309
|
+
class_weight=class_weight,
|
310
|
+
ccp_alpha=ccp_alpha,
|
311
|
+
max_samples=max_samples,
|
312
|
+
)
|
313
|
+
|
314
|
+
if tabpfn is None:
|
315
|
+
raise ValueError(
|
316
|
+
"The tabpfn parameter cannot be None. Please provide a TabPFNClassifier instance.",
|
317
|
+
)
|
318
|
+
|
319
|
+
# Check if tabpfn is a classifier instance
|
320
|
+
if not hasattr(tabpfn, "predict_proba"):
|
321
|
+
raise TypeError(
|
322
|
+
f"Expected a TabPFNClassifier instance with predict_proba method, but got {type(tabpfn).__name__}",
|
323
|
+
)
|
324
|
+
|
325
|
+
self.tabpfn = tabpfn
|
326
|
+
|
327
|
+
self.categorical_features = categorical_features
|
328
|
+
self.show_progress = show_progress
|
329
|
+
self.verbose = verbose
|
330
|
+
self.n_jobs = n_jobs
|
331
|
+
self.adaptive_tree = adaptive_tree
|
332
|
+
self.fit_nodes = fit_nodes
|
333
|
+
self.adaptive_tree_overwrite_metric = adaptive_tree_overwrite_metric
|
334
|
+
self.adaptive_tree_test_size = adaptive_tree_test_size
|
335
|
+
self.adaptive_tree_min_train_samples = adaptive_tree_min_train_samples
|
336
|
+
self.adaptive_tree_max_train_samples = adaptive_tree_max_train_samples
|
337
|
+
self.adaptive_tree_min_valid_samples_fraction_of_train = (
|
338
|
+
adaptive_tree_min_valid_samples_fraction_of_train
|
339
|
+
)
|
340
|
+
self.preprocess_X_once = preprocess_X_once
|
341
|
+
self.max_predict_time = max_predict_time
|
342
|
+
self.rf_average_logits = rf_average_logits
|
343
|
+
self.dt_average_logits = dt_average_logits
|
344
|
+
self.adaptive_tree_skip_class_missing = adaptive_tree_skip_class_missing
|
345
|
+
self.n_estimators = n_estimators
|
346
|
+
|
347
|
+
def _more_tags(self):
|
348
|
+
return {
|
349
|
+
"allow_nan": True,
|
350
|
+
}
|
351
|
+
|
352
|
+
def __sklearn_tags__(self):
|
353
|
+
tags = super().__sklearn_tags__()
|
354
|
+
tags.input_tags.allow_nan = True
|
355
|
+
tags.estimator_type = "regressor"
|
356
|
+
if self.task_type == "multiclass":
|
357
|
+
tags.estimator_type = "classifier"
|
358
|
+
else:
|
359
|
+
tags.estimator_type = "regressor"
|
360
|
+
return tags
|
361
|
+
|
362
|
+
def init_base_estimator(self):
|
363
|
+
"""Initialize a base decision tree estimator.
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
A new DecisionTreeTabPFNClassifier instance
|
367
|
+
"""
|
368
|
+
return DecisionTreeTabPFNClassifier(
|
369
|
+
tabpfn=self.tabpfn,
|
370
|
+
min_samples_split=self.min_samples_split,
|
371
|
+
min_samples_leaf=self.min_samples_leaf,
|
372
|
+
max_features=self.max_features,
|
373
|
+
random_state=self.random_state,
|
374
|
+
categorical_features=self.categorical_features,
|
375
|
+
max_depth=self.max_depth,
|
376
|
+
show_progress=self.show_progress,
|
377
|
+
adaptive_tree=self.adaptive_tree,
|
378
|
+
fit_nodes=self.fit_nodes,
|
379
|
+
verbose=self.verbose,
|
380
|
+
adaptive_tree_test_size=self.adaptive_tree_test_size,
|
381
|
+
adaptive_tree_overwrite_metric=self.adaptive_tree_overwrite_metric,
|
382
|
+
adaptive_tree_min_train_samples=self.adaptive_tree_min_train_samples,
|
383
|
+
adaptive_tree_max_train_samples=self.adaptive_tree_max_train_samples,
|
384
|
+
adaptive_tree_min_valid_samples_fraction_of_train=self.adaptive_tree_min_valid_samples_fraction_of_train,
|
385
|
+
average_logits=self.dt_average_logits,
|
386
|
+
adaptive_tree_skip_class_missing=self.adaptive_tree_skip_class_missing,
|
387
|
+
)
|
388
|
+
|
389
|
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
390
|
+
"""Predict class for X.
|
391
|
+
|
392
|
+
The predicted class of an input sample is a vote by the trees in
|
393
|
+
the forest, weighted by their probability estimates. That is,
|
394
|
+
the predicted class is the one with highest mean probability
|
395
|
+
estimate across the trees.
|
396
|
+
|
397
|
+
Parameters:
|
398
|
+
X: {array-like, sparse matrix} of shape (n_samples, n_features)
|
399
|
+
The input samples.
|
400
|
+
|
401
|
+
Returns:
|
402
|
+
y: ndarray of shape (n_samples,)
|
403
|
+
The predicted classes.
|
404
|
+
|
405
|
+
Raises:
|
406
|
+
ValueError: If model is not fitted
|
407
|
+
"""
|
408
|
+
# Get class probabilities
|
409
|
+
proba = self.predict_proba(X)
|
410
|
+
|
411
|
+
# Return class with highest probability
|
412
|
+
if hasattr(self, "classes_"):
|
413
|
+
return self.classes_.take(np.argmax(proba, axis=1), axis=0)
|
414
|
+
else:
|
415
|
+
return np.argmax(proba, axis=1)
|
416
|
+
|
417
|
+
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
418
|
+
"""Predict class probabilities for X.
|
419
|
+
|
420
|
+
The predicted class probabilities of an input sample are computed as
|
421
|
+
the mean predicted class probabilities of the trees in the forest.
|
422
|
+
|
423
|
+
Parameters:
|
424
|
+
X: {array-like, sparse matrix} of shape (n_samples, n_features)
|
425
|
+
The input samples.
|
426
|
+
|
427
|
+
Returns:
|
428
|
+
p: ndarray of shape (n_samples, n_classes)
|
429
|
+
The class probabilities of the input samples.
|
430
|
+
|
431
|
+
Raises:
|
432
|
+
ValueError: If model is not fitted
|
433
|
+
"""
|
434
|
+
# Check if fitted
|
435
|
+
if not hasattr(self, "_fitted") or not self._fitted:
|
436
|
+
raise ValueError(
|
437
|
+
"This RandomForestTabPFNClassifier instance is not fitted yet. "
|
438
|
+
"Call 'fit' with appropriate arguments before using this estimator.",
|
439
|
+
)
|
440
|
+
|
441
|
+
# Convert input if needed
|
442
|
+
if torch.is_tensor(X):
|
443
|
+
X = X.numpy()
|
444
|
+
|
445
|
+
# Special case for depth 0 - TabPFN can handle missing values directly
|
446
|
+
if self.max_depth == 0:
|
447
|
+
# No need for preprocessing - TabPFN handles NaN values
|
448
|
+
return self.tabpfn.predict_proba(X)
|
449
|
+
|
450
|
+
# First collect all the classes from all estimators to ensure we handle all possible classes
|
451
|
+
if not hasattr(self, "classes_"):
|
452
|
+
all_classes_sets = [
|
453
|
+
set(np.unique(estimator.classes_)) for estimator in self.estimators_
|
454
|
+
]
|
455
|
+
all_classes = sorted(set().union(*all_classes_sets))
|
456
|
+
self.classes_ = np.array(all_classes)
|
457
|
+
self.n_classes_ = len(self.classes_)
|
458
|
+
|
459
|
+
# Initialize probabilities array
|
460
|
+
n_samples = X.shape[0]
|
461
|
+
all_proba = np.zeros((n_samples, self.n_classes_), dtype=np.float64)
|
462
|
+
|
463
|
+
# Accumulate predictions from trees
|
464
|
+
start_time = time.time()
|
465
|
+
evaluated_estimators = 0
|
466
|
+
|
467
|
+
for estimator in self.estimators_:
|
468
|
+
# Get predictions from this tree
|
469
|
+
proba = estimator.predict_proba(X)
|
470
|
+
|
471
|
+
# If this estimator has fewer classes than the overall set, expand it
|
472
|
+
if proba.shape[1] < self.n_classes_:
|
473
|
+
expanded_proba = np.zeros(
|
474
|
+
(n_samples, self.n_classes_),
|
475
|
+
dtype=np.float64,
|
476
|
+
)
|
477
|
+
for i, class_val in enumerate(estimator.classes_):
|
478
|
+
# Find the index of this class in the overall classes array
|
479
|
+
idx = np.where(self.classes_ == class_val)[0][0]
|
480
|
+
expanded_proba[:, idx] = proba[:, i]
|
481
|
+
proba = expanded_proba
|
482
|
+
|
483
|
+
# Convert to logits if needed
|
484
|
+
if self.rf_average_logits:
|
485
|
+
proba = np.log(proba + 1e-10) # Add small constant to avoid log(0)
|
486
|
+
|
487
|
+
# Accumulate
|
488
|
+
all_proba += proba
|
489
|
+
|
490
|
+
# Check timeout
|
491
|
+
evaluated_estimators += 1
|
492
|
+
time_elapsed = time.time() - start_time
|
493
|
+
if time_elapsed > self.max_predict_time and self.max_predict_time > 0:
|
494
|
+
break
|
495
|
+
|
496
|
+
# Average probabilities
|
497
|
+
all_proba /= evaluated_estimators
|
498
|
+
|
499
|
+
# Convert back from logits if needed
|
500
|
+
if self.rf_average_logits:
|
501
|
+
all_proba = softmax_numpy(all_proba)
|
502
|
+
|
503
|
+
return all_proba
|
504
|
+
|
505
|
+
|
506
|
+
class RandomForestTabPFNRegressor(RandomForestTabPFNBase, RandomForestRegressor):
|
507
|
+
"""RandomForestTabPFNRegressor implements a Random Forest using TabPFN at leaf nodes.
|
508
|
+
|
509
|
+
This regressor combines decision trees with TabPFN models at the leaf nodes for
|
510
|
+
improved regression performance on tabular data. It extends scikit-learn's
|
511
|
+
RandomForestRegressor with TabPFN's neural network capabilities.
|
512
|
+
|
513
|
+
Parameters:
|
514
|
+
tabpfn: TabPFNRegressor instance to use at leaf nodes
|
515
|
+
n_jobs: Number of parallel jobs
|
516
|
+
categorical_features: List of categorical feature indices
|
517
|
+
show_progress: Whether to display progress during fitting
|
518
|
+
verbose: Verbosity level (0=quiet, >0=verbose)
|
519
|
+
adaptive_tree: Whether to use adaptive tree-based method
|
520
|
+
fit_nodes: Whether to fit the leaf node models
|
521
|
+
adaptive_tree_overwrite_metric: Metric used for adaptive node fitting
|
522
|
+
adaptive_tree_test_size: Test size for adaptive node fitting
|
523
|
+
adaptive_tree_min_train_samples: Minimum samples for training leaf nodes
|
524
|
+
adaptive_tree_max_train_samples: Maximum samples for training leaf nodes
|
525
|
+
adaptive_tree_min_valid_samples_fraction_of_train: Min fraction of validation samples
|
526
|
+
preprocess_X_once: Whether to preprocess X only once
|
527
|
+
max_predict_time: Maximum time allowed for prediction (seconds)
|
528
|
+
rf_average_logits: Whether to average logits instead of raw predictions
|
529
|
+
n_estimators: Number of trees in the forest
|
530
|
+
criterion: Function to measure split quality
|
531
|
+
max_depth: Maximum depth of the trees
|
532
|
+
min_samples_split: Minimum samples required to split a node
|
533
|
+
min_samples_leaf: Minimum samples required at a leaf node
|
534
|
+
min_weight_fraction_leaf: Minimum weighted fraction of sum total
|
535
|
+
max_features: Number of features to consider for best split
|
536
|
+
max_leaf_nodes: Maximum number of leaf nodes
|
537
|
+
min_impurity_decrease: Minimum impurity decrease required for split
|
538
|
+
bootstrap: Whether to use bootstrap samples
|
539
|
+
oob_score: Whether to use out-of-bag samples
|
540
|
+
random_state: Controls randomness of the estimator
|
541
|
+
warm_start: Whether to reuse previous solution
|
542
|
+
ccp_alpha: Complexity parameter for minimal cost-complexity pruning
|
543
|
+
max_samples: Number of samples to draw to train each tree
|
544
|
+
"""
|
545
|
+
|
546
|
+
task_type = "regression"
|
547
|
+
|
548
|
+
def _more_tags(self):
|
549
|
+
return {
|
550
|
+
"allow_nan": True,
|
551
|
+
}
|
552
|
+
|
553
|
+
def __sklearn_tags__(self):
|
554
|
+
tags = super().__sklearn_tags__()
|
555
|
+
tags.input_tags.allow_nan = True
|
556
|
+
tags.estimator_type = "regressor"
|
557
|
+
return tags
|
558
|
+
|
559
|
+
def __init__(
|
560
|
+
self,
|
561
|
+
tabpfn=None,
|
562
|
+
n_jobs=1,
|
563
|
+
categorical_features=None,
|
564
|
+
show_progress=False,
|
565
|
+
verbose=0,
|
566
|
+
adaptive_tree=True,
|
567
|
+
fit_nodes=True,
|
568
|
+
adaptive_tree_overwrite_metric="rmse",
|
569
|
+
adaptive_tree_test_size=0.2,
|
570
|
+
adaptive_tree_min_train_samples=100,
|
571
|
+
adaptive_tree_max_train_samples=5000,
|
572
|
+
adaptive_tree_min_valid_samples_fraction_of_train=0.2,
|
573
|
+
preprocess_X_once=False,
|
574
|
+
max_predict_time=-1,
|
575
|
+
rf_average_logits=False,
|
576
|
+
# Added to make cloneable.
|
577
|
+
n_estimators=16,
|
578
|
+
criterion="friedman_mse",
|
579
|
+
max_depth=5,
|
580
|
+
min_samples_split=300,
|
581
|
+
min_samples_leaf=5,
|
582
|
+
min_weight_fraction_leaf=0.0,
|
583
|
+
max_features="sqrt",
|
584
|
+
max_leaf_nodes=None,
|
585
|
+
min_impurity_decrease=0.0,
|
586
|
+
bootstrap=True,
|
587
|
+
oob_score=False,
|
588
|
+
random_state=None,
|
589
|
+
warm_start=False,
|
590
|
+
ccp_alpha=0.0,
|
591
|
+
max_samples=None,
|
592
|
+
):
|
593
|
+
super().__init__(
|
594
|
+
n_estimators=n_estimators,
|
595
|
+
criterion=criterion,
|
596
|
+
max_depth=max_depth,
|
597
|
+
min_samples_split=min_samples_split,
|
598
|
+
min_samples_leaf=min_samples_leaf,
|
599
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
600
|
+
max_features=max_features,
|
601
|
+
max_leaf_nodes=max_leaf_nodes,
|
602
|
+
min_impurity_decrease=min_impurity_decrease,
|
603
|
+
bootstrap=bootstrap,
|
604
|
+
oob_score=oob_score,
|
605
|
+
n_jobs=n_jobs,
|
606
|
+
random_state=random_state,
|
607
|
+
verbose=verbose,
|
608
|
+
warm_start=warm_start,
|
609
|
+
ccp_alpha=ccp_alpha,
|
610
|
+
max_samples=max_samples,
|
611
|
+
)
|
612
|
+
|
613
|
+
self.tabpfn = tabpfn
|
614
|
+
|
615
|
+
self.categorical_features = categorical_features
|
616
|
+
self.show_progress = show_progress
|
617
|
+
self.verbose = verbose
|
618
|
+
self.n_jobs = n_jobs
|
619
|
+
self.adaptive_tree = adaptive_tree
|
620
|
+
self.fit_nodes = fit_nodes
|
621
|
+
self.adaptive_tree_overwrite_metric = adaptive_tree_overwrite_metric
|
622
|
+
self.adaptive_tree_test_size = adaptive_tree_test_size
|
623
|
+
self.adaptive_tree_min_train_samples = adaptive_tree_min_train_samples
|
624
|
+
self.adaptive_tree_max_train_samples = adaptive_tree_max_train_samples
|
625
|
+
self.adaptive_tree_min_valid_samples_fraction_of_train = (
|
626
|
+
adaptive_tree_min_valid_samples_fraction_of_train
|
627
|
+
)
|
628
|
+
self.preprocess_X_once = preprocess_X_once
|
629
|
+
self.max_predict_time = max_predict_time
|
630
|
+
self.rf_average_logits = rf_average_logits
|
631
|
+
|
632
|
+
def init_base_estimator(self):
|
633
|
+
"""Initialize a base decision tree estimator.
|
634
|
+
|
635
|
+
Returns:
|
636
|
+
A new DecisionTreeTabPFNRegressor instance
|
637
|
+
"""
|
638
|
+
return DecisionTreeTabPFNRegressor(
|
639
|
+
tabpfn=self.tabpfn,
|
640
|
+
min_samples_split=self.min_samples_split,
|
641
|
+
min_samples_leaf=self.min_samples_leaf,
|
642
|
+
max_features=self.max_features,
|
643
|
+
random_state=self.random_state,
|
644
|
+
categorical_features=self.categorical_features,
|
645
|
+
max_depth=self.max_depth,
|
646
|
+
show_progress=self.show_progress,
|
647
|
+
adaptive_tree=self.adaptive_tree,
|
648
|
+
fit_nodes=self.fit_nodes,
|
649
|
+
verbose=self.verbose,
|
650
|
+
adaptive_tree_test_size=self.adaptive_tree_test_size,
|
651
|
+
adaptive_tree_overwrite_metric=self.adaptive_tree_overwrite_metric,
|
652
|
+
adaptive_tree_min_train_samples=self.adaptive_tree_min_train_samples,
|
653
|
+
adaptive_tree_max_train_samples=self.adaptive_tree_max_train_samples,
|
654
|
+
adaptive_tree_min_valid_samples_fraction_of_train=self.adaptive_tree_min_valid_samples_fraction_of_train,
|
655
|
+
)
|
656
|
+
|
657
|
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
658
|
+
"""Predict regression target for X.
|
659
|
+
|
660
|
+
The predicted regression target of an input sample is computed as the
|
661
|
+
mean predicted regression targets of the trees in the forest.
|
662
|
+
|
663
|
+
Parameters:
|
664
|
+
X: {array-like, sparse matrix} of shape (n_samples, n_features)
|
665
|
+
The input samples.
|
666
|
+
|
667
|
+
Returns:
|
668
|
+
y: ndarray of shape (n_samples,) or (n_samples, n_outputs)
|
669
|
+
The predicted values.
|
670
|
+
|
671
|
+
Raises:
|
672
|
+
ValueError: If model is not fitted
|
673
|
+
"""
|
674
|
+
# Check if fitted
|
675
|
+
if not hasattr(self, "_fitted") or not self._fitted:
|
676
|
+
raise ValueError(
|
677
|
+
"This RandomForestTabPFNRegressor instance is not fitted yet. "
|
678
|
+
"Call 'fit' with appropriate arguments before using this estimator.",
|
679
|
+
)
|
680
|
+
|
681
|
+
X = validate_data(
|
682
|
+
self,
|
683
|
+
X,
|
684
|
+
ensure_all_finite=False,
|
685
|
+
)
|
686
|
+
|
687
|
+
# Special case for depth 0 - TabPFN can handle missing values directly
|
688
|
+
if self.max_depth == 0:
|
689
|
+
# No need for preprocessing - TabPFN handles NaN values
|
690
|
+
return self.tabpfn.predict(X)
|
691
|
+
|
692
|
+
# Initialize output array
|
693
|
+
n_samples = X.shape[0]
|
694
|
+
self.n_outputs_ = 1 # Only supporting single output for now
|
695
|
+
y_hat = np.zeros(n_samples, dtype=np.float64)
|
696
|
+
|
697
|
+
# Accumulate predictions from trees
|
698
|
+
start_time = time.time()
|
699
|
+
evaluated_estimators = 0
|
700
|
+
|
701
|
+
for estimator in self.estimators_:
|
702
|
+
# Get predictions from this tree
|
703
|
+
pred = estimator.predict(X)
|
704
|
+
|
705
|
+
# Accumulate
|
706
|
+
y_hat += pred
|
707
|
+
|
708
|
+
# Check timeout
|
709
|
+
evaluated_estimators += 1
|
710
|
+
time_elapsed = time.time() - start_time
|
711
|
+
if time_elapsed > self.max_predict_time and self.max_predict_time > 0:
|
712
|
+
break
|
713
|
+
|
714
|
+
# Average predictions
|
715
|
+
y_hat /= evaluated_estimators
|
716
|
+
|
717
|
+
return y_hat
|
718
|
+
|
719
|
+
|
720
|
+
def _accumulate_prediction(
|
721
|
+
predict,
|
722
|
+
X: np.ndarray,
|
723
|
+
out: list[np.ndarray],
|
724
|
+
accumulate_logits: bool = False,
|
725
|
+
) -> None:
|
726
|
+
"""This is a utility function for joblib's Parallel.
|
727
|
+
|
728
|
+
It can't go locally in ForestClassifier or ForestRegressor, because joblib
|
729
|
+
complains that it cannot pickle it when placed there.
|
730
|
+
|
731
|
+
Args:
|
732
|
+
predict: Prediction function to call
|
733
|
+
X: Input data
|
734
|
+
out: Output array to accumulate predictions into
|
735
|
+
accumulate_logits: Whether to accumulate logits instead of probabilities
|
736
|
+
"""
|
737
|
+
prediction = predict(X, check_input=False)
|
738
|
+
|
739
|
+
if accumulate_logits:
|
740
|
+
# convert multiclass probabilities to logits
|
741
|
+
prediction = np.log(prediction + 1e-10) # Add small value to avoid log(0)
|
742
|
+
|
743
|
+
if len(out) == 1:
|
744
|
+
out[0] += prediction
|
745
|
+
else:
|
746
|
+
for i in range(len(out)):
|
747
|
+
out[i] += prediction[i]
|