autogluon.tabular 1.3.2b20250711__py3-none-any.whl → 1.3.2b20250712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. autogluon/tabular/models/__init__.py +1 -1
  2. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  3. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  4. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  5. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  6. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  7. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  8. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  9. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  10. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +376 -0
  11. autogluon/tabular/registry/_ag_model_registry.py +2 -2
  12. autogluon/tabular/version.py +1 -1
  13. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/METADATA +13 -15
  14. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/RECORD +21 -14
  15. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  16. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  17. /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth +0 -0
  18. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/LICENSE +0 -0
  19. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/NOTICE +0 -0
  20. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/WHEEL +0 -0
  21. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/namespace_packages.txt +0 -0
  22. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/top_level.txt +0 -0
  23. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/zip-safe +0 -0
@@ -0,0 +1,1464 @@
1
+ # Copyright (c) Prior Labs GmbH 2025.
2
+ # Licensed under the Apache License, Version 2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import random
7
+ import warnings
8
+
9
+ # For type checking only
10
+ from typing import TYPE_CHECKING, Any
11
+ from copy import deepcopy
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ if TYPE_CHECKING:
17
+ from numpy.typing import NDArray
18
+ from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
19
+ from sklearn.model_selection import train_test_split
20
+ from sklearn.preprocessing import LabelEncoder
21
+ from sklearn.tree import (
22
+ BaseDecisionTree,
23
+ DecisionTreeClassifier,
24
+ DecisionTreeRegressor,
25
+ )
26
+ from sklearn.utils.multiclass import unique_labels
27
+ from sklearn.utils.validation import (
28
+ _check_sample_weight,
29
+ check_is_fitted,
30
+ )
31
+
32
+ from .sklearn_compat import validate_data
33
+ from .scoring_utils import (
34
+ score_classification,
35
+ score_regression,
36
+ )
37
+ from .utils import softmax
38
+
39
+ ###############################################################################
40
+ # BASE DECISION TREE #
41
+ ###############################################################################
42
+
43
+
44
+ class DecisionTreeTabPFNBase(BaseDecisionTree, BaseEstimator):
45
+ """Abstract base class combining a scikit-learn Decision Tree with TabPFN at the leaves.
46
+
47
+ This class provides a hybrid approach by combining the standard decision tree
48
+ splitting algorithm from scikit-learn with TabPFN models at the leaves or
49
+ internal nodes. This allows for both interpretable tree-based partitioning
50
+ and high-performance TabPFN prediction.
51
+
52
+ Key features:
53
+ -------------
54
+ • Inherits from sklearn's BaseDecisionTree to leverage standard tree splitting algorithms
55
+ • Uses TabPFN (Classifier or Regressor) to fit leaf nodes (or all internal nodes)
56
+ • Provides adaptive pruning logic (optional) that dynamically determines optimal tree depth
57
+ • Supports both classification and regression through specialized subclasses
58
+
59
+ Subclasses:
60
+ -----------
61
+ • DecisionTreeTabPFNClassifier - for classification tasks
62
+ • DecisionTreeTabPFNRegressor - for regression tasks
63
+
64
+ Parameters
65
+ ----------
66
+ tabpfn : Any
67
+ A TabPFN instance (TabPFNClassifier or TabPFNRegressor) that will be used at tree nodes.
68
+ criterion : str
69
+ The function to measure the quality of a split (from sklearn).
70
+ splitter : str
71
+ The strategy used to choose the split at each node (e.g. "best" or "random").
72
+ max_depth : int, optional
73
+ The maximum depth of the tree (None means unlimited).
74
+ min_samples_split : int
75
+ The minimum number of samples required to split an internal node.
76
+ min_samples_leaf : int
77
+ The minimum number of samples required to be at a leaf node.
78
+ min_weight_fraction_leaf : float
79
+ The minimum weighted fraction of the sum total of weights required to be at a leaf node.
80
+ max_features : Union[int, float, str, None]
81
+ The number of features to consider when looking for the best split.
82
+ random_state : Union[int, np.random.RandomState, None]
83
+ Controls the randomness of the estimator.
84
+ max_leaf_nodes : Optional[int]
85
+ If not None, grow a tree with max_leaf_nodes in best-first fashion.
86
+ min_impurity_decrease : float
87
+ A node will be split if this split induces a decrease of the impurity >= this value.
88
+ class_weight : Optional[Union[Dict[int, float], str]]
89
+ Only used in classification. Dict of class -> weight or “balanced”.
90
+ ccp_alpha : float
91
+ Complexity parameter used for Minimal Cost-Complexity Pruning (non-negative).
92
+ monotonic_cst : Any
93
+ Optional monotonicity constraints (depending on sklearn version).
94
+ categorical_features : Optional[List[int]]
95
+ Indices of categorical features for TabPFN usage (if any).
96
+ verbose : Union[bool, int]
97
+ Verbosity level; higher values produce more output.
98
+ show_progress : bool
99
+ Whether to show progress bars for leaf/node fitting using TabPFN.
100
+ fit_nodes : bool
101
+ Whether to fit TabPFN at internal nodes (True) or only final leaves (False).
102
+ tree_seed : int
103
+ Used to set seeds for TabPFN fitting in each node.
104
+ adaptive_tree : bool
105
+ Whether to do adaptive node-by-node pruning using a hold-out strategy.
106
+ adaptive_tree_min_train_samples : int
107
+ Minimum number of training samples required to fit a TabPFN in a node.
108
+ adaptive_tree_max_train_samples : int
109
+ Maximum number of training samples above which a node might be pruned if not a final leaf.
110
+ adaptive_tree_min_valid_samples_fraction_of_train : float
111
+ Fraction controlling the minimum valid/test points to consider a node for re-fitting.
112
+ adaptive_tree_overwrite_metric : Optional[str]
113
+ If set, overrides the default metric for pruning. E.g., "roc" or "rmse".
114
+ adaptive_tree_test_size : float
115
+ Fraction of data to hold out for adaptive pruning if no separate valid set is provided.
116
+ average_logits : bool
117
+ Whether to average logits (True) or probabilities (False) when combining predictions.
118
+ adaptive_tree_skip_class_missing : bool
119
+ If True, skip re-fitting if the nodes training set does not contain all classes (classification only).
120
+ """
121
+
122
+ # Task type set by subclasses: "multiclass" or "regression"
123
+ task_type: str | None = None
124
+
125
+ def __init__(
126
+ self,
127
+ *,
128
+ # Decision Tree arguments
129
+ criterion: str = "gini",
130
+ splitter: str = "best",
131
+ max_depth: int | None = None,
132
+ min_samples_split: int = 1000,
133
+ min_samples_leaf: int = 1,
134
+ min_weight_fraction_leaf: float = 0.0,
135
+ max_features: int | float | str | None = None,
136
+ random_state: int | np.random.RandomState | None = None,
137
+ max_leaf_nodes: int | None = None,
138
+ min_impurity_decrease: float = 0.0,
139
+ class_weight: dict[int, float] | str | None = None,
140
+ ccp_alpha: float = 0.0,
141
+ monotonic_cst: Any = None,
142
+ # TabPFN argument
143
+ tabpfn: Any = None, # TabPFNClassifier or TabPFNRegressor
144
+ categorical_features: list[int] | None = None,
145
+ verbose: bool | int = False,
146
+ show_progress: bool = False,
147
+ fit_nodes: bool = True,
148
+ tree_seed: int = 0,
149
+ adaptive_tree: bool = True,
150
+ adaptive_tree_min_train_samples: int = 50,
151
+ adaptive_tree_max_train_samples: int = 2000,
152
+ adaptive_tree_min_valid_samples_fraction_of_train: float = 0.2,
153
+ adaptive_tree_overwrite_metric: str | None = None,
154
+ adaptive_tree_test_size: float = 0.2,
155
+ average_logits: bool = True,
156
+ adaptive_tree_skip_class_missing: bool = True,
157
+ ):
158
+ # Collect recognized arguments
159
+ self.tabpfn = tabpfn
160
+ self.criterion = criterion
161
+ self.splitter = splitter
162
+ self.max_depth = max_depth
163
+ self.min_samples_split = min_samples_split
164
+ self.min_samples_leaf = min_samples_leaf
165
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
166
+ self.max_features = max_features
167
+ self.random_state = random_state
168
+ self.max_leaf_nodes = max_leaf_nodes
169
+ self.min_impurity_decrease = min_impurity_decrease
170
+ self.class_weight = class_weight
171
+ self.ccp_alpha = ccp_alpha
172
+ self.monotonic_cst = monotonic_cst
173
+
174
+ self.categorical_features = categorical_features
175
+ self.verbose = verbose
176
+ self.show_progress = show_progress
177
+ self.fit_nodes = fit_nodes
178
+ self.tree_seed = tree_seed
179
+ self.adaptive_tree = adaptive_tree
180
+ self.adaptive_tree_min_train_samples = adaptive_tree_min_train_samples
181
+ self.adaptive_tree_max_train_samples = adaptive_tree_max_train_samples
182
+ self.adaptive_tree_min_valid_samples_fraction_of_train = (
183
+ adaptive_tree_min_valid_samples_fraction_of_train
184
+ )
185
+ self.adaptive_tree_overwrite_metric = adaptive_tree_overwrite_metric
186
+ self.adaptive_tree_test_size = adaptive_tree_test_size
187
+ self.average_logits = average_logits
188
+ self.adaptive_tree_skip_class_missing = adaptive_tree_skip_class_missing
189
+
190
+ # Initialize internal flags/structures that will be set during fit
191
+ self._need_post_fit: bool = False
192
+ self._decision_tree = None
193
+
194
+ # Handling possible differences in sklearn versions, specifically monotonic_cst
195
+ optional_args_filtered = {}
196
+ if BaseDecisionTree.__init__.__code__.co_varnames.__contains__("monotonic_cst"):
197
+ optional_args_filtered["monotonic_cst"] = monotonic_cst
198
+
199
+ # Initialize the underlying DecisionTree
200
+ super().__init__(
201
+ criterion=self.criterion,
202
+ splitter=self.splitter,
203
+ max_depth=self.max_depth,
204
+ min_samples_split=self.min_samples_split,
205
+ min_samples_leaf=self.min_samples_leaf,
206
+ min_weight_fraction_leaf=self.min_weight_fraction_leaf,
207
+ max_features=self.max_features,
208
+ random_state=self.random_state,
209
+ max_leaf_nodes=self.max_leaf_nodes,
210
+ min_impurity_decrease=self.min_impurity_decrease,
211
+ ccp_alpha=self.ccp_alpha,
212
+ **optional_args_filtered,
213
+ )
214
+
215
+ # If the user gave a TabPFN, we do not want it to have a random_state forcibly set
216
+ # because we handle seeds ourselves at each node
217
+ if self.tabpfn is not None:
218
+ self.tabpfn.random_state = None
219
+
220
+ def _validate_tabpfn_runtime(self) -> None:
221
+ """Validate the TabPFN instance at runtime before using it.
222
+
223
+ This ensures the TabPFN instance is still available when needed during
224
+ prediction or fitting operations.
225
+
226
+ Raises:
227
+ ValueError: If self.tabpfn is None at runtime
228
+ """
229
+ if self.tabpfn is None:
230
+ raise ValueError("TabPFN was None at runtime - cannot proceed.")
231
+
232
+ def _more_tags(self) -> dict[str, Any]:
233
+ return {
234
+ "allow_nan": True,
235
+ }
236
+
237
+ def __sklearn_tags__(self):
238
+ tags = super().__sklearn_tags__()
239
+ tags.input_tags.allow_nan = True
240
+ tags.estimator_type = "regressor"
241
+ if self.task_type == "multiclass":
242
+ tags.estimator_type = "classifier"
243
+ else:
244
+ tags.estimator_type = "regressor"
245
+ return tags
246
+
247
+ def fit(
248
+ self,
249
+ X: NDArray[np.float64],
250
+ y: NDArray[Any],
251
+ sample_weight: NDArray[np.float64] | None = None,
252
+ check_input: bool = True,
253
+ ) -> DecisionTreeTabPFNBase:
254
+ """Fit the DecisionTree + TabPFN model.
255
+
256
+ This method trains the hybrid model by:
257
+ 1. Building a decision tree structure
258
+ 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True)
259
+ 3. Optionally performing adaptive pruning if adaptive_tree=True
260
+
261
+ Args:
262
+ X: The training input samples, shape (n_samples, n_features).
263
+ y: The target values (class labels for classification, real values for regression),
264
+ shape (n_samples,) or (n_samples, n_outputs).
265
+ sample_weight: Sample weights. If None, then samples are equally weighted.
266
+ check_input: Whether to validate the input data arrays. Default is True.
267
+
268
+ Returns:
269
+ self: Fitted estimator.
270
+ """
271
+ return self._fit(X, y, sample_weight=sample_weight, check_input=check_input)
272
+
273
+ def _fit(
274
+ self,
275
+ X: NDArray[Any],
276
+ y: NDArray[Any],
277
+ sample_weight: NDArray[Any] | None = None,
278
+ check_input: bool = True,
279
+ missing_values_in_feature_mask: np.ndarray | None = None, # Unused placeholder
280
+ ) -> DecisionTreeTabPFNBase:
281
+ """Internal method to fit the DecisionTree-TabPFN model on X, y.
282
+
283
+ Parameters
284
+ ----------
285
+ X : NDArray
286
+ Training features of shape (n_samples, n_features).
287
+ y : NDArray
288
+ Target labels/values of shape (n_samples,).
289
+ sample_weight : NDArray, optional
290
+ Sample weights for each sample.
291
+ check_input : bool
292
+ Whether to check inputs.
293
+ missing_values_in_feature_mask : np.ndarray, optional
294
+ Unused placeholder for older code or possible expansions.
295
+
296
+ Returns:
297
+ -------
298
+ self : DecisionTreeTabPFNBase
299
+ The fitted model.
300
+ """
301
+ # Initialize attributes (per scikit-learn conventions)
302
+ self._leaf_nodes = []
303
+ self._leaf_train_data = {}
304
+ self._label_encoder = LabelEncoder()
305
+ self._need_post_fit = False
306
+ self._node_prediction_type = {}
307
+
308
+ # Make sure tabpfn is valid
309
+ self._validate_tabpfn_runtime()
310
+
311
+ # Possibly randomize tree_seed if not set
312
+ if self.tree_seed == 0:
313
+ self.tree_seed = random.randint(1, 10000)
314
+
315
+ sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float64)
316
+ X, y = validate_data(
317
+ self,
318
+ X,
319
+ y,
320
+ ensure_all_finite=False, # scikit-learn sets self.n_features_in_ automatically
321
+ )
322
+
323
+ if self.task_type == "multiclass":
324
+ self.classes_ = unique_labels(y)
325
+ self.n_classes_ = len(self.classes_)
326
+
327
+ # Convert torch tensor -> numpy if needed, handle NaNs
328
+ X_preprocessed = self._preprocess_data_for_tree(X)
329
+
330
+ if sample_weight is None:
331
+ sample_weight = np.ones((X_preprocessed.shape[0],), dtype=np.float64)
332
+
333
+ # Setup classes_ or n_classes_ if needed
334
+ if self.task_type == "multiclass":
335
+ # Classification
336
+ self.classes_ = np.unique(y)
337
+ self.n_classes_ = len(self.classes_)
338
+ else:
339
+ # Regression
340
+ self.n_classes_ = (
341
+ 1 # Not used for numeric tasks, but keep it for consistency
342
+ )
343
+
344
+ # Possibly label-encode y for classification if your TabPFN needs it
345
+ # (Here we just rely on uniqueness checks above.)
346
+ y_ = y.copy()
347
+
348
+ # If adaptive_tree is on, do a train/validation split
349
+ if self.adaptive_tree:
350
+ stratify = y_ if (self.task_type == "multiclass") else None
351
+
352
+ # Basic checks for classification to see if splitting is feasible
353
+ if self.task_type == "multiclass":
354
+ unique_classes, counts = np.unique(y_, return_counts=True)
355
+ # Disable adaptive tree in extreme cases
356
+ if counts.min() == 1 or len(unique_classes) < 2:
357
+ self.adaptive_tree = False
358
+ elif len(unique_classes) > int(len(y_) * self.adaptive_tree_test_size):
359
+ self.adaptive_tree_test_size = min(
360
+ 0.5,
361
+ len(unique_classes) / len(y_) * 1.5,
362
+ )
363
+ if len(y_) < 10:
364
+ self.adaptive_tree = False
365
+
366
+ if self.adaptive_tree:
367
+ (
368
+ X_train,
369
+ X_valid,
370
+ X_preproc_train,
371
+ X_preproc_valid,
372
+ y_train,
373
+ y_valid,
374
+ sw_train,
375
+ sw_valid,
376
+ ) = train_test_split(
377
+ X,
378
+ X_preprocessed,
379
+ y_,
380
+ sample_weight,
381
+ test_size=self.adaptive_tree_test_size,
382
+ random_state=self.random_state,
383
+ stratify=stratify,
384
+ )
385
+
386
+ # Safety check - if split is empty, revert
387
+ if len(y_train) == 0 or len(y_valid) == 0:
388
+ self.adaptive_tree = False
389
+ X_train, X_preproc_train, y_train, sw_train = (
390
+ X,
391
+ X_preprocessed,
392
+ y_,
393
+ sample_weight,
394
+ )
395
+ X_valid = X_preproc_valid = y_valid = sw_valid = None
396
+
397
+ # If classification, also ensure train/valid has same classes
398
+ if (
399
+ self.task_type == "multiclass"
400
+ and self.adaptive_tree
401
+ and (len(np.unique(y_train)) != len(np.unique(y_valid)))
402
+ ):
403
+ self.adaptive_tree = False
404
+ else:
405
+ # If we were disabled, keep all data as training
406
+ X_train, X_preproc_train, y_train, sw_train = (
407
+ X,
408
+ X_preprocessed,
409
+ y_,
410
+ sample_weight,
411
+ )
412
+ X_valid = X_preproc_valid = y_valid = sw_valid = None
413
+ else:
414
+ # Not adaptive, everything is train
415
+ X_train, X_preproc_train, y_train, sw_train = (
416
+ X,
417
+ X_preprocessed,
418
+ y_,
419
+ sample_weight,
420
+ )
421
+ X_valid = X_preproc_valid = y_valid = sw_valid = None
422
+
423
+ # Build the sklearn decision tree
424
+ self._decision_tree = self._init_decision_tree()
425
+ self._decision_tree.fit(X_preproc_train, y_train, sample_weight=sw_train)
426
+ self._tree = self._decision_tree # for sklearn compatibility
427
+
428
+ # Keep references for potential post-fitting (leaf-level fitting)
429
+ self.X = X
430
+ self.y = y_
431
+ self.train_X = X_train
432
+ self.train_X_preprocessed = X_preproc_train
433
+ self.train_y = y_train
434
+ self.train_sample_weight = sw_train
435
+
436
+ if self.adaptive_tree:
437
+ self.valid_X = X_valid
438
+ self.valid_X_preprocessed = X_preproc_valid
439
+ self.valid_y = y_valid
440
+ self.valid_sample_weight = sw_valid
441
+
442
+ # We will do a leaf-fitting step on demand (lazy) in predict
443
+ self._need_post_fit = True
444
+
445
+ # If verbose, optionally do it right away:
446
+ if self.verbose:
447
+ self._post_fit()
448
+
449
+ return self
450
+
451
+ def _init_decision_tree(self) -> BaseDecisionTree:
452
+ """Initialize the underlying scikit-learn Decision Tree.
453
+
454
+ Overridden by child classes for classifier vs regressor.
455
+
456
+ Returns:
457
+ -------
458
+ BaseDecisionTree
459
+ An instance of a scikit-learn DecisionTreeClassifier or DecisionTreeRegressor.
460
+ """
461
+ raise NotImplementedError("Must be implemented in subclass.")
462
+
463
+ def _post_fit(self) -> None:
464
+ """Hook after the decision tree is fitted. Can be used for final prints/logs."""
465
+ if self.verbose:
466
+ pass
467
+
468
+ def _preprocess_data_for_tree(self, X: np.ndarray) -> np.ndarray:
469
+ """Handle missing data prior to feeding into the decision tree.
470
+
471
+ Replaces NaNs with a default value, handles pandas DataFrames and other input types.
472
+ Uses scikit-learn's validation functions for compatibility.
473
+
474
+ Parameters
475
+ ----------
476
+ X : array-like
477
+ Input features, possibly containing NaNs.
478
+
479
+ Returns:
480
+ -------
481
+ np.ndarray
482
+ A copy of X with NaNs replaced by default value.
483
+ """
484
+ # Use check_array from sklearn_compat to handle different input types
485
+ from .sklearn_compat import check_array
486
+
487
+ # Handle torch tensor
488
+ if torch.is_tensor(X):
489
+ X = X.cpu().numpy()
490
+
491
+ # Convert to array and handle input validation
492
+ # Don't extract DataFrame values - let check_array handle it
493
+ X = check_array(
494
+ X,
495
+ dtype=np.float64,
496
+ ensure_all_finite=False, # We'll handle NaNs ourselves
497
+ ensure_2d=True,
498
+ copy=True, # Make a copy so we don't modify the original
499
+ )
500
+
501
+ # Replace NaN with our specific value (-1000.0)
502
+ X = np.nan_to_num(X, nan=-1000.0)
503
+ return X
504
+
505
+ def _apply_tree(self, X: np.ndarray) -> np.ndarray:
506
+ """Apply the fitted tree to X, returning a matrix of leaf membership.
507
+
508
+ Returns:
509
+ -------
510
+ np.ndarray
511
+ A dense matrix of shape (n_samples, n_nodes, n_estimators),
512
+ though we typically only have 1 estimator.
513
+ """
514
+ X_preprocessed = self._preprocess_data_for_tree(X)
515
+ decision_path = self.get_tree().decision_path(X_preprocessed)
516
+ return np.expand_dims(decision_path.todense(), axis=2)
517
+
518
+ def _apply_tree_train(
519
+ self,
520
+ X: np.ndarray,
521
+ y: np.ndarray,
522
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
523
+ """Apply the tree for training data, returning leaf membership plus (X, y) unchanged.
524
+
525
+ Returns:
526
+ -------
527
+ leaf_matrix : np.ndarray
528
+ Shape (n_samples, n_nodes, n_estimators)
529
+ X_array : np.ndarray
530
+ Same as X input
531
+ y_array : np.ndarray
532
+ Same as y input
533
+ """
534
+ return self._apply_tree(X), X, y
535
+
536
+ def get_tree(self) -> BaseDecisionTree:
537
+ """Return the underlying fitted sklearn decision tree.
538
+
539
+ Returns:
540
+ DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree.
541
+
542
+ Raises:
543
+ sklearn.exceptions.NotFittedError: If the model has not been fitted yet.
544
+ """
545
+ # This will raise NotFittedError if the model is not fitted
546
+ check_is_fitted(self, ["_tree", "X", "y"])
547
+ return self._tree
548
+
549
+ @property
550
+ def tree_(self):
551
+ """Expose the fitted tree for sklearn compatibility.
552
+
553
+ Returns:
554
+ -------
555
+ sklearn.tree._tree.Tree
556
+ Underlying scikit-learn tree object.
557
+ """
558
+ return self.get_tree().tree_
559
+
560
+ def fit_leaves(
561
+ self,
562
+ train_X: np.ndarray,
563
+ train_y: np.ndarray,
564
+ ) -> None:
565
+ """Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).
566
+
567
+ This populates an internal dictionary of training data for each leaf
568
+ so that TabPFN can make predictions at these leaves.
569
+
570
+ Parameters
571
+ ----------
572
+ train_X : np.ndarray
573
+ Training features for all samples.
574
+ train_y : np.ndarray
575
+ Training labels/targets for all samples.
576
+ """
577
+ self._leaf_train_data = {}
578
+ leaf_node_matrix, _, _ = self._apply_tree_train(train_X, train_y)
579
+ self._leaf_nodes = leaf_node_matrix
580
+
581
+ n_samples, n_nodes, n_estims = leaf_node_matrix.shape
582
+
583
+ for estimator_id in range(n_estims):
584
+ self._leaf_train_data[estimator_id] = {}
585
+ for leaf_id in range(n_nodes):
586
+ indices = np.argwhere(
587
+ leaf_node_matrix[:, leaf_id, estimator_id],
588
+ ).ravel()
589
+ X_leaf_samples = np.take(train_X, indices, axis=0)
590
+ y_leaf_samples = np.take(train_y, indices, axis=0).ravel()
591
+
592
+ self._leaf_train_data[estimator_id][leaf_id] = (
593
+ X_leaf_samples,
594
+ y_leaf_samples,
595
+ )
596
+
597
+ def _predict_internal(
598
+ self,
599
+ X: np.ndarray,
600
+ y: np.ndarray | None = None,
601
+ check_input: bool = True,
602
+ ) -> np.ndarray:
603
+ """Internal method used to produce probabilities or regression predictions,
604
+ with optional adaptive pruning logic.
605
+
606
+ If y is given and we have adaptive_tree=True, node-level pruning is applied.
607
+
608
+ Parameters
609
+ ----------
610
+ X : np.ndarray
611
+ Features to predict.
612
+ y : np.ndarray, optional
613
+ Target values, only required if we are in adaptive pruning mode
614
+ and need to compare node performance.
615
+ check_input : bool, default=True
616
+ Whether to validate input arrays.
617
+
618
+ Returns:
619
+ -------
620
+ np.ndarray
621
+ The final predictions (probabilities for classification, or continuous values for regression).
622
+ """
623
+ # If we haven't yet done the final leaf fit, do it here
624
+ if self._need_post_fit:
625
+ self._need_post_fit = False
626
+ if self.adaptive_tree:
627
+ # Fit leaves on train data, check performance on valid data if available
628
+ self.fit_leaves(self.train_X, self.train_y)
629
+ if (
630
+ hasattr(self, "valid_X")
631
+ and self.valid_X is not None
632
+ and self.valid_y is not None
633
+ ):
634
+ # Force a pass to evaluate node performance
635
+ # so we can prune or decide node updates
636
+ self._predict_internal(
637
+ self.valid_X,
638
+ self.valid_y,
639
+ check_input=False,
640
+ )
641
+ # Now fit leaves again using the entire dataset (train + valid, effectively)
642
+ self.fit_leaves(self.X, self.y)
643
+
644
+ # Assign TabPFNs categorical features if needed
645
+ if self.tabpfn is not None:
646
+ self.tabpfn.categorical_features_indices = self.categorical_features
647
+
648
+ # Find leaf membership in X
649
+ X_leaf_nodes = self._apply_tree(X)
650
+ n_samples, n_nodes, n_estims = X_leaf_nodes.shape
651
+
652
+ # Track intermediate predictions
653
+ y_prob: dict[int, dict[int, np.ndarray]] = {}
654
+ y_metric: dict[int, dict[int, float]] = {}
655
+
656
+ # If pruning, track how each node is updated
657
+ do_pruning = (y is not None) and self.adaptive_tree
658
+ if do_pruning:
659
+ self._node_prediction_type: dict[int, dict[int, str]] = {}
660
+
661
+ for est_id in range(n_estims):
662
+ if do_pruning:
663
+ self._node_prediction_type[est_id] = {}
664
+ y_prob[est_id] = {}
665
+ y_metric[est_id] = {}
666
+ if self.show_progress:
667
+ import tqdm.auto
668
+
669
+ node_iter = tqdm.auto.tqdm(range(n_nodes), desc=f"Estimator {est_id}")
670
+ else:
671
+ node_iter = range(n_nodes)
672
+
673
+ for leaf_id in node_iter:
674
+ self._pruning_init_node_predictions(
675
+ leaf_id,
676
+ est_id,
677
+ y_prob,
678
+ y_metric,
679
+ n_nodes,
680
+ n_samples,
681
+ )
682
+ if est_id > 0 and leaf_id == 0:
683
+ # Skip repeated re-initialization if multiple trees
684
+ continue
685
+
686
+ # Gather test-sample indices that belong to this leaf
687
+ test_sample_indices = np.argwhere(
688
+ X_leaf_nodes[:, leaf_id, est_id],
689
+ ).ravel()
690
+
691
+ # Gather training samples that belong to this leaf
692
+ X_train_leaf, y_train_leaf = self._leaf_train_data[est_id][leaf_id]
693
+
694
+ # If no training or test samples in this node, skip
695
+ if (X_train_leaf.shape[0] == 0) or (len(test_sample_indices) == 0):
696
+ if do_pruning:
697
+ self._node_prediction_type[est_id][leaf_id] = "previous"
698
+ continue
699
+
700
+ # Determine if this is a final leaf
701
+ # If the sum of membership in subsequent nodes is zero, its final
702
+ is_leaf = (
703
+ X_leaf_nodes[test_sample_indices, leaf_id + 1 :, est_id].sum()
704
+ == 0.0
705
+ )
706
+
707
+ # If it's not a leaf and we are not fitting internal nodes, skip
708
+ # (unless leaf_id==0 and we do a top-level check for adaptive_tree)
709
+ if (
710
+ (not is_leaf)
711
+ and (not self.fit_nodes)
712
+ and not (leaf_id == 0 and self.adaptive_tree)
713
+ ):
714
+ if do_pruning:
715
+ self._node_prediction_type[est_id][leaf_id] = "previous"
716
+ continue
717
+
718
+ # Additional adaptive checks
719
+ if self.adaptive_tree and leaf_id != 0:
720
+ should_skip_previously_pruned = False
721
+ if y is None:
722
+ # Safely check if the key exists before accessing
723
+ node_type = self._node_prediction_type.get(est_id, {}).get(
724
+ leaf_id,
725
+ )
726
+ if node_type == "previous":
727
+ should_skip_previously_pruned = True
728
+
729
+ if should_skip_previously_pruned:
730
+ continue
731
+
732
+ # Skip if classification is missing a class
733
+ if (
734
+ self.task_type == "multiclass"
735
+ and len(np.unique(y_train_leaf)) < self.n_classes_
736
+ and self.adaptive_tree_skip_class_missing
737
+ ):
738
+ self._node_prediction_type[est_id][leaf_id] = "previous"
739
+ continue
740
+
741
+ # Skip if too few or too many training points
742
+ if (
743
+ (X_train_leaf.shape[0] < self.adaptive_tree_min_train_samples)
744
+ or (
745
+ len(test_sample_indices)
746
+ < self.adaptive_tree_min_valid_samples_fraction_of_train
747
+ * self.adaptive_tree_min_train_samples
748
+ )
749
+ or (
750
+ X_train_leaf.shape[0] > self.adaptive_tree_max_train_samples
751
+ and not is_leaf
752
+ )
753
+ ):
754
+ if do_pruning:
755
+ self._node_prediction_type[est_id][leaf_id] = "previous"
756
+ continue
757
+
758
+ # Perform leaf-level TabPFN prediction
759
+ leaf_prediction = self._predict_leaf(
760
+ X_train_leaf,
761
+ y_train_leaf,
762
+ leaf_id,
763
+ X,
764
+ test_sample_indices,
765
+ )
766
+
767
+ # Evaluate “averaging” and “replacement” for pruning
768
+ y_prob_averaging, y_prob_replacement = (
769
+ self._pruning_get_prediction_type_results(
770
+ y_prob,
771
+ leaf_prediction,
772
+ test_sample_indices,
773
+ est_id,
774
+ leaf_id,
775
+ )
776
+ )
777
+
778
+ # Decide best approach if in adaptive mode
779
+ if self.adaptive_tree:
780
+ # If not adaptive, we simply do replacement
781
+ y_prob[est_id][leaf_id] = y_prob_replacement
782
+ elif y is not None:
783
+ self._pruning_set_node_prediction_type(
784
+ y,
785
+ y_prob_averaging,
786
+ y_prob_replacement,
787
+ y_metric,
788
+ est_id,
789
+ leaf_id,
790
+ )
791
+ self._pruning_set_predictions(
792
+ y_prob,
793
+ y_prob_averaging,
794
+ y_prob_replacement,
795
+ est_id,
796
+ leaf_id,
797
+ )
798
+ y_metric[est_id][leaf_id] = self._score(
799
+ y,
800
+ y_prob[est_id][leaf_id],
801
+ )
802
+ else:
803
+ # If not validating and not adaptive, just use replacement
804
+ y_prob[est_id][leaf_id] = y_prob_replacement
805
+
806
+ # Final predictions come from the last estimators last node
807
+ return y_prob[n_estims - 1][n_nodes - 1]
808
+
809
+ def _pruning_init_node_predictions(
810
+ self,
811
+ leaf_id: int,
812
+ estimator_id: int,
813
+ y_prob: dict[int, dict[int, np.ndarray]],
814
+ y_metric: dict[int, dict[int, float]],
815
+ n_nodes: int,
816
+ n_samples: int,
817
+ ) -> None:
818
+ """Initialize node predictions for the pruning logic.
819
+
820
+ Parameters
821
+ ----------
822
+ leaf_id : int
823
+ Index of the leaf/node being processed.
824
+ estimator_id : int
825
+ Index of the current estimator (if multiple).
826
+ y_prob : dict
827
+ Nested dictionary of predictions.
828
+ y_metric : dict
829
+ Nested dictionary of scores/metrics.
830
+ n_nodes : int
831
+ Total number of nodes in the tree.
832
+ n_samples : int
833
+ Number of samples in X.
834
+ """
835
+ if estimator_id == 0 and leaf_id == 0:
836
+ y_prob[0][0] = self._init_eval_probability_array(n_samples, to_zero=True)
837
+ y_metric[0][0] = 0.0
838
+ elif leaf_id == 0 and estimator_id > 0:
839
+ # If first leaf of new estimator, carry from last node of previous estimator
840
+ y_prob[estimator_id][leaf_id] = y_prob[estimator_id - 1][n_nodes - 1]
841
+ y_metric[estimator_id][leaf_id] = y_metric[estimator_id - 1][n_nodes - 1]
842
+ else:
843
+ # Use last leaf of the same estimator
844
+ y_prob[estimator_id][leaf_id] = y_prob[estimator_id][leaf_id - 1]
845
+ y_metric[estimator_id][leaf_id] = y_metric[estimator_id][leaf_id - 1]
846
+
847
+ def _pruning_get_prediction_type_results(
848
+ self,
849
+ y_eval_prob: dict[int, dict[int, np.ndarray]],
850
+ leaf_prediction: np.ndarray,
851
+ test_sample_indices: np.ndarray,
852
+ estimator_id: int,
853
+ leaf_id: int,
854
+ ) -> tuple[np.ndarray, np.ndarray]:
855
+ """Produce the “averaging” and “replacement” predictions for pruning decisions.
856
+
857
+ Parameters
858
+ ----------
859
+ y_eval_prob : dict
860
+ Nested dictionary of predictions.
861
+ leaf_prediction : np.ndarray
862
+ Predictions from the newly fitted leaf (for relevant samples).
863
+ test_sample_indices : np.ndarray
864
+ Indices of the test samples that fall into this leaf.
865
+ estimator_id : int
866
+ Index of the current estimator.
867
+ leaf_id : int
868
+ Index of the current leaf/node.
869
+
870
+ Returns:
871
+ -------
872
+ y_prob_averaging : np.ndarray
873
+ Updated predictions using an “averaging” rule.
874
+ y_prob_replacement : np.ndarray
875
+ Updated predictions using a “replacement” rule.
876
+ """
877
+ y_prob_current = y_eval_prob[estimator_id][leaf_id]
878
+ y_prob_replacement = np.copy(y_prob_current)
879
+ # "replacement" sets the new leaf prediction directly
880
+ y_prob_replacement[test_sample_indices] = leaf_prediction[test_sample_indices]
881
+
882
+ if self.task_type == "multiclass":
883
+ # Normalize
884
+ row_sums = y_prob_replacement.sum(axis=1, keepdims=True)
885
+ row_sums[row_sums == 0] = 1.0
886
+ y_prob_replacement /= row_sums
887
+
888
+ # "averaging" -> combine old predictions with new
889
+ y_prob_averaging = np.copy(y_prob_current)
890
+
891
+ if self.task_type == "multiclass":
892
+ if self.average_logits:
893
+ # Convert old + new to log, sum them, then softmax
894
+ y_prob_averaging[test_sample_indices] = np.log(
895
+ y_prob_averaging[test_sample_indices] + 1e-6,
896
+ )
897
+ leaf_pred_log = np.log(leaf_prediction[test_sample_indices] + 1e-6)
898
+ y_prob_averaging[test_sample_indices] += leaf_pred_log
899
+ y_prob_averaging[test_sample_indices] = softmax(
900
+ y_prob_averaging[test_sample_indices],
901
+ )
902
+ else:
903
+ # Average probabilities directly
904
+ y_prob_averaging[test_sample_indices] += leaf_prediction[
905
+ test_sample_indices
906
+ ]
907
+ row_sums = y_prob_averaging.sum(axis=1, keepdims=True)
908
+ row_sums[row_sums == 0] = 1.0
909
+ y_prob_averaging /= row_sums
910
+ elif self.task_type == "regression":
911
+ # Regression -> simply average
912
+ y_prob_averaging[test_sample_indices] += leaf_prediction[
913
+ test_sample_indices
914
+ ]
915
+ y_prob_averaging[test_sample_indices] /= 2.0
916
+
917
+ return y_prob_averaging, y_prob_replacement
918
+
919
+ def _pruning_set_node_prediction_type(
920
+ self,
921
+ y_true: np.ndarray,
922
+ y_prob_averaging: np.ndarray,
923
+ y_prob_replacement: np.ndarray,
924
+ y_metric: dict[int, dict[int, float]],
925
+ estimator_id: int,
926
+ leaf_id: int,
927
+ ) -> None:
928
+ """Decide which approach is better: “averaging” vs “replacement” vs “previous,”
929
+ using the nodes previous metric vs new metrics.
930
+
931
+ Parameters
932
+ ----------
933
+ y_true : np.ndarray
934
+ Ground-truth labels/targets for pruning comparison.
935
+ y_prob_averaging : np.ndarray
936
+ Predictions if we use averaging.
937
+ y_prob_replacement : np.ndarray
938
+ Predictions if we use replacement.
939
+ y_metric : dict
940
+ Nested dictionary of scores/metrics for each node.
941
+ estimator_id : int
942
+ Index of the current estimator.
943
+ leaf_id : int
944
+ Index of the current leaf/node.
945
+ """
946
+ averaging_score = self._score(y_true, y_prob_averaging)
947
+ replacement_score = self._score(y_true, y_prob_replacement)
948
+ prev_score = y_metric[estimator_id][leaf_id - 1] if (leaf_id > 0) else 0.0
949
+
950
+ if (leaf_id == 0) or (max(averaging_score, replacement_score) > prev_score):
951
+ # Pick whichever is better
952
+ if replacement_score > averaging_score:
953
+ prediction_type = "replacement"
954
+ else:
955
+ prediction_type = "averaging"
956
+ else:
957
+ prediction_type = "previous"
958
+
959
+ self._node_prediction_type[estimator_id][leaf_id] = prediction_type
960
+
961
+ def _pruning_set_predictions(
962
+ self,
963
+ y_prob: dict[int, dict[int, np.ndarray]],
964
+ y_prob_averaging: np.ndarray,
965
+ y_prob_replacement: np.ndarray,
966
+ estimator_id: int,
967
+ leaf_id: int,
968
+ ) -> None:
969
+ """Based on the chosen node_prediction_type, finalize the predictions.
970
+
971
+ Parameters
972
+ ----------
973
+ y_prob : dict
974
+ Nested dictionary of predictions.
975
+ y_prob_averaging : np.ndarray
976
+ Predictions if we use averaging.
977
+ y_prob_replacement : np.ndarray
978
+ Predictions if we use replacement.
979
+ estimator_id : int
980
+ Index of the current estimator.
981
+ leaf_id : int
982
+ Index of the current leaf/node.
983
+ """
984
+ node_type = self._node_prediction_type[estimator_id][leaf_id]
985
+ if node_type == "averaging":
986
+ y_prob[estimator_id][leaf_id] = y_prob_averaging
987
+ elif node_type == "replacement":
988
+ y_prob[estimator_id][leaf_id] = y_prob_replacement
989
+ else:
990
+ # “previous”
991
+ y_prob[estimator_id][leaf_id] = y_prob[estimator_id][leaf_id - 1]
992
+
993
+ def _init_eval_probability_array(
994
+ self,
995
+ n_samples: int,
996
+ to_zero: bool = False,
997
+ ) -> np.ndarray:
998
+ """Initialize an array of predictions for the entire dataset.
999
+
1000
+ For classification, this is (n_samples, n_classes).
1001
+ For regression, this is (n_samples,).
1002
+
1003
+ Parameters
1004
+ ----------
1005
+ n_samples : int
1006
+ Number of samples to predict.
1007
+ to_zero : bool, default=False
1008
+ If True, fill with zeros. Otherwise use uniform for classification,
1009
+ or zeros for regression.
1010
+
1011
+ Returns:
1012
+ -------
1013
+ np.ndarray
1014
+ An appropriately sized array of initial predictions.
1015
+ """
1016
+ if self.task_type == "multiclass":
1017
+ if to_zero:
1018
+ return np.zeros((n_samples, self.n_classes_), dtype=np.float64)
1019
+ return (
1020
+ np.ones((n_samples, self.n_classes_), dtype=np.float64)
1021
+ / self.n_classes_
1022
+ )
1023
+ else:
1024
+ # Regression
1025
+ return np.zeros((n_samples,), dtype=np.float64)
1026
+
1027
+ def _score(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
1028
+ """Compute a performance score given ground truth and predictions.
1029
+
1030
+ Parameters
1031
+ ----------
1032
+ y_true : np.ndarray
1033
+ Ground truth labels or values.
1034
+ y_pred : np.ndarray
1035
+ Predictions (probabilities for classification, continuous for regression).
1036
+
1037
+ Returns:
1038
+ -------
1039
+ float
1040
+ The performance score (higher is better for classification,
1041
+ or depends on the specific metric).
1042
+ """
1043
+ metric = self._get_optimize_metric()
1044
+ if self.task_type == "multiclass":
1045
+ return score_classification(metric, y_true, y_pred)
1046
+ elif self.task_type == "regression":
1047
+ return score_regression(metric, y_true, y_pred)
1048
+ else:
1049
+ raise NotImplementedError
1050
+
1051
+ def _get_optimize_metric(self) -> str:
1052
+ """Return which metric name to use for scoring.
1053
+
1054
+ Returns:
1055
+ -------
1056
+ str
1057
+ The metric name, e.g. "roc" for classification or "rmse" for regression.
1058
+ """
1059
+ if self.adaptive_tree_overwrite_metric is not None:
1060
+ return self.adaptive_tree_overwrite_metric
1061
+ if self.task_type == "multiclass":
1062
+ return "roc"
1063
+ return "rmse"
1064
+
1065
+ def _predict_leaf(
1066
+ self,
1067
+ X_train_leaf: np.ndarray,
1068
+ y_train_leaf: np.ndarray,
1069
+ leaf_id: int,
1070
+ X_full: np.ndarray,
1071
+ indices: np.ndarray,
1072
+ ) -> np.ndarray:
1073
+ """Each subclass implements how to call TabPFN for classification or regression.
1074
+
1075
+ Parameters
1076
+ ----------
1077
+ X_train_leaf : np.ndarray
1078
+ Training features for the samples in this leaf/node.
1079
+ y_train_leaf : np.ndarray
1080
+ Training targets for the samples in this leaf/node.
1081
+ leaf_id : int
1082
+ Leaf/node index (for seeding or debugging).
1083
+ X_full : np.ndarray
1084
+ The entire set of features we are predicting on.
1085
+ indices : np.ndarray
1086
+ The indices in X_full that belong to this leaf.
1087
+
1088
+ Returns:
1089
+ -------
1090
+ np.ndarray
1091
+ Predictions for all n_samples, but only indices are filled meaningfully.
1092
+ """
1093
+ raise NotImplementedError("Must be implemented in subclass.")
1094
+
1095
+
1096
+ ###############################################################################
1097
+ # CLASSIFIER SUBCLASS #
1098
+ ###############################################################################
1099
+
1100
+
1101
+ class DecisionTreeTabPFNClassifier(DecisionTreeTabPFNBase, ClassifierMixin):
1102
+ """Decision tree that uses TabPFNClassifier at the leaves."""
1103
+
1104
+ task_type: str = "multiclass"
1105
+
1106
+ def _init_decision_tree(self) -> DecisionTreeClassifier:
1107
+ """Create a scikit-learn DecisionTreeClassifier with stored parameters."""
1108
+ return DecisionTreeClassifier(
1109
+ criterion=self.criterion,
1110
+ max_depth=self.max_depth,
1111
+ min_samples_split=self.min_samples_split,
1112
+ min_samples_leaf=self.min_samples_leaf,
1113
+ min_weight_fraction_leaf=self.min_weight_fraction_leaf,
1114
+ max_features=self.max_features,
1115
+ random_state=self.random_state,
1116
+ max_leaf_nodes=self.max_leaf_nodes,
1117
+ min_impurity_decrease=self.min_impurity_decrease,
1118
+ class_weight=self.class_weight,
1119
+ ccp_alpha=self.ccp_alpha,
1120
+ splitter=self.splitter,
1121
+ )
1122
+
1123
+ def _predict_leaf(
1124
+ self,
1125
+ X_train_leaf: np.ndarray,
1126
+ y_train_leaf: np.ndarray,
1127
+ leaf_id: int,
1128
+ X_full: np.ndarray,
1129
+ indices: np.ndarray,
1130
+ ) -> np.ndarray:
1131
+ """Fit a TabPFNClassifier on the leafs train data and predict_proba for the relevant samples.
1132
+
1133
+ Parameters
1134
+ ----------
1135
+ X_train_leaf : np.ndarray
1136
+ Training features for the samples in this leaf/node.
1137
+ y_train_leaf : np.ndarray
1138
+ Training targets for the samples in this leaf/node.
1139
+ leaf_id : int
1140
+ Leaf/node index.
1141
+ X_full : np.ndarray
1142
+ Full feature matrix to predict on.
1143
+ indices : np.ndarray
1144
+ Indices of X_full that belong to this leaf.
1145
+
1146
+ Returns:
1147
+ -------
1148
+ np.ndarray
1149
+ A (n_samples, n_classes) array of probabilities, with only `indices` updated for this leaf.
1150
+ """
1151
+ y_eval_prob = self._init_eval_probability_array(X_full.shape[0], to_zero=True)
1152
+ classes_in_leaf = [i for i in range(len(np.unique(y_train_leaf)))]
1153
+
1154
+ # If only one class, fill probability 1.0 for that class
1155
+ if len(classes_in_leaf) == 1:
1156
+ y_eval_prob[indices, classes_in_leaf[0]] = 1.0
1157
+ return y_eval_prob
1158
+
1159
+ # Otherwise, fit TabPFN
1160
+ leaf_seed = leaf_id + self.tree_seed
1161
+ try:
1162
+ # Handle pandas DataFrame or numpy array
1163
+ if hasattr(X_full, "iloc"):
1164
+ # Use .iloc for pandas
1165
+ X_subset = X_full.iloc[indices]
1166
+ else:
1167
+ # Use direct indexing for numpy
1168
+ X_subset = X_full[indices]
1169
+
1170
+ try:
1171
+ self.tabpfn.random_state = leaf_seed
1172
+ self.tabpfn.fit(X_train_leaf, y_train_leaf)
1173
+ proba = self.tabpfn.predict_proba(X_subset)
1174
+ except Exception as e:
1175
+ from tabpfn.preprocessing import default_classifier_preprocessor_configs, \
1176
+ default_regressor_preprocessor_configs
1177
+ backup_inf_conf = deepcopy(self.tabpfn.inference_config)
1178
+ default_pre = default_classifier_preprocessor_configs if self.task_type == "multiclass" else default_regressor_preprocessor_configs
1179
+
1180
+ # Try to run again without preprocessing which might crash
1181
+ self.tabpfn.random_state = leaf_seed
1182
+ self.tabpfn.inference_config["PREPROCESS_TRANSFORMS"] = default_pre()
1183
+ self.tabpfn.inference_config["REGRESSION_Y_PREPROCESS_TRANSFORMS"] = (None, "safepower")
1184
+ print(self.tabpfn.inference_config)
1185
+ self.tabpfn.fit(X_train_leaf, y_train_leaf)
1186
+ proba = self.tabpfn.predict_proba(X_subset)
1187
+ # reset preprocessing
1188
+ self.tabpfn.inference_config = backup_inf_conf
1189
+
1190
+ for i, c in enumerate(classes_in_leaf):
1191
+ y_eval_prob[indices, c] = proba[:, i]
1192
+
1193
+ except ValueError as e:
1194
+ if (
1195
+ not e.args
1196
+ or e.args[0]
1197
+ != "All features are constant and would have been removed! Unable to predict using TabPFN."
1198
+ ):
1199
+ raise e
1200
+ warnings.warn(
1201
+ "One node has constant features for TabPFN. Using class-ratio fallback.",
1202
+ stacklevel=2,
1203
+ )
1204
+ _, counts = np.unique(y_train_leaf, return_counts=True)
1205
+ ratio = counts / counts.sum()
1206
+ for i, c in enumerate(classes_in_leaf):
1207
+ y_eval_prob[indices, c] = ratio[i]
1208
+
1209
+ return y_eval_prob
1210
+
1211
+ def predict(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
1212
+ """Predict class labels for X.
1213
+
1214
+ Args:
1215
+ X: Input features.
1216
+ check_input: Whether to validate input arrays. Default is True.
1217
+
1218
+ Returns:
1219
+ np.ndarray: Predicted class labels.
1220
+ """
1221
+ # Validate the model is fitted
1222
+ X = validate_data(
1223
+ self,
1224
+ X,
1225
+ ensure_all_finite=False,
1226
+ )
1227
+ check_is_fitted(self, ["_tree", "X", "y"])
1228
+ proba = self.predict_proba(X, check_input=check_input)
1229
+ return np.argmax(proba, axis=1)
1230
+
1231
+ def predict_proba(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
1232
+ """Predict class probabilities for X using the TabPFN leaves.
1233
+
1234
+ Args:
1235
+ X: Input features.
1236
+ check_input: Whether to validate input arrays. Default is True.
1237
+
1238
+ Returns:
1239
+ np.ndarray: Predicted probabilities of shape (n_samples, n_classes).
1240
+ """
1241
+ # Validate the model is fitted
1242
+ X = validate_data(
1243
+ self,
1244
+ X,
1245
+ ensure_all_finite=False,
1246
+ )
1247
+ check_is_fitted(self, ["_tree", "X", "y"])
1248
+ return self._predict_internal(X, check_input=check_input)
1249
+
1250
+ def _post_fit(self) -> None:
1251
+ """Optional hook after the decision tree is fitted."""
1252
+ if self.verbose:
1253
+ pass
1254
+
1255
+
1256
+ ###############################################################################
1257
+ # REGRESSOR SUBCLASS #
1258
+ ###############################################################################
1259
+
1260
+
1261
+ class DecisionTreeTabPFNRegressor(DecisionTreeTabPFNBase, RegressorMixin):
1262
+ """Decision tree that uses TabPFNRegressor at the leaves."""
1263
+
1264
+ task_type: str = "regression"
1265
+
1266
+ def __init__(
1267
+ self,
1268
+ *,
1269
+ criterion="squared_error",
1270
+ splitter="best",
1271
+ max_depth=None,
1272
+ min_samples_split=1000,
1273
+ min_samples_leaf=1,
1274
+ min_weight_fraction_leaf=0.0,
1275
+ max_features=None,
1276
+ random_state=None,
1277
+ max_leaf_nodes=None,
1278
+ min_impurity_decrease=0.0,
1279
+ ccp_alpha=0.0,
1280
+ monotonic_cst=None,
1281
+ tabpfn=None,
1282
+ categorical_features=None,
1283
+ verbose=False,
1284
+ show_progress=False,
1285
+ fit_nodes=True,
1286
+ tree_seed=0,
1287
+ adaptive_tree=True,
1288
+ adaptive_tree_min_train_samples=50,
1289
+ adaptive_tree_max_train_samples=2000,
1290
+ adaptive_tree_min_valid_samples_fraction_of_train=0.2,
1291
+ adaptive_tree_overwrite_metric=None,
1292
+ adaptive_tree_test_size=0.2,
1293
+ average_logits=True,
1294
+ adaptive_tree_skip_class_missing=True,
1295
+ ):
1296
+ # Call parent constructor
1297
+ super().__init__(
1298
+ tabpfn=tabpfn,
1299
+ criterion=criterion,
1300
+ splitter=splitter,
1301
+ max_depth=max_depth,
1302
+ min_samples_split=min_samples_split,
1303
+ min_samples_leaf=min_samples_leaf,
1304
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
1305
+ max_features=max_features,
1306
+ random_state=random_state,
1307
+ max_leaf_nodes=max_leaf_nodes,
1308
+ min_impurity_decrease=min_impurity_decrease,
1309
+ ccp_alpha=ccp_alpha,
1310
+ monotonic_cst=monotonic_cst,
1311
+ categorical_features=categorical_features,
1312
+ verbose=verbose,
1313
+ show_progress=show_progress,
1314
+ fit_nodes=fit_nodes,
1315
+ tree_seed=tree_seed,
1316
+ adaptive_tree=adaptive_tree,
1317
+ adaptive_tree_min_train_samples=adaptive_tree_min_train_samples,
1318
+ adaptive_tree_max_train_samples=adaptive_tree_max_train_samples,
1319
+ adaptive_tree_min_valid_samples_fraction_of_train=(
1320
+ adaptive_tree_min_valid_samples_fraction_of_train
1321
+ ),
1322
+ adaptive_tree_overwrite_metric=adaptive_tree_overwrite_metric,
1323
+ adaptive_tree_test_size=adaptive_tree_test_size,
1324
+ average_logits=average_logits,
1325
+ adaptive_tree_skip_class_missing=adaptive_tree_skip_class_missing,
1326
+ )
1327
+
1328
+ def _init_decision_tree(self) -> DecisionTreeRegressor:
1329
+ """Create a scikit-learn DecisionTreeRegressor with stored parameters."""
1330
+ return DecisionTreeRegressor(
1331
+ criterion=self.criterion,
1332
+ max_depth=self.max_depth,
1333
+ min_samples_split=self.min_samples_split,
1334
+ min_samples_leaf=self.min_samples_leaf,
1335
+ min_weight_fraction_leaf=self.min_weight_fraction_leaf,
1336
+ max_features=self.max_features,
1337
+ random_state=self.random_state,
1338
+ max_leaf_nodes=self.max_leaf_nodes,
1339
+ min_impurity_decrease=self.min_impurity_decrease,
1340
+ ccp_alpha=self.ccp_alpha,
1341
+ splitter=self.splitter,
1342
+ )
1343
+
1344
+ def _predict_leaf(
1345
+ self,
1346
+ X_train_leaf: np.ndarray,
1347
+ y_train_leaf: np.ndarray,
1348
+ leaf_id: int,
1349
+ X_full: np.ndarray,
1350
+ indices: np.ndarray,
1351
+ ) -> np.ndarray:
1352
+ """Fit a TabPFNRegressor on the nodes train data, then predict for the relevant samples.
1353
+
1354
+ Parameters
1355
+ ----------
1356
+ X_train_leaf : np.ndarray
1357
+ Training features for the samples in this leaf/node.
1358
+ y_train_leaf : np.ndarray
1359
+ Training targets for the samples in this leaf/node.
1360
+ leaf_id : int
1361
+ Leaf/node index.
1362
+ X_full : np.ndarray
1363
+ Full feature matrix to predict on.
1364
+ indices : np.ndarray
1365
+ Indices of X_full that fall into this leaf.
1366
+
1367
+ Returns:
1368
+ -------
1369
+ np.ndarray
1370
+ An array of shape (n_samples,) with predictions; only `indices` are updated.
1371
+ """
1372
+ y_eval = np.zeros(X_full.shape[0], dtype=float)
1373
+
1374
+ # If no training data or just 1 sample, fall back to 0 or single value
1375
+ if len(X_train_leaf) < 1:
1376
+ warnings.warn(
1377
+ f"Leaf {leaf_id} has zero training samples. Returning 0.0 predictions.",
1378
+ stacklevel=2,
1379
+ )
1380
+ return y_eval
1381
+ elif len(X_train_leaf) == 1:
1382
+ y_eval[indices] = y_train_leaf[0]
1383
+ return y_eval
1384
+
1385
+ # If all y are identical, return that constant
1386
+ if np.all(y_train_leaf == y_train_leaf[0]):
1387
+ y_eval[indices] = y_train_leaf[0]
1388
+ return y_eval
1389
+
1390
+ # Fit TabPFNRegressor
1391
+ leaf_seed = leaf_id + self.tree_seed
1392
+ try:
1393
+ self.tabpfn.random_state = leaf_seed
1394
+ self.tabpfn.fit(X_train_leaf, y_train_leaf)
1395
+
1396
+ # Handle pandas DataFrame or numpy array
1397
+ if hasattr(X_full, "iloc"):
1398
+ # Use .iloc for pandas
1399
+ X_subset = X_full.iloc[indices]
1400
+ else:
1401
+ # Use direct indexing for numpy
1402
+ X_subset = X_full[indices]
1403
+
1404
+ preds = self.tabpfn.predict(X_subset)
1405
+ y_eval[indices] = preds
1406
+ except (ValueError, RuntimeError, NotImplementedError, AssertionError) as e:
1407
+ warnings.warn(
1408
+ f"TabPFN fit/predict failed at leaf {leaf_id}: {e}. Using mean fallback.",
1409
+ stacklevel=2,
1410
+ )
1411
+ y_eval[indices] = np.mean(y_train_leaf)
1412
+
1413
+ return y_eval
1414
+
1415
+ def predict(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
1416
+ """Predict regression values using the TabPFN leaves.
1417
+
1418
+ Parameters
1419
+ ----------
1420
+ X : np.ndarray
1421
+ Input features.
1422
+ check_input : bool, default=True
1423
+ Whether to validate the input arrays.
1424
+
1425
+ Returns:
1426
+ -------
1427
+ np.ndarray
1428
+ Continuous predictions of shape (n_samples,).
1429
+ """
1430
+ # Validate the model is fitted
1431
+ X = validate_data(
1432
+ self,
1433
+ X,
1434
+ ensure_all_finite=False,
1435
+ )
1436
+ check_is_fitted(self, ["_tree", "X", "y"])
1437
+ return self._predict_internal(X, check_input=check_input)
1438
+
1439
+ def predict_full(self, X: np.ndarray) -> np.ndarray:
1440
+ """Convenience method to predict with no input checks (optional).
1441
+
1442
+ Parameters
1443
+ ----------
1444
+ X : np.ndarray
1445
+ Input features.
1446
+
1447
+ Returns:
1448
+ -------
1449
+ np.ndarray
1450
+ Continuous predictions of shape (n_samples,).
1451
+ """
1452
+ # Validate the model is fitted
1453
+ X = validate_data(
1454
+ self,
1455
+ X,
1456
+ ensure_all_finite=False,
1457
+ )
1458
+ check_is_fitted(self, ["_tree", "X", "y"])
1459
+ return self._predict_internal(X, check_input=False)
1460
+
1461
+ def _post_fit(self) -> None:
1462
+ """Optional hook after the regressor's tree is fitted."""
1463
+ if self.verbose:
1464
+ pass