autogluon.tabular 1.3.2b20250711__py3-none-any.whl → 1.3.2b20250713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -1
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +376 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/METADATA +13 -15
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/RECORD +21 -14
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/zip-safe +0 -0
@@ -0,0 +1,1464 @@
|
|
1
|
+
# Copyright (c) Prior Labs GmbH 2025.
|
2
|
+
# Licensed under the Apache License, Version 2.0
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
import random
|
7
|
+
import warnings
|
8
|
+
|
9
|
+
# For type checking only
|
10
|
+
from typing import TYPE_CHECKING, Any
|
11
|
+
from copy import deepcopy
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import torch
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from numpy.typing import NDArray
|
18
|
+
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
|
19
|
+
from sklearn.model_selection import train_test_split
|
20
|
+
from sklearn.preprocessing import LabelEncoder
|
21
|
+
from sklearn.tree import (
|
22
|
+
BaseDecisionTree,
|
23
|
+
DecisionTreeClassifier,
|
24
|
+
DecisionTreeRegressor,
|
25
|
+
)
|
26
|
+
from sklearn.utils.multiclass import unique_labels
|
27
|
+
from sklearn.utils.validation import (
|
28
|
+
_check_sample_weight,
|
29
|
+
check_is_fitted,
|
30
|
+
)
|
31
|
+
|
32
|
+
from .sklearn_compat import validate_data
|
33
|
+
from .scoring_utils import (
|
34
|
+
score_classification,
|
35
|
+
score_regression,
|
36
|
+
)
|
37
|
+
from .utils import softmax
|
38
|
+
|
39
|
+
###############################################################################
|
40
|
+
# BASE DECISION TREE #
|
41
|
+
###############################################################################
|
42
|
+
|
43
|
+
|
44
|
+
class DecisionTreeTabPFNBase(BaseDecisionTree, BaseEstimator):
|
45
|
+
"""Abstract base class combining a scikit-learn Decision Tree with TabPFN at the leaves.
|
46
|
+
|
47
|
+
This class provides a hybrid approach by combining the standard decision tree
|
48
|
+
splitting algorithm from scikit-learn with TabPFN models at the leaves or
|
49
|
+
internal nodes. This allows for both interpretable tree-based partitioning
|
50
|
+
and high-performance TabPFN prediction.
|
51
|
+
|
52
|
+
Key features:
|
53
|
+
-------------
|
54
|
+
• Inherits from sklearn's BaseDecisionTree to leverage standard tree splitting algorithms
|
55
|
+
• Uses TabPFN (Classifier or Regressor) to fit leaf nodes (or all internal nodes)
|
56
|
+
• Provides adaptive pruning logic (optional) that dynamically determines optimal tree depth
|
57
|
+
• Supports both classification and regression through specialized subclasses
|
58
|
+
|
59
|
+
Subclasses:
|
60
|
+
-----------
|
61
|
+
• DecisionTreeTabPFNClassifier - for classification tasks
|
62
|
+
• DecisionTreeTabPFNRegressor - for regression tasks
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
tabpfn : Any
|
67
|
+
A TabPFN instance (TabPFNClassifier or TabPFNRegressor) that will be used at tree nodes.
|
68
|
+
criterion : str
|
69
|
+
The function to measure the quality of a split (from sklearn).
|
70
|
+
splitter : str
|
71
|
+
The strategy used to choose the split at each node (e.g. "best" or "random").
|
72
|
+
max_depth : int, optional
|
73
|
+
The maximum depth of the tree (None means unlimited).
|
74
|
+
min_samples_split : int
|
75
|
+
The minimum number of samples required to split an internal node.
|
76
|
+
min_samples_leaf : int
|
77
|
+
The minimum number of samples required to be at a leaf node.
|
78
|
+
min_weight_fraction_leaf : float
|
79
|
+
The minimum weighted fraction of the sum total of weights required to be at a leaf node.
|
80
|
+
max_features : Union[int, float, str, None]
|
81
|
+
The number of features to consider when looking for the best split.
|
82
|
+
random_state : Union[int, np.random.RandomState, None]
|
83
|
+
Controls the randomness of the estimator.
|
84
|
+
max_leaf_nodes : Optional[int]
|
85
|
+
If not None, grow a tree with max_leaf_nodes in best-first fashion.
|
86
|
+
min_impurity_decrease : float
|
87
|
+
A node will be split if this split induces a decrease of the impurity >= this value.
|
88
|
+
class_weight : Optional[Union[Dict[int, float], str]]
|
89
|
+
Only used in classification. Dict of class -> weight or “balanced”.
|
90
|
+
ccp_alpha : float
|
91
|
+
Complexity parameter used for Minimal Cost-Complexity Pruning (non-negative).
|
92
|
+
monotonic_cst : Any
|
93
|
+
Optional monotonicity constraints (depending on sklearn version).
|
94
|
+
categorical_features : Optional[List[int]]
|
95
|
+
Indices of categorical features for TabPFN usage (if any).
|
96
|
+
verbose : Union[bool, int]
|
97
|
+
Verbosity level; higher values produce more output.
|
98
|
+
show_progress : bool
|
99
|
+
Whether to show progress bars for leaf/node fitting using TabPFN.
|
100
|
+
fit_nodes : bool
|
101
|
+
Whether to fit TabPFN at internal nodes (True) or only final leaves (False).
|
102
|
+
tree_seed : int
|
103
|
+
Used to set seeds for TabPFN fitting in each node.
|
104
|
+
adaptive_tree : bool
|
105
|
+
Whether to do adaptive node-by-node pruning using a hold-out strategy.
|
106
|
+
adaptive_tree_min_train_samples : int
|
107
|
+
Minimum number of training samples required to fit a TabPFN in a node.
|
108
|
+
adaptive_tree_max_train_samples : int
|
109
|
+
Maximum number of training samples above which a node might be pruned if not a final leaf.
|
110
|
+
adaptive_tree_min_valid_samples_fraction_of_train : float
|
111
|
+
Fraction controlling the minimum valid/test points to consider a node for re-fitting.
|
112
|
+
adaptive_tree_overwrite_metric : Optional[str]
|
113
|
+
If set, overrides the default metric for pruning. E.g., "roc" or "rmse".
|
114
|
+
adaptive_tree_test_size : float
|
115
|
+
Fraction of data to hold out for adaptive pruning if no separate valid set is provided.
|
116
|
+
average_logits : bool
|
117
|
+
Whether to average logits (True) or probabilities (False) when combining predictions.
|
118
|
+
adaptive_tree_skip_class_missing : bool
|
119
|
+
If True, skip re-fitting if the nodes training set does not contain all classes (classification only).
|
120
|
+
"""
|
121
|
+
|
122
|
+
# Task type set by subclasses: "multiclass" or "regression"
|
123
|
+
task_type: str | None = None
|
124
|
+
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
*,
|
128
|
+
# Decision Tree arguments
|
129
|
+
criterion: str = "gini",
|
130
|
+
splitter: str = "best",
|
131
|
+
max_depth: int | None = None,
|
132
|
+
min_samples_split: int = 1000,
|
133
|
+
min_samples_leaf: int = 1,
|
134
|
+
min_weight_fraction_leaf: float = 0.0,
|
135
|
+
max_features: int | float | str | None = None,
|
136
|
+
random_state: int | np.random.RandomState | None = None,
|
137
|
+
max_leaf_nodes: int | None = None,
|
138
|
+
min_impurity_decrease: float = 0.0,
|
139
|
+
class_weight: dict[int, float] | str | None = None,
|
140
|
+
ccp_alpha: float = 0.0,
|
141
|
+
monotonic_cst: Any = None,
|
142
|
+
# TabPFN argument
|
143
|
+
tabpfn: Any = None, # TabPFNClassifier or TabPFNRegressor
|
144
|
+
categorical_features: list[int] | None = None,
|
145
|
+
verbose: bool | int = False,
|
146
|
+
show_progress: bool = False,
|
147
|
+
fit_nodes: bool = True,
|
148
|
+
tree_seed: int = 0,
|
149
|
+
adaptive_tree: bool = True,
|
150
|
+
adaptive_tree_min_train_samples: int = 50,
|
151
|
+
adaptive_tree_max_train_samples: int = 2000,
|
152
|
+
adaptive_tree_min_valid_samples_fraction_of_train: float = 0.2,
|
153
|
+
adaptive_tree_overwrite_metric: str | None = None,
|
154
|
+
adaptive_tree_test_size: float = 0.2,
|
155
|
+
average_logits: bool = True,
|
156
|
+
adaptive_tree_skip_class_missing: bool = True,
|
157
|
+
):
|
158
|
+
# Collect recognized arguments
|
159
|
+
self.tabpfn = tabpfn
|
160
|
+
self.criterion = criterion
|
161
|
+
self.splitter = splitter
|
162
|
+
self.max_depth = max_depth
|
163
|
+
self.min_samples_split = min_samples_split
|
164
|
+
self.min_samples_leaf = min_samples_leaf
|
165
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
166
|
+
self.max_features = max_features
|
167
|
+
self.random_state = random_state
|
168
|
+
self.max_leaf_nodes = max_leaf_nodes
|
169
|
+
self.min_impurity_decrease = min_impurity_decrease
|
170
|
+
self.class_weight = class_weight
|
171
|
+
self.ccp_alpha = ccp_alpha
|
172
|
+
self.monotonic_cst = monotonic_cst
|
173
|
+
|
174
|
+
self.categorical_features = categorical_features
|
175
|
+
self.verbose = verbose
|
176
|
+
self.show_progress = show_progress
|
177
|
+
self.fit_nodes = fit_nodes
|
178
|
+
self.tree_seed = tree_seed
|
179
|
+
self.adaptive_tree = adaptive_tree
|
180
|
+
self.adaptive_tree_min_train_samples = adaptive_tree_min_train_samples
|
181
|
+
self.adaptive_tree_max_train_samples = adaptive_tree_max_train_samples
|
182
|
+
self.adaptive_tree_min_valid_samples_fraction_of_train = (
|
183
|
+
adaptive_tree_min_valid_samples_fraction_of_train
|
184
|
+
)
|
185
|
+
self.adaptive_tree_overwrite_metric = adaptive_tree_overwrite_metric
|
186
|
+
self.adaptive_tree_test_size = adaptive_tree_test_size
|
187
|
+
self.average_logits = average_logits
|
188
|
+
self.adaptive_tree_skip_class_missing = adaptive_tree_skip_class_missing
|
189
|
+
|
190
|
+
# Initialize internal flags/structures that will be set during fit
|
191
|
+
self._need_post_fit: bool = False
|
192
|
+
self._decision_tree = None
|
193
|
+
|
194
|
+
# Handling possible differences in sklearn versions, specifically monotonic_cst
|
195
|
+
optional_args_filtered = {}
|
196
|
+
if BaseDecisionTree.__init__.__code__.co_varnames.__contains__("monotonic_cst"):
|
197
|
+
optional_args_filtered["monotonic_cst"] = monotonic_cst
|
198
|
+
|
199
|
+
# Initialize the underlying DecisionTree
|
200
|
+
super().__init__(
|
201
|
+
criterion=self.criterion,
|
202
|
+
splitter=self.splitter,
|
203
|
+
max_depth=self.max_depth,
|
204
|
+
min_samples_split=self.min_samples_split,
|
205
|
+
min_samples_leaf=self.min_samples_leaf,
|
206
|
+
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
|
207
|
+
max_features=self.max_features,
|
208
|
+
random_state=self.random_state,
|
209
|
+
max_leaf_nodes=self.max_leaf_nodes,
|
210
|
+
min_impurity_decrease=self.min_impurity_decrease,
|
211
|
+
ccp_alpha=self.ccp_alpha,
|
212
|
+
**optional_args_filtered,
|
213
|
+
)
|
214
|
+
|
215
|
+
# If the user gave a TabPFN, we do not want it to have a random_state forcibly set
|
216
|
+
# because we handle seeds ourselves at each node
|
217
|
+
if self.tabpfn is not None:
|
218
|
+
self.tabpfn.random_state = None
|
219
|
+
|
220
|
+
def _validate_tabpfn_runtime(self) -> None:
|
221
|
+
"""Validate the TabPFN instance at runtime before using it.
|
222
|
+
|
223
|
+
This ensures the TabPFN instance is still available when needed during
|
224
|
+
prediction or fitting operations.
|
225
|
+
|
226
|
+
Raises:
|
227
|
+
ValueError: If self.tabpfn is None at runtime
|
228
|
+
"""
|
229
|
+
if self.tabpfn is None:
|
230
|
+
raise ValueError("TabPFN was None at runtime - cannot proceed.")
|
231
|
+
|
232
|
+
def _more_tags(self) -> dict[str, Any]:
|
233
|
+
return {
|
234
|
+
"allow_nan": True,
|
235
|
+
}
|
236
|
+
|
237
|
+
def __sklearn_tags__(self):
|
238
|
+
tags = super().__sklearn_tags__()
|
239
|
+
tags.input_tags.allow_nan = True
|
240
|
+
tags.estimator_type = "regressor"
|
241
|
+
if self.task_type == "multiclass":
|
242
|
+
tags.estimator_type = "classifier"
|
243
|
+
else:
|
244
|
+
tags.estimator_type = "regressor"
|
245
|
+
return tags
|
246
|
+
|
247
|
+
def fit(
|
248
|
+
self,
|
249
|
+
X: NDArray[np.float64],
|
250
|
+
y: NDArray[Any],
|
251
|
+
sample_weight: NDArray[np.float64] | None = None,
|
252
|
+
check_input: bool = True,
|
253
|
+
) -> DecisionTreeTabPFNBase:
|
254
|
+
"""Fit the DecisionTree + TabPFN model.
|
255
|
+
|
256
|
+
This method trains the hybrid model by:
|
257
|
+
1. Building a decision tree structure
|
258
|
+
2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True)
|
259
|
+
3. Optionally performing adaptive pruning if adaptive_tree=True
|
260
|
+
|
261
|
+
Args:
|
262
|
+
X: The training input samples, shape (n_samples, n_features).
|
263
|
+
y: The target values (class labels for classification, real values for regression),
|
264
|
+
shape (n_samples,) or (n_samples, n_outputs).
|
265
|
+
sample_weight: Sample weights. If None, then samples are equally weighted.
|
266
|
+
check_input: Whether to validate the input data arrays. Default is True.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
self: Fitted estimator.
|
270
|
+
"""
|
271
|
+
return self._fit(X, y, sample_weight=sample_weight, check_input=check_input)
|
272
|
+
|
273
|
+
def _fit(
|
274
|
+
self,
|
275
|
+
X: NDArray[Any],
|
276
|
+
y: NDArray[Any],
|
277
|
+
sample_weight: NDArray[Any] | None = None,
|
278
|
+
check_input: bool = True,
|
279
|
+
missing_values_in_feature_mask: np.ndarray | None = None, # Unused placeholder
|
280
|
+
) -> DecisionTreeTabPFNBase:
|
281
|
+
"""Internal method to fit the DecisionTree-TabPFN model on X, y.
|
282
|
+
|
283
|
+
Parameters
|
284
|
+
----------
|
285
|
+
X : NDArray
|
286
|
+
Training features of shape (n_samples, n_features).
|
287
|
+
y : NDArray
|
288
|
+
Target labels/values of shape (n_samples,).
|
289
|
+
sample_weight : NDArray, optional
|
290
|
+
Sample weights for each sample.
|
291
|
+
check_input : bool
|
292
|
+
Whether to check inputs.
|
293
|
+
missing_values_in_feature_mask : np.ndarray, optional
|
294
|
+
Unused placeholder for older code or possible expansions.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
-------
|
298
|
+
self : DecisionTreeTabPFNBase
|
299
|
+
The fitted model.
|
300
|
+
"""
|
301
|
+
# Initialize attributes (per scikit-learn conventions)
|
302
|
+
self._leaf_nodes = []
|
303
|
+
self._leaf_train_data = {}
|
304
|
+
self._label_encoder = LabelEncoder()
|
305
|
+
self._need_post_fit = False
|
306
|
+
self._node_prediction_type = {}
|
307
|
+
|
308
|
+
# Make sure tabpfn is valid
|
309
|
+
self._validate_tabpfn_runtime()
|
310
|
+
|
311
|
+
# Possibly randomize tree_seed if not set
|
312
|
+
if self.tree_seed == 0:
|
313
|
+
self.tree_seed = random.randint(1, 10000)
|
314
|
+
|
315
|
+
sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float64)
|
316
|
+
X, y = validate_data(
|
317
|
+
self,
|
318
|
+
X,
|
319
|
+
y,
|
320
|
+
ensure_all_finite=False, # scikit-learn sets self.n_features_in_ automatically
|
321
|
+
)
|
322
|
+
|
323
|
+
if self.task_type == "multiclass":
|
324
|
+
self.classes_ = unique_labels(y)
|
325
|
+
self.n_classes_ = len(self.classes_)
|
326
|
+
|
327
|
+
# Convert torch tensor -> numpy if needed, handle NaNs
|
328
|
+
X_preprocessed = self._preprocess_data_for_tree(X)
|
329
|
+
|
330
|
+
if sample_weight is None:
|
331
|
+
sample_weight = np.ones((X_preprocessed.shape[0],), dtype=np.float64)
|
332
|
+
|
333
|
+
# Setup classes_ or n_classes_ if needed
|
334
|
+
if self.task_type == "multiclass":
|
335
|
+
# Classification
|
336
|
+
self.classes_ = np.unique(y)
|
337
|
+
self.n_classes_ = len(self.classes_)
|
338
|
+
else:
|
339
|
+
# Regression
|
340
|
+
self.n_classes_ = (
|
341
|
+
1 # Not used for numeric tasks, but keep it for consistency
|
342
|
+
)
|
343
|
+
|
344
|
+
# Possibly label-encode y for classification if your TabPFN needs it
|
345
|
+
# (Here we just rely on uniqueness checks above.)
|
346
|
+
y_ = y.copy()
|
347
|
+
|
348
|
+
# If adaptive_tree is on, do a train/validation split
|
349
|
+
if self.adaptive_tree:
|
350
|
+
stratify = y_ if (self.task_type == "multiclass") else None
|
351
|
+
|
352
|
+
# Basic checks for classification to see if splitting is feasible
|
353
|
+
if self.task_type == "multiclass":
|
354
|
+
unique_classes, counts = np.unique(y_, return_counts=True)
|
355
|
+
# Disable adaptive tree in extreme cases
|
356
|
+
if counts.min() == 1 or len(unique_classes) < 2:
|
357
|
+
self.adaptive_tree = False
|
358
|
+
elif len(unique_classes) > int(len(y_) * self.adaptive_tree_test_size):
|
359
|
+
self.adaptive_tree_test_size = min(
|
360
|
+
0.5,
|
361
|
+
len(unique_classes) / len(y_) * 1.5,
|
362
|
+
)
|
363
|
+
if len(y_) < 10:
|
364
|
+
self.adaptive_tree = False
|
365
|
+
|
366
|
+
if self.adaptive_tree:
|
367
|
+
(
|
368
|
+
X_train,
|
369
|
+
X_valid,
|
370
|
+
X_preproc_train,
|
371
|
+
X_preproc_valid,
|
372
|
+
y_train,
|
373
|
+
y_valid,
|
374
|
+
sw_train,
|
375
|
+
sw_valid,
|
376
|
+
) = train_test_split(
|
377
|
+
X,
|
378
|
+
X_preprocessed,
|
379
|
+
y_,
|
380
|
+
sample_weight,
|
381
|
+
test_size=self.adaptive_tree_test_size,
|
382
|
+
random_state=self.random_state,
|
383
|
+
stratify=stratify,
|
384
|
+
)
|
385
|
+
|
386
|
+
# Safety check - if split is empty, revert
|
387
|
+
if len(y_train) == 0 or len(y_valid) == 0:
|
388
|
+
self.adaptive_tree = False
|
389
|
+
X_train, X_preproc_train, y_train, sw_train = (
|
390
|
+
X,
|
391
|
+
X_preprocessed,
|
392
|
+
y_,
|
393
|
+
sample_weight,
|
394
|
+
)
|
395
|
+
X_valid = X_preproc_valid = y_valid = sw_valid = None
|
396
|
+
|
397
|
+
# If classification, also ensure train/valid has same classes
|
398
|
+
if (
|
399
|
+
self.task_type == "multiclass"
|
400
|
+
and self.adaptive_tree
|
401
|
+
and (len(np.unique(y_train)) != len(np.unique(y_valid)))
|
402
|
+
):
|
403
|
+
self.adaptive_tree = False
|
404
|
+
else:
|
405
|
+
# If we were disabled, keep all data as training
|
406
|
+
X_train, X_preproc_train, y_train, sw_train = (
|
407
|
+
X,
|
408
|
+
X_preprocessed,
|
409
|
+
y_,
|
410
|
+
sample_weight,
|
411
|
+
)
|
412
|
+
X_valid = X_preproc_valid = y_valid = sw_valid = None
|
413
|
+
else:
|
414
|
+
# Not adaptive, everything is train
|
415
|
+
X_train, X_preproc_train, y_train, sw_train = (
|
416
|
+
X,
|
417
|
+
X_preprocessed,
|
418
|
+
y_,
|
419
|
+
sample_weight,
|
420
|
+
)
|
421
|
+
X_valid = X_preproc_valid = y_valid = sw_valid = None
|
422
|
+
|
423
|
+
# Build the sklearn decision tree
|
424
|
+
self._decision_tree = self._init_decision_tree()
|
425
|
+
self._decision_tree.fit(X_preproc_train, y_train, sample_weight=sw_train)
|
426
|
+
self._tree = self._decision_tree # for sklearn compatibility
|
427
|
+
|
428
|
+
# Keep references for potential post-fitting (leaf-level fitting)
|
429
|
+
self.X = X
|
430
|
+
self.y = y_
|
431
|
+
self.train_X = X_train
|
432
|
+
self.train_X_preprocessed = X_preproc_train
|
433
|
+
self.train_y = y_train
|
434
|
+
self.train_sample_weight = sw_train
|
435
|
+
|
436
|
+
if self.adaptive_tree:
|
437
|
+
self.valid_X = X_valid
|
438
|
+
self.valid_X_preprocessed = X_preproc_valid
|
439
|
+
self.valid_y = y_valid
|
440
|
+
self.valid_sample_weight = sw_valid
|
441
|
+
|
442
|
+
# We will do a leaf-fitting step on demand (lazy) in predict
|
443
|
+
self._need_post_fit = True
|
444
|
+
|
445
|
+
# If verbose, optionally do it right away:
|
446
|
+
if self.verbose:
|
447
|
+
self._post_fit()
|
448
|
+
|
449
|
+
return self
|
450
|
+
|
451
|
+
def _init_decision_tree(self) -> BaseDecisionTree:
|
452
|
+
"""Initialize the underlying scikit-learn Decision Tree.
|
453
|
+
|
454
|
+
Overridden by child classes for classifier vs regressor.
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
-------
|
458
|
+
BaseDecisionTree
|
459
|
+
An instance of a scikit-learn DecisionTreeClassifier or DecisionTreeRegressor.
|
460
|
+
"""
|
461
|
+
raise NotImplementedError("Must be implemented in subclass.")
|
462
|
+
|
463
|
+
def _post_fit(self) -> None:
|
464
|
+
"""Hook after the decision tree is fitted. Can be used for final prints/logs."""
|
465
|
+
if self.verbose:
|
466
|
+
pass
|
467
|
+
|
468
|
+
def _preprocess_data_for_tree(self, X: np.ndarray) -> np.ndarray:
|
469
|
+
"""Handle missing data prior to feeding into the decision tree.
|
470
|
+
|
471
|
+
Replaces NaNs with a default value, handles pandas DataFrames and other input types.
|
472
|
+
Uses scikit-learn's validation functions for compatibility.
|
473
|
+
|
474
|
+
Parameters
|
475
|
+
----------
|
476
|
+
X : array-like
|
477
|
+
Input features, possibly containing NaNs.
|
478
|
+
|
479
|
+
Returns:
|
480
|
+
-------
|
481
|
+
np.ndarray
|
482
|
+
A copy of X with NaNs replaced by default value.
|
483
|
+
"""
|
484
|
+
# Use check_array from sklearn_compat to handle different input types
|
485
|
+
from .sklearn_compat import check_array
|
486
|
+
|
487
|
+
# Handle torch tensor
|
488
|
+
if torch.is_tensor(X):
|
489
|
+
X = X.cpu().numpy()
|
490
|
+
|
491
|
+
# Convert to array and handle input validation
|
492
|
+
# Don't extract DataFrame values - let check_array handle it
|
493
|
+
X = check_array(
|
494
|
+
X,
|
495
|
+
dtype=np.float64,
|
496
|
+
ensure_all_finite=False, # We'll handle NaNs ourselves
|
497
|
+
ensure_2d=True,
|
498
|
+
copy=True, # Make a copy so we don't modify the original
|
499
|
+
)
|
500
|
+
|
501
|
+
# Replace NaN with our specific value (-1000.0)
|
502
|
+
X = np.nan_to_num(X, nan=-1000.0)
|
503
|
+
return X
|
504
|
+
|
505
|
+
def _apply_tree(self, X: np.ndarray) -> np.ndarray:
|
506
|
+
"""Apply the fitted tree to X, returning a matrix of leaf membership.
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
-------
|
510
|
+
np.ndarray
|
511
|
+
A dense matrix of shape (n_samples, n_nodes, n_estimators),
|
512
|
+
though we typically only have 1 estimator.
|
513
|
+
"""
|
514
|
+
X_preprocessed = self._preprocess_data_for_tree(X)
|
515
|
+
decision_path = self.get_tree().decision_path(X_preprocessed)
|
516
|
+
return np.expand_dims(decision_path.todense(), axis=2)
|
517
|
+
|
518
|
+
def _apply_tree_train(
|
519
|
+
self,
|
520
|
+
X: np.ndarray,
|
521
|
+
y: np.ndarray,
|
522
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
523
|
+
"""Apply the tree for training data, returning leaf membership plus (X, y) unchanged.
|
524
|
+
|
525
|
+
Returns:
|
526
|
+
-------
|
527
|
+
leaf_matrix : np.ndarray
|
528
|
+
Shape (n_samples, n_nodes, n_estimators)
|
529
|
+
X_array : np.ndarray
|
530
|
+
Same as X input
|
531
|
+
y_array : np.ndarray
|
532
|
+
Same as y input
|
533
|
+
"""
|
534
|
+
return self._apply_tree(X), X, y
|
535
|
+
|
536
|
+
def get_tree(self) -> BaseDecisionTree:
|
537
|
+
"""Return the underlying fitted sklearn decision tree.
|
538
|
+
|
539
|
+
Returns:
|
540
|
+
DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree.
|
541
|
+
|
542
|
+
Raises:
|
543
|
+
sklearn.exceptions.NotFittedError: If the model has not been fitted yet.
|
544
|
+
"""
|
545
|
+
# This will raise NotFittedError if the model is not fitted
|
546
|
+
check_is_fitted(self, ["_tree", "X", "y"])
|
547
|
+
return self._tree
|
548
|
+
|
549
|
+
@property
|
550
|
+
def tree_(self):
|
551
|
+
"""Expose the fitted tree for sklearn compatibility.
|
552
|
+
|
553
|
+
Returns:
|
554
|
+
-------
|
555
|
+
sklearn.tree._tree.Tree
|
556
|
+
Underlying scikit-learn tree object.
|
557
|
+
"""
|
558
|
+
return self.get_tree().tree_
|
559
|
+
|
560
|
+
def fit_leaves(
|
561
|
+
self,
|
562
|
+
train_X: np.ndarray,
|
563
|
+
train_y: np.ndarray,
|
564
|
+
) -> None:
|
565
|
+
"""Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).
|
566
|
+
|
567
|
+
This populates an internal dictionary of training data for each leaf
|
568
|
+
so that TabPFN can make predictions at these leaves.
|
569
|
+
|
570
|
+
Parameters
|
571
|
+
----------
|
572
|
+
train_X : np.ndarray
|
573
|
+
Training features for all samples.
|
574
|
+
train_y : np.ndarray
|
575
|
+
Training labels/targets for all samples.
|
576
|
+
"""
|
577
|
+
self._leaf_train_data = {}
|
578
|
+
leaf_node_matrix, _, _ = self._apply_tree_train(train_X, train_y)
|
579
|
+
self._leaf_nodes = leaf_node_matrix
|
580
|
+
|
581
|
+
n_samples, n_nodes, n_estims = leaf_node_matrix.shape
|
582
|
+
|
583
|
+
for estimator_id in range(n_estims):
|
584
|
+
self._leaf_train_data[estimator_id] = {}
|
585
|
+
for leaf_id in range(n_nodes):
|
586
|
+
indices = np.argwhere(
|
587
|
+
leaf_node_matrix[:, leaf_id, estimator_id],
|
588
|
+
).ravel()
|
589
|
+
X_leaf_samples = np.take(train_X, indices, axis=0)
|
590
|
+
y_leaf_samples = np.take(train_y, indices, axis=0).ravel()
|
591
|
+
|
592
|
+
self._leaf_train_data[estimator_id][leaf_id] = (
|
593
|
+
X_leaf_samples,
|
594
|
+
y_leaf_samples,
|
595
|
+
)
|
596
|
+
|
597
|
+
def _predict_internal(
|
598
|
+
self,
|
599
|
+
X: np.ndarray,
|
600
|
+
y: np.ndarray | None = None,
|
601
|
+
check_input: bool = True,
|
602
|
+
) -> np.ndarray:
|
603
|
+
"""Internal method used to produce probabilities or regression predictions,
|
604
|
+
with optional adaptive pruning logic.
|
605
|
+
|
606
|
+
If y is given and we have adaptive_tree=True, node-level pruning is applied.
|
607
|
+
|
608
|
+
Parameters
|
609
|
+
----------
|
610
|
+
X : np.ndarray
|
611
|
+
Features to predict.
|
612
|
+
y : np.ndarray, optional
|
613
|
+
Target values, only required if we are in adaptive pruning mode
|
614
|
+
and need to compare node performance.
|
615
|
+
check_input : bool, default=True
|
616
|
+
Whether to validate input arrays.
|
617
|
+
|
618
|
+
Returns:
|
619
|
+
-------
|
620
|
+
np.ndarray
|
621
|
+
The final predictions (probabilities for classification, or continuous values for regression).
|
622
|
+
"""
|
623
|
+
# If we haven't yet done the final leaf fit, do it here
|
624
|
+
if self._need_post_fit:
|
625
|
+
self._need_post_fit = False
|
626
|
+
if self.adaptive_tree:
|
627
|
+
# Fit leaves on train data, check performance on valid data if available
|
628
|
+
self.fit_leaves(self.train_X, self.train_y)
|
629
|
+
if (
|
630
|
+
hasattr(self, "valid_X")
|
631
|
+
and self.valid_X is not None
|
632
|
+
and self.valid_y is not None
|
633
|
+
):
|
634
|
+
# Force a pass to evaluate node performance
|
635
|
+
# so we can prune or decide node updates
|
636
|
+
self._predict_internal(
|
637
|
+
self.valid_X,
|
638
|
+
self.valid_y,
|
639
|
+
check_input=False,
|
640
|
+
)
|
641
|
+
# Now fit leaves again using the entire dataset (train + valid, effectively)
|
642
|
+
self.fit_leaves(self.X, self.y)
|
643
|
+
|
644
|
+
# Assign TabPFNs categorical features if needed
|
645
|
+
if self.tabpfn is not None:
|
646
|
+
self.tabpfn.categorical_features_indices = self.categorical_features
|
647
|
+
|
648
|
+
# Find leaf membership in X
|
649
|
+
X_leaf_nodes = self._apply_tree(X)
|
650
|
+
n_samples, n_nodes, n_estims = X_leaf_nodes.shape
|
651
|
+
|
652
|
+
# Track intermediate predictions
|
653
|
+
y_prob: dict[int, dict[int, np.ndarray]] = {}
|
654
|
+
y_metric: dict[int, dict[int, float]] = {}
|
655
|
+
|
656
|
+
# If pruning, track how each node is updated
|
657
|
+
do_pruning = (y is not None) and self.adaptive_tree
|
658
|
+
if do_pruning:
|
659
|
+
self._node_prediction_type: dict[int, dict[int, str]] = {}
|
660
|
+
|
661
|
+
for est_id in range(n_estims):
|
662
|
+
if do_pruning:
|
663
|
+
self._node_prediction_type[est_id] = {}
|
664
|
+
y_prob[est_id] = {}
|
665
|
+
y_metric[est_id] = {}
|
666
|
+
if self.show_progress:
|
667
|
+
import tqdm.auto
|
668
|
+
|
669
|
+
node_iter = tqdm.auto.tqdm(range(n_nodes), desc=f"Estimator {est_id}")
|
670
|
+
else:
|
671
|
+
node_iter = range(n_nodes)
|
672
|
+
|
673
|
+
for leaf_id in node_iter:
|
674
|
+
self._pruning_init_node_predictions(
|
675
|
+
leaf_id,
|
676
|
+
est_id,
|
677
|
+
y_prob,
|
678
|
+
y_metric,
|
679
|
+
n_nodes,
|
680
|
+
n_samples,
|
681
|
+
)
|
682
|
+
if est_id > 0 and leaf_id == 0:
|
683
|
+
# Skip repeated re-initialization if multiple trees
|
684
|
+
continue
|
685
|
+
|
686
|
+
# Gather test-sample indices that belong to this leaf
|
687
|
+
test_sample_indices = np.argwhere(
|
688
|
+
X_leaf_nodes[:, leaf_id, est_id],
|
689
|
+
).ravel()
|
690
|
+
|
691
|
+
# Gather training samples that belong to this leaf
|
692
|
+
X_train_leaf, y_train_leaf = self._leaf_train_data[est_id][leaf_id]
|
693
|
+
|
694
|
+
# If no training or test samples in this node, skip
|
695
|
+
if (X_train_leaf.shape[0] == 0) or (len(test_sample_indices) == 0):
|
696
|
+
if do_pruning:
|
697
|
+
self._node_prediction_type[est_id][leaf_id] = "previous"
|
698
|
+
continue
|
699
|
+
|
700
|
+
# Determine if this is a final leaf
|
701
|
+
# If the sum of membership in subsequent nodes is zero, its final
|
702
|
+
is_leaf = (
|
703
|
+
X_leaf_nodes[test_sample_indices, leaf_id + 1 :, est_id].sum()
|
704
|
+
== 0.0
|
705
|
+
)
|
706
|
+
|
707
|
+
# If it's not a leaf and we are not fitting internal nodes, skip
|
708
|
+
# (unless leaf_id==0 and we do a top-level check for adaptive_tree)
|
709
|
+
if (
|
710
|
+
(not is_leaf)
|
711
|
+
and (not self.fit_nodes)
|
712
|
+
and not (leaf_id == 0 and self.adaptive_tree)
|
713
|
+
):
|
714
|
+
if do_pruning:
|
715
|
+
self._node_prediction_type[est_id][leaf_id] = "previous"
|
716
|
+
continue
|
717
|
+
|
718
|
+
# Additional adaptive checks
|
719
|
+
if self.adaptive_tree and leaf_id != 0:
|
720
|
+
should_skip_previously_pruned = False
|
721
|
+
if y is None:
|
722
|
+
# Safely check if the key exists before accessing
|
723
|
+
node_type = self._node_prediction_type.get(est_id, {}).get(
|
724
|
+
leaf_id,
|
725
|
+
)
|
726
|
+
if node_type == "previous":
|
727
|
+
should_skip_previously_pruned = True
|
728
|
+
|
729
|
+
if should_skip_previously_pruned:
|
730
|
+
continue
|
731
|
+
|
732
|
+
# Skip if classification is missing a class
|
733
|
+
if (
|
734
|
+
self.task_type == "multiclass"
|
735
|
+
and len(np.unique(y_train_leaf)) < self.n_classes_
|
736
|
+
and self.adaptive_tree_skip_class_missing
|
737
|
+
):
|
738
|
+
self._node_prediction_type[est_id][leaf_id] = "previous"
|
739
|
+
continue
|
740
|
+
|
741
|
+
# Skip if too few or too many training points
|
742
|
+
if (
|
743
|
+
(X_train_leaf.shape[0] < self.adaptive_tree_min_train_samples)
|
744
|
+
or (
|
745
|
+
len(test_sample_indices)
|
746
|
+
< self.adaptive_tree_min_valid_samples_fraction_of_train
|
747
|
+
* self.adaptive_tree_min_train_samples
|
748
|
+
)
|
749
|
+
or (
|
750
|
+
X_train_leaf.shape[0] > self.adaptive_tree_max_train_samples
|
751
|
+
and not is_leaf
|
752
|
+
)
|
753
|
+
):
|
754
|
+
if do_pruning:
|
755
|
+
self._node_prediction_type[est_id][leaf_id] = "previous"
|
756
|
+
continue
|
757
|
+
|
758
|
+
# Perform leaf-level TabPFN prediction
|
759
|
+
leaf_prediction = self._predict_leaf(
|
760
|
+
X_train_leaf,
|
761
|
+
y_train_leaf,
|
762
|
+
leaf_id,
|
763
|
+
X,
|
764
|
+
test_sample_indices,
|
765
|
+
)
|
766
|
+
|
767
|
+
# Evaluate “averaging” and “replacement” for pruning
|
768
|
+
y_prob_averaging, y_prob_replacement = (
|
769
|
+
self._pruning_get_prediction_type_results(
|
770
|
+
y_prob,
|
771
|
+
leaf_prediction,
|
772
|
+
test_sample_indices,
|
773
|
+
est_id,
|
774
|
+
leaf_id,
|
775
|
+
)
|
776
|
+
)
|
777
|
+
|
778
|
+
# Decide best approach if in adaptive mode
|
779
|
+
if self.adaptive_tree:
|
780
|
+
# If not adaptive, we simply do replacement
|
781
|
+
y_prob[est_id][leaf_id] = y_prob_replacement
|
782
|
+
elif y is not None:
|
783
|
+
self._pruning_set_node_prediction_type(
|
784
|
+
y,
|
785
|
+
y_prob_averaging,
|
786
|
+
y_prob_replacement,
|
787
|
+
y_metric,
|
788
|
+
est_id,
|
789
|
+
leaf_id,
|
790
|
+
)
|
791
|
+
self._pruning_set_predictions(
|
792
|
+
y_prob,
|
793
|
+
y_prob_averaging,
|
794
|
+
y_prob_replacement,
|
795
|
+
est_id,
|
796
|
+
leaf_id,
|
797
|
+
)
|
798
|
+
y_metric[est_id][leaf_id] = self._score(
|
799
|
+
y,
|
800
|
+
y_prob[est_id][leaf_id],
|
801
|
+
)
|
802
|
+
else:
|
803
|
+
# If not validating and not adaptive, just use replacement
|
804
|
+
y_prob[est_id][leaf_id] = y_prob_replacement
|
805
|
+
|
806
|
+
# Final predictions come from the last estimators last node
|
807
|
+
return y_prob[n_estims - 1][n_nodes - 1]
|
808
|
+
|
809
|
+
def _pruning_init_node_predictions(
|
810
|
+
self,
|
811
|
+
leaf_id: int,
|
812
|
+
estimator_id: int,
|
813
|
+
y_prob: dict[int, dict[int, np.ndarray]],
|
814
|
+
y_metric: dict[int, dict[int, float]],
|
815
|
+
n_nodes: int,
|
816
|
+
n_samples: int,
|
817
|
+
) -> None:
|
818
|
+
"""Initialize node predictions for the pruning logic.
|
819
|
+
|
820
|
+
Parameters
|
821
|
+
----------
|
822
|
+
leaf_id : int
|
823
|
+
Index of the leaf/node being processed.
|
824
|
+
estimator_id : int
|
825
|
+
Index of the current estimator (if multiple).
|
826
|
+
y_prob : dict
|
827
|
+
Nested dictionary of predictions.
|
828
|
+
y_metric : dict
|
829
|
+
Nested dictionary of scores/metrics.
|
830
|
+
n_nodes : int
|
831
|
+
Total number of nodes in the tree.
|
832
|
+
n_samples : int
|
833
|
+
Number of samples in X.
|
834
|
+
"""
|
835
|
+
if estimator_id == 0 and leaf_id == 0:
|
836
|
+
y_prob[0][0] = self._init_eval_probability_array(n_samples, to_zero=True)
|
837
|
+
y_metric[0][0] = 0.0
|
838
|
+
elif leaf_id == 0 and estimator_id > 0:
|
839
|
+
# If first leaf of new estimator, carry from last node of previous estimator
|
840
|
+
y_prob[estimator_id][leaf_id] = y_prob[estimator_id - 1][n_nodes - 1]
|
841
|
+
y_metric[estimator_id][leaf_id] = y_metric[estimator_id - 1][n_nodes - 1]
|
842
|
+
else:
|
843
|
+
# Use last leaf of the same estimator
|
844
|
+
y_prob[estimator_id][leaf_id] = y_prob[estimator_id][leaf_id - 1]
|
845
|
+
y_metric[estimator_id][leaf_id] = y_metric[estimator_id][leaf_id - 1]
|
846
|
+
|
847
|
+
def _pruning_get_prediction_type_results(
|
848
|
+
self,
|
849
|
+
y_eval_prob: dict[int, dict[int, np.ndarray]],
|
850
|
+
leaf_prediction: np.ndarray,
|
851
|
+
test_sample_indices: np.ndarray,
|
852
|
+
estimator_id: int,
|
853
|
+
leaf_id: int,
|
854
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
855
|
+
"""Produce the “averaging” and “replacement” predictions for pruning decisions.
|
856
|
+
|
857
|
+
Parameters
|
858
|
+
----------
|
859
|
+
y_eval_prob : dict
|
860
|
+
Nested dictionary of predictions.
|
861
|
+
leaf_prediction : np.ndarray
|
862
|
+
Predictions from the newly fitted leaf (for relevant samples).
|
863
|
+
test_sample_indices : np.ndarray
|
864
|
+
Indices of the test samples that fall into this leaf.
|
865
|
+
estimator_id : int
|
866
|
+
Index of the current estimator.
|
867
|
+
leaf_id : int
|
868
|
+
Index of the current leaf/node.
|
869
|
+
|
870
|
+
Returns:
|
871
|
+
-------
|
872
|
+
y_prob_averaging : np.ndarray
|
873
|
+
Updated predictions using an “averaging” rule.
|
874
|
+
y_prob_replacement : np.ndarray
|
875
|
+
Updated predictions using a “replacement” rule.
|
876
|
+
"""
|
877
|
+
y_prob_current = y_eval_prob[estimator_id][leaf_id]
|
878
|
+
y_prob_replacement = np.copy(y_prob_current)
|
879
|
+
# "replacement" sets the new leaf prediction directly
|
880
|
+
y_prob_replacement[test_sample_indices] = leaf_prediction[test_sample_indices]
|
881
|
+
|
882
|
+
if self.task_type == "multiclass":
|
883
|
+
# Normalize
|
884
|
+
row_sums = y_prob_replacement.sum(axis=1, keepdims=True)
|
885
|
+
row_sums[row_sums == 0] = 1.0
|
886
|
+
y_prob_replacement /= row_sums
|
887
|
+
|
888
|
+
# "averaging" -> combine old predictions with new
|
889
|
+
y_prob_averaging = np.copy(y_prob_current)
|
890
|
+
|
891
|
+
if self.task_type == "multiclass":
|
892
|
+
if self.average_logits:
|
893
|
+
# Convert old + new to log, sum them, then softmax
|
894
|
+
y_prob_averaging[test_sample_indices] = np.log(
|
895
|
+
y_prob_averaging[test_sample_indices] + 1e-6,
|
896
|
+
)
|
897
|
+
leaf_pred_log = np.log(leaf_prediction[test_sample_indices] + 1e-6)
|
898
|
+
y_prob_averaging[test_sample_indices] += leaf_pred_log
|
899
|
+
y_prob_averaging[test_sample_indices] = softmax(
|
900
|
+
y_prob_averaging[test_sample_indices],
|
901
|
+
)
|
902
|
+
else:
|
903
|
+
# Average probabilities directly
|
904
|
+
y_prob_averaging[test_sample_indices] += leaf_prediction[
|
905
|
+
test_sample_indices
|
906
|
+
]
|
907
|
+
row_sums = y_prob_averaging.sum(axis=1, keepdims=True)
|
908
|
+
row_sums[row_sums == 0] = 1.0
|
909
|
+
y_prob_averaging /= row_sums
|
910
|
+
elif self.task_type == "regression":
|
911
|
+
# Regression -> simply average
|
912
|
+
y_prob_averaging[test_sample_indices] += leaf_prediction[
|
913
|
+
test_sample_indices
|
914
|
+
]
|
915
|
+
y_prob_averaging[test_sample_indices] /= 2.0
|
916
|
+
|
917
|
+
return y_prob_averaging, y_prob_replacement
|
918
|
+
|
919
|
+
def _pruning_set_node_prediction_type(
|
920
|
+
self,
|
921
|
+
y_true: np.ndarray,
|
922
|
+
y_prob_averaging: np.ndarray,
|
923
|
+
y_prob_replacement: np.ndarray,
|
924
|
+
y_metric: dict[int, dict[int, float]],
|
925
|
+
estimator_id: int,
|
926
|
+
leaf_id: int,
|
927
|
+
) -> None:
|
928
|
+
"""Decide which approach is better: “averaging” vs “replacement” vs “previous,”
|
929
|
+
using the nodes previous metric vs new metrics.
|
930
|
+
|
931
|
+
Parameters
|
932
|
+
----------
|
933
|
+
y_true : np.ndarray
|
934
|
+
Ground-truth labels/targets for pruning comparison.
|
935
|
+
y_prob_averaging : np.ndarray
|
936
|
+
Predictions if we use averaging.
|
937
|
+
y_prob_replacement : np.ndarray
|
938
|
+
Predictions if we use replacement.
|
939
|
+
y_metric : dict
|
940
|
+
Nested dictionary of scores/metrics for each node.
|
941
|
+
estimator_id : int
|
942
|
+
Index of the current estimator.
|
943
|
+
leaf_id : int
|
944
|
+
Index of the current leaf/node.
|
945
|
+
"""
|
946
|
+
averaging_score = self._score(y_true, y_prob_averaging)
|
947
|
+
replacement_score = self._score(y_true, y_prob_replacement)
|
948
|
+
prev_score = y_metric[estimator_id][leaf_id - 1] if (leaf_id > 0) else 0.0
|
949
|
+
|
950
|
+
if (leaf_id == 0) or (max(averaging_score, replacement_score) > prev_score):
|
951
|
+
# Pick whichever is better
|
952
|
+
if replacement_score > averaging_score:
|
953
|
+
prediction_type = "replacement"
|
954
|
+
else:
|
955
|
+
prediction_type = "averaging"
|
956
|
+
else:
|
957
|
+
prediction_type = "previous"
|
958
|
+
|
959
|
+
self._node_prediction_type[estimator_id][leaf_id] = prediction_type
|
960
|
+
|
961
|
+
def _pruning_set_predictions(
|
962
|
+
self,
|
963
|
+
y_prob: dict[int, dict[int, np.ndarray]],
|
964
|
+
y_prob_averaging: np.ndarray,
|
965
|
+
y_prob_replacement: np.ndarray,
|
966
|
+
estimator_id: int,
|
967
|
+
leaf_id: int,
|
968
|
+
) -> None:
|
969
|
+
"""Based on the chosen node_prediction_type, finalize the predictions.
|
970
|
+
|
971
|
+
Parameters
|
972
|
+
----------
|
973
|
+
y_prob : dict
|
974
|
+
Nested dictionary of predictions.
|
975
|
+
y_prob_averaging : np.ndarray
|
976
|
+
Predictions if we use averaging.
|
977
|
+
y_prob_replacement : np.ndarray
|
978
|
+
Predictions if we use replacement.
|
979
|
+
estimator_id : int
|
980
|
+
Index of the current estimator.
|
981
|
+
leaf_id : int
|
982
|
+
Index of the current leaf/node.
|
983
|
+
"""
|
984
|
+
node_type = self._node_prediction_type[estimator_id][leaf_id]
|
985
|
+
if node_type == "averaging":
|
986
|
+
y_prob[estimator_id][leaf_id] = y_prob_averaging
|
987
|
+
elif node_type == "replacement":
|
988
|
+
y_prob[estimator_id][leaf_id] = y_prob_replacement
|
989
|
+
else:
|
990
|
+
# “previous”
|
991
|
+
y_prob[estimator_id][leaf_id] = y_prob[estimator_id][leaf_id - 1]
|
992
|
+
|
993
|
+
def _init_eval_probability_array(
|
994
|
+
self,
|
995
|
+
n_samples: int,
|
996
|
+
to_zero: bool = False,
|
997
|
+
) -> np.ndarray:
|
998
|
+
"""Initialize an array of predictions for the entire dataset.
|
999
|
+
|
1000
|
+
For classification, this is (n_samples, n_classes).
|
1001
|
+
For regression, this is (n_samples,).
|
1002
|
+
|
1003
|
+
Parameters
|
1004
|
+
----------
|
1005
|
+
n_samples : int
|
1006
|
+
Number of samples to predict.
|
1007
|
+
to_zero : bool, default=False
|
1008
|
+
If True, fill with zeros. Otherwise use uniform for classification,
|
1009
|
+
or zeros for regression.
|
1010
|
+
|
1011
|
+
Returns:
|
1012
|
+
-------
|
1013
|
+
np.ndarray
|
1014
|
+
An appropriately sized array of initial predictions.
|
1015
|
+
"""
|
1016
|
+
if self.task_type == "multiclass":
|
1017
|
+
if to_zero:
|
1018
|
+
return np.zeros((n_samples, self.n_classes_), dtype=np.float64)
|
1019
|
+
return (
|
1020
|
+
np.ones((n_samples, self.n_classes_), dtype=np.float64)
|
1021
|
+
/ self.n_classes_
|
1022
|
+
)
|
1023
|
+
else:
|
1024
|
+
# Regression
|
1025
|
+
return np.zeros((n_samples,), dtype=np.float64)
|
1026
|
+
|
1027
|
+
def _score(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
1028
|
+
"""Compute a performance score given ground truth and predictions.
|
1029
|
+
|
1030
|
+
Parameters
|
1031
|
+
----------
|
1032
|
+
y_true : np.ndarray
|
1033
|
+
Ground truth labels or values.
|
1034
|
+
y_pred : np.ndarray
|
1035
|
+
Predictions (probabilities for classification, continuous for regression).
|
1036
|
+
|
1037
|
+
Returns:
|
1038
|
+
-------
|
1039
|
+
float
|
1040
|
+
The performance score (higher is better for classification,
|
1041
|
+
or depends on the specific metric).
|
1042
|
+
"""
|
1043
|
+
metric = self._get_optimize_metric()
|
1044
|
+
if self.task_type == "multiclass":
|
1045
|
+
return score_classification(metric, y_true, y_pred)
|
1046
|
+
elif self.task_type == "regression":
|
1047
|
+
return score_regression(metric, y_true, y_pred)
|
1048
|
+
else:
|
1049
|
+
raise NotImplementedError
|
1050
|
+
|
1051
|
+
def _get_optimize_metric(self) -> str:
|
1052
|
+
"""Return which metric name to use for scoring.
|
1053
|
+
|
1054
|
+
Returns:
|
1055
|
+
-------
|
1056
|
+
str
|
1057
|
+
The metric name, e.g. "roc" for classification or "rmse" for regression.
|
1058
|
+
"""
|
1059
|
+
if self.adaptive_tree_overwrite_metric is not None:
|
1060
|
+
return self.adaptive_tree_overwrite_metric
|
1061
|
+
if self.task_type == "multiclass":
|
1062
|
+
return "roc"
|
1063
|
+
return "rmse"
|
1064
|
+
|
1065
|
+
def _predict_leaf(
|
1066
|
+
self,
|
1067
|
+
X_train_leaf: np.ndarray,
|
1068
|
+
y_train_leaf: np.ndarray,
|
1069
|
+
leaf_id: int,
|
1070
|
+
X_full: np.ndarray,
|
1071
|
+
indices: np.ndarray,
|
1072
|
+
) -> np.ndarray:
|
1073
|
+
"""Each subclass implements how to call TabPFN for classification or regression.
|
1074
|
+
|
1075
|
+
Parameters
|
1076
|
+
----------
|
1077
|
+
X_train_leaf : np.ndarray
|
1078
|
+
Training features for the samples in this leaf/node.
|
1079
|
+
y_train_leaf : np.ndarray
|
1080
|
+
Training targets for the samples in this leaf/node.
|
1081
|
+
leaf_id : int
|
1082
|
+
Leaf/node index (for seeding or debugging).
|
1083
|
+
X_full : np.ndarray
|
1084
|
+
The entire set of features we are predicting on.
|
1085
|
+
indices : np.ndarray
|
1086
|
+
The indices in X_full that belong to this leaf.
|
1087
|
+
|
1088
|
+
Returns:
|
1089
|
+
-------
|
1090
|
+
np.ndarray
|
1091
|
+
Predictions for all n_samples, but only indices are filled meaningfully.
|
1092
|
+
"""
|
1093
|
+
raise NotImplementedError("Must be implemented in subclass.")
|
1094
|
+
|
1095
|
+
|
1096
|
+
###############################################################################
|
1097
|
+
# CLASSIFIER SUBCLASS #
|
1098
|
+
###############################################################################
|
1099
|
+
|
1100
|
+
|
1101
|
+
class DecisionTreeTabPFNClassifier(DecisionTreeTabPFNBase, ClassifierMixin):
|
1102
|
+
"""Decision tree that uses TabPFNClassifier at the leaves."""
|
1103
|
+
|
1104
|
+
task_type: str = "multiclass"
|
1105
|
+
|
1106
|
+
def _init_decision_tree(self) -> DecisionTreeClassifier:
|
1107
|
+
"""Create a scikit-learn DecisionTreeClassifier with stored parameters."""
|
1108
|
+
return DecisionTreeClassifier(
|
1109
|
+
criterion=self.criterion,
|
1110
|
+
max_depth=self.max_depth,
|
1111
|
+
min_samples_split=self.min_samples_split,
|
1112
|
+
min_samples_leaf=self.min_samples_leaf,
|
1113
|
+
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
|
1114
|
+
max_features=self.max_features,
|
1115
|
+
random_state=self.random_state,
|
1116
|
+
max_leaf_nodes=self.max_leaf_nodes,
|
1117
|
+
min_impurity_decrease=self.min_impurity_decrease,
|
1118
|
+
class_weight=self.class_weight,
|
1119
|
+
ccp_alpha=self.ccp_alpha,
|
1120
|
+
splitter=self.splitter,
|
1121
|
+
)
|
1122
|
+
|
1123
|
+
def _predict_leaf(
|
1124
|
+
self,
|
1125
|
+
X_train_leaf: np.ndarray,
|
1126
|
+
y_train_leaf: np.ndarray,
|
1127
|
+
leaf_id: int,
|
1128
|
+
X_full: np.ndarray,
|
1129
|
+
indices: np.ndarray,
|
1130
|
+
) -> np.ndarray:
|
1131
|
+
"""Fit a TabPFNClassifier on the leafs train data and predict_proba for the relevant samples.
|
1132
|
+
|
1133
|
+
Parameters
|
1134
|
+
----------
|
1135
|
+
X_train_leaf : np.ndarray
|
1136
|
+
Training features for the samples in this leaf/node.
|
1137
|
+
y_train_leaf : np.ndarray
|
1138
|
+
Training targets for the samples in this leaf/node.
|
1139
|
+
leaf_id : int
|
1140
|
+
Leaf/node index.
|
1141
|
+
X_full : np.ndarray
|
1142
|
+
Full feature matrix to predict on.
|
1143
|
+
indices : np.ndarray
|
1144
|
+
Indices of X_full that belong to this leaf.
|
1145
|
+
|
1146
|
+
Returns:
|
1147
|
+
-------
|
1148
|
+
np.ndarray
|
1149
|
+
A (n_samples, n_classes) array of probabilities, with only `indices` updated for this leaf.
|
1150
|
+
"""
|
1151
|
+
y_eval_prob = self._init_eval_probability_array(X_full.shape[0], to_zero=True)
|
1152
|
+
classes_in_leaf = [i for i in range(len(np.unique(y_train_leaf)))]
|
1153
|
+
|
1154
|
+
# If only one class, fill probability 1.0 for that class
|
1155
|
+
if len(classes_in_leaf) == 1:
|
1156
|
+
y_eval_prob[indices, classes_in_leaf[0]] = 1.0
|
1157
|
+
return y_eval_prob
|
1158
|
+
|
1159
|
+
# Otherwise, fit TabPFN
|
1160
|
+
leaf_seed = leaf_id + self.tree_seed
|
1161
|
+
try:
|
1162
|
+
# Handle pandas DataFrame or numpy array
|
1163
|
+
if hasattr(X_full, "iloc"):
|
1164
|
+
# Use .iloc for pandas
|
1165
|
+
X_subset = X_full.iloc[indices]
|
1166
|
+
else:
|
1167
|
+
# Use direct indexing for numpy
|
1168
|
+
X_subset = X_full[indices]
|
1169
|
+
|
1170
|
+
try:
|
1171
|
+
self.tabpfn.random_state = leaf_seed
|
1172
|
+
self.tabpfn.fit(X_train_leaf, y_train_leaf)
|
1173
|
+
proba = self.tabpfn.predict_proba(X_subset)
|
1174
|
+
except Exception as e:
|
1175
|
+
from tabpfn.preprocessing import default_classifier_preprocessor_configs, \
|
1176
|
+
default_regressor_preprocessor_configs
|
1177
|
+
backup_inf_conf = deepcopy(self.tabpfn.inference_config)
|
1178
|
+
default_pre = default_classifier_preprocessor_configs if self.task_type == "multiclass" else default_regressor_preprocessor_configs
|
1179
|
+
|
1180
|
+
# Try to run again without preprocessing which might crash
|
1181
|
+
self.tabpfn.random_state = leaf_seed
|
1182
|
+
self.tabpfn.inference_config["PREPROCESS_TRANSFORMS"] = default_pre()
|
1183
|
+
self.tabpfn.inference_config["REGRESSION_Y_PREPROCESS_TRANSFORMS"] = (None, "safepower")
|
1184
|
+
print(self.tabpfn.inference_config)
|
1185
|
+
self.tabpfn.fit(X_train_leaf, y_train_leaf)
|
1186
|
+
proba = self.tabpfn.predict_proba(X_subset)
|
1187
|
+
# reset preprocessing
|
1188
|
+
self.tabpfn.inference_config = backup_inf_conf
|
1189
|
+
|
1190
|
+
for i, c in enumerate(classes_in_leaf):
|
1191
|
+
y_eval_prob[indices, c] = proba[:, i]
|
1192
|
+
|
1193
|
+
except ValueError as e:
|
1194
|
+
if (
|
1195
|
+
not e.args
|
1196
|
+
or e.args[0]
|
1197
|
+
!= "All features are constant and would have been removed! Unable to predict using TabPFN."
|
1198
|
+
):
|
1199
|
+
raise e
|
1200
|
+
warnings.warn(
|
1201
|
+
"One node has constant features for TabPFN. Using class-ratio fallback.",
|
1202
|
+
stacklevel=2,
|
1203
|
+
)
|
1204
|
+
_, counts = np.unique(y_train_leaf, return_counts=True)
|
1205
|
+
ratio = counts / counts.sum()
|
1206
|
+
for i, c in enumerate(classes_in_leaf):
|
1207
|
+
y_eval_prob[indices, c] = ratio[i]
|
1208
|
+
|
1209
|
+
return y_eval_prob
|
1210
|
+
|
1211
|
+
def predict(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
|
1212
|
+
"""Predict class labels for X.
|
1213
|
+
|
1214
|
+
Args:
|
1215
|
+
X: Input features.
|
1216
|
+
check_input: Whether to validate input arrays. Default is True.
|
1217
|
+
|
1218
|
+
Returns:
|
1219
|
+
np.ndarray: Predicted class labels.
|
1220
|
+
"""
|
1221
|
+
# Validate the model is fitted
|
1222
|
+
X = validate_data(
|
1223
|
+
self,
|
1224
|
+
X,
|
1225
|
+
ensure_all_finite=False,
|
1226
|
+
)
|
1227
|
+
check_is_fitted(self, ["_tree", "X", "y"])
|
1228
|
+
proba = self.predict_proba(X, check_input=check_input)
|
1229
|
+
return np.argmax(proba, axis=1)
|
1230
|
+
|
1231
|
+
def predict_proba(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
|
1232
|
+
"""Predict class probabilities for X using the TabPFN leaves.
|
1233
|
+
|
1234
|
+
Args:
|
1235
|
+
X: Input features.
|
1236
|
+
check_input: Whether to validate input arrays. Default is True.
|
1237
|
+
|
1238
|
+
Returns:
|
1239
|
+
np.ndarray: Predicted probabilities of shape (n_samples, n_classes).
|
1240
|
+
"""
|
1241
|
+
# Validate the model is fitted
|
1242
|
+
X = validate_data(
|
1243
|
+
self,
|
1244
|
+
X,
|
1245
|
+
ensure_all_finite=False,
|
1246
|
+
)
|
1247
|
+
check_is_fitted(self, ["_tree", "X", "y"])
|
1248
|
+
return self._predict_internal(X, check_input=check_input)
|
1249
|
+
|
1250
|
+
def _post_fit(self) -> None:
|
1251
|
+
"""Optional hook after the decision tree is fitted."""
|
1252
|
+
if self.verbose:
|
1253
|
+
pass
|
1254
|
+
|
1255
|
+
|
1256
|
+
###############################################################################
|
1257
|
+
# REGRESSOR SUBCLASS #
|
1258
|
+
###############################################################################
|
1259
|
+
|
1260
|
+
|
1261
|
+
class DecisionTreeTabPFNRegressor(DecisionTreeTabPFNBase, RegressorMixin):
|
1262
|
+
"""Decision tree that uses TabPFNRegressor at the leaves."""
|
1263
|
+
|
1264
|
+
task_type: str = "regression"
|
1265
|
+
|
1266
|
+
def __init__(
|
1267
|
+
self,
|
1268
|
+
*,
|
1269
|
+
criterion="squared_error",
|
1270
|
+
splitter="best",
|
1271
|
+
max_depth=None,
|
1272
|
+
min_samples_split=1000,
|
1273
|
+
min_samples_leaf=1,
|
1274
|
+
min_weight_fraction_leaf=0.0,
|
1275
|
+
max_features=None,
|
1276
|
+
random_state=None,
|
1277
|
+
max_leaf_nodes=None,
|
1278
|
+
min_impurity_decrease=0.0,
|
1279
|
+
ccp_alpha=0.0,
|
1280
|
+
monotonic_cst=None,
|
1281
|
+
tabpfn=None,
|
1282
|
+
categorical_features=None,
|
1283
|
+
verbose=False,
|
1284
|
+
show_progress=False,
|
1285
|
+
fit_nodes=True,
|
1286
|
+
tree_seed=0,
|
1287
|
+
adaptive_tree=True,
|
1288
|
+
adaptive_tree_min_train_samples=50,
|
1289
|
+
adaptive_tree_max_train_samples=2000,
|
1290
|
+
adaptive_tree_min_valid_samples_fraction_of_train=0.2,
|
1291
|
+
adaptive_tree_overwrite_metric=None,
|
1292
|
+
adaptive_tree_test_size=0.2,
|
1293
|
+
average_logits=True,
|
1294
|
+
adaptive_tree_skip_class_missing=True,
|
1295
|
+
):
|
1296
|
+
# Call parent constructor
|
1297
|
+
super().__init__(
|
1298
|
+
tabpfn=tabpfn,
|
1299
|
+
criterion=criterion,
|
1300
|
+
splitter=splitter,
|
1301
|
+
max_depth=max_depth,
|
1302
|
+
min_samples_split=min_samples_split,
|
1303
|
+
min_samples_leaf=min_samples_leaf,
|
1304
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
1305
|
+
max_features=max_features,
|
1306
|
+
random_state=random_state,
|
1307
|
+
max_leaf_nodes=max_leaf_nodes,
|
1308
|
+
min_impurity_decrease=min_impurity_decrease,
|
1309
|
+
ccp_alpha=ccp_alpha,
|
1310
|
+
monotonic_cst=monotonic_cst,
|
1311
|
+
categorical_features=categorical_features,
|
1312
|
+
verbose=verbose,
|
1313
|
+
show_progress=show_progress,
|
1314
|
+
fit_nodes=fit_nodes,
|
1315
|
+
tree_seed=tree_seed,
|
1316
|
+
adaptive_tree=adaptive_tree,
|
1317
|
+
adaptive_tree_min_train_samples=adaptive_tree_min_train_samples,
|
1318
|
+
adaptive_tree_max_train_samples=adaptive_tree_max_train_samples,
|
1319
|
+
adaptive_tree_min_valid_samples_fraction_of_train=(
|
1320
|
+
adaptive_tree_min_valid_samples_fraction_of_train
|
1321
|
+
),
|
1322
|
+
adaptive_tree_overwrite_metric=adaptive_tree_overwrite_metric,
|
1323
|
+
adaptive_tree_test_size=adaptive_tree_test_size,
|
1324
|
+
average_logits=average_logits,
|
1325
|
+
adaptive_tree_skip_class_missing=adaptive_tree_skip_class_missing,
|
1326
|
+
)
|
1327
|
+
|
1328
|
+
def _init_decision_tree(self) -> DecisionTreeRegressor:
|
1329
|
+
"""Create a scikit-learn DecisionTreeRegressor with stored parameters."""
|
1330
|
+
return DecisionTreeRegressor(
|
1331
|
+
criterion=self.criterion,
|
1332
|
+
max_depth=self.max_depth,
|
1333
|
+
min_samples_split=self.min_samples_split,
|
1334
|
+
min_samples_leaf=self.min_samples_leaf,
|
1335
|
+
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
|
1336
|
+
max_features=self.max_features,
|
1337
|
+
random_state=self.random_state,
|
1338
|
+
max_leaf_nodes=self.max_leaf_nodes,
|
1339
|
+
min_impurity_decrease=self.min_impurity_decrease,
|
1340
|
+
ccp_alpha=self.ccp_alpha,
|
1341
|
+
splitter=self.splitter,
|
1342
|
+
)
|
1343
|
+
|
1344
|
+
def _predict_leaf(
|
1345
|
+
self,
|
1346
|
+
X_train_leaf: np.ndarray,
|
1347
|
+
y_train_leaf: np.ndarray,
|
1348
|
+
leaf_id: int,
|
1349
|
+
X_full: np.ndarray,
|
1350
|
+
indices: np.ndarray,
|
1351
|
+
) -> np.ndarray:
|
1352
|
+
"""Fit a TabPFNRegressor on the nodes train data, then predict for the relevant samples.
|
1353
|
+
|
1354
|
+
Parameters
|
1355
|
+
----------
|
1356
|
+
X_train_leaf : np.ndarray
|
1357
|
+
Training features for the samples in this leaf/node.
|
1358
|
+
y_train_leaf : np.ndarray
|
1359
|
+
Training targets for the samples in this leaf/node.
|
1360
|
+
leaf_id : int
|
1361
|
+
Leaf/node index.
|
1362
|
+
X_full : np.ndarray
|
1363
|
+
Full feature matrix to predict on.
|
1364
|
+
indices : np.ndarray
|
1365
|
+
Indices of X_full that fall into this leaf.
|
1366
|
+
|
1367
|
+
Returns:
|
1368
|
+
-------
|
1369
|
+
np.ndarray
|
1370
|
+
An array of shape (n_samples,) with predictions; only `indices` are updated.
|
1371
|
+
"""
|
1372
|
+
y_eval = np.zeros(X_full.shape[0], dtype=float)
|
1373
|
+
|
1374
|
+
# If no training data or just 1 sample, fall back to 0 or single value
|
1375
|
+
if len(X_train_leaf) < 1:
|
1376
|
+
warnings.warn(
|
1377
|
+
f"Leaf {leaf_id} has zero training samples. Returning 0.0 predictions.",
|
1378
|
+
stacklevel=2,
|
1379
|
+
)
|
1380
|
+
return y_eval
|
1381
|
+
elif len(X_train_leaf) == 1:
|
1382
|
+
y_eval[indices] = y_train_leaf[0]
|
1383
|
+
return y_eval
|
1384
|
+
|
1385
|
+
# If all y are identical, return that constant
|
1386
|
+
if np.all(y_train_leaf == y_train_leaf[0]):
|
1387
|
+
y_eval[indices] = y_train_leaf[0]
|
1388
|
+
return y_eval
|
1389
|
+
|
1390
|
+
# Fit TabPFNRegressor
|
1391
|
+
leaf_seed = leaf_id + self.tree_seed
|
1392
|
+
try:
|
1393
|
+
self.tabpfn.random_state = leaf_seed
|
1394
|
+
self.tabpfn.fit(X_train_leaf, y_train_leaf)
|
1395
|
+
|
1396
|
+
# Handle pandas DataFrame or numpy array
|
1397
|
+
if hasattr(X_full, "iloc"):
|
1398
|
+
# Use .iloc for pandas
|
1399
|
+
X_subset = X_full.iloc[indices]
|
1400
|
+
else:
|
1401
|
+
# Use direct indexing for numpy
|
1402
|
+
X_subset = X_full[indices]
|
1403
|
+
|
1404
|
+
preds = self.tabpfn.predict(X_subset)
|
1405
|
+
y_eval[indices] = preds
|
1406
|
+
except (ValueError, RuntimeError, NotImplementedError, AssertionError) as e:
|
1407
|
+
warnings.warn(
|
1408
|
+
f"TabPFN fit/predict failed at leaf {leaf_id}: {e}. Using mean fallback.",
|
1409
|
+
stacklevel=2,
|
1410
|
+
)
|
1411
|
+
y_eval[indices] = np.mean(y_train_leaf)
|
1412
|
+
|
1413
|
+
return y_eval
|
1414
|
+
|
1415
|
+
def predict(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
|
1416
|
+
"""Predict regression values using the TabPFN leaves.
|
1417
|
+
|
1418
|
+
Parameters
|
1419
|
+
----------
|
1420
|
+
X : np.ndarray
|
1421
|
+
Input features.
|
1422
|
+
check_input : bool, default=True
|
1423
|
+
Whether to validate the input arrays.
|
1424
|
+
|
1425
|
+
Returns:
|
1426
|
+
-------
|
1427
|
+
np.ndarray
|
1428
|
+
Continuous predictions of shape (n_samples,).
|
1429
|
+
"""
|
1430
|
+
# Validate the model is fitted
|
1431
|
+
X = validate_data(
|
1432
|
+
self,
|
1433
|
+
X,
|
1434
|
+
ensure_all_finite=False,
|
1435
|
+
)
|
1436
|
+
check_is_fitted(self, ["_tree", "X", "y"])
|
1437
|
+
return self._predict_internal(X, check_input=check_input)
|
1438
|
+
|
1439
|
+
def predict_full(self, X: np.ndarray) -> np.ndarray:
|
1440
|
+
"""Convenience method to predict with no input checks (optional).
|
1441
|
+
|
1442
|
+
Parameters
|
1443
|
+
----------
|
1444
|
+
X : np.ndarray
|
1445
|
+
Input features.
|
1446
|
+
|
1447
|
+
Returns:
|
1448
|
+
-------
|
1449
|
+
np.ndarray
|
1450
|
+
Continuous predictions of shape (n_samples,).
|
1451
|
+
"""
|
1452
|
+
# Validate the model is fitted
|
1453
|
+
X = validate_data(
|
1454
|
+
self,
|
1455
|
+
X,
|
1456
|
+
ensure_all_finite=False,
|
1457
|
+
)
|
1458
|
+
check_is_fitted(self, ["_tree", "X", "y"])
|
1459
|
+
return self._predict_internal(X, check_input=False)
|
1460
|
+
|
1461
|
+
def _post_fit(self) -> None:
|
1462
|
+
"""Optional hook after the regressor's tree is fitted."""
|
1463
|
+
if self.verbose:
|
1464
|
+
pass
|