obliquetree 1.0.3__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of obliquetree might be problematic. Click here for more details.

obliquetree/_pywrap.py ADDED
@@ -0,0 +1,747 @@
1
+ from __future__ import annotations
2
+
3
+ from .src.base import TreeClassifier
4
+
5
+ from typing import List, Optional
6
+ from numpy.typing import ArrayLike, NDArray
7
+ import numpy as np
8
+ from math import comb
9
+ import warnings
10
+
11
+
12
+ def formatwarning(message, category, filename, lineno, line=None, **kwargs):
13
+ return f"UserWarning: {message}\n"
14
+
15
+
16
+ warnings.formatwarning = formatwarning
17
+
18
+
19
+ class BaseTree(TreeClassifier):
20
+ """
21
+ Base class for decision tree classifiers and regressors.
22
+
23
+ This class provides foundational functionality for building decision trees,
24
+ including parameter validation, data preprocessing, and interfacing with the
25
+ underlying `TreeClassifier`. It handles both classification and regression
26
+ tasks based on the `task` parameter.
27
+
28
+ Parameters
29
+ ----------
30
+ task : bool
31
+ - If `True`, construct regression tree.
32
+ - If `False`, construct classification tree.
33
+
34
+ max_depth : int
35
+ Maximum depth of the tree. Controls model complexity and prevents overfitting.
36
+
37
+ - If `-1`: Expands until leaves are pure or contain fewer than `min_samples_split` samples.
38
+ - If `int > 0`: Limits the tree to the specified depth.
39
+
40
+ min_samples_leaf : int
41
+ Minimum number of samples required at leaf nodes.
42
+
43
+ min_samples_split : int
44
+ Minimum number of samples required to split an internal node.
45
+
46
+ min_impurity_decrease : float
47
+ Minimum required decrease in impurity to create a split.
48
+
49
+ ccp_alpha : float
50
+ Complexity parameter for Minimal Cost-Complexity Pruning.
51
+
52
+ categories : List[int]
53
+ Indices of categorical features in the dataset.
54
+
55
+ use_oblique : bool
56
+ - If `True`, enables oblique splits using linear combinations of features.
57
+ - If `False`, uses traditional axis-aligned splits only.
58
+
59
+ random_state : int
60
+ Seed for random number generation in oblique splits.
61
+
62
+ - Only used when `use_oblique=True`.
63
+
64
+ n_pair : int
65
+ Number of features to combine in oblique splits.
66
+
67
+ - Only used when `use_oblique=True`.
68
+
69
+ gamma : float
70
+ Separation strength parameter for oblique splits.
71
+
72
+ - Only used when `use_oblique=True`.
73
+
74
+ max_iter : int
75
+ Maximum iterations for L-BFGS optimization in oblique splits.
76
+
77
+ - Only used when `use_oblique=True`.
78
+
79
+ relative_change : float
80
+ Early stopping threshold for L-BFGS optimization.
81
+
82
+ - Only used when `use_oblique=True`.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ task: bool,
88
+ max_depth: int,
89
+ min_samples_leaf: int,
90
+ min_samples_split: int,
91
+ min_impurity_decrease: float,
92
+ ccp_alpha: float,
93
+ categories: Optional[List[int]],
94
+ use_oblique: bool,
95
+ random_state: Optional[int],
96
+ n_pair: int,
97
+ gamma: float,
98
+ max_iter: int,
99
+ relative_change: float,
100
+ ) -> None:
101
+ # Validate and assign parameters
102
+ self.task = task
103
+ self.use_oblique = self._validate_use_oblique(use_oblique)
104
+ self.max_depth = self._validate_max_depth(max_depth)
105
+ self.min_samples_leaf = self._validate_min_samples_leaf(min_samples_leaf)
106
+ self.min_samples_split = self._validate_min_samples_split(min_samples_split)
107
+ self.min_impurity_decrease = self._validate_min_impurity_decrease(
108
+ min_impurity_decrease
109
+ )
110
+ self.ccp_alpha = self._validate_ccp_alpha(ccp_alpha)
111
+ self.n_pair = self._validate_n_pair(n_pair)
112
+ self.gamma = self._validate_gamma(gamma)
113
+ self.max_iter = self._validate_max_iter(max_iter)
114
+ self.relative_change = self._validate_relative_change(
115
+ relative_change, self.use_oblique
116
+ )
117
+ self.random_state = self._validate_random_state(random_state)
118
+ self.categories = self._validate_categories(categories)
119
+ self._fit = False
120
+
121
+ # Initialize the TreeClassifier
122
+ super().__init__(
123
+ self.max_depth,
124
+ self.min_samples_leaf,
125
+ self.min_samples_split,
126
+ self.min_impurity_decrease,
127
+ self.random_state,
128
+ self.n_pair,
129
+ self.gamma,
130
+ self.max_iter,
131
+ self.relative_change,
132
+ self.categories,
133
+ self.ccp_alpha,
134
+ self.use_oblique,
135
+ self.task,
136
+ 1,
137
+ )
138
+
139
+ def __getstate__(self):
140
+ """Return the state for pickling."""
141
+ state = super().__getstate__()
142
+ state["_fit"] = self._fit
143
+
144
+ return state
145
+
146
+ def __setstate__(self, state):
147
+ """Restore the state from pickle."""
148
+ # Extract special attributes
149
+ _fit = state.pop("_fit", False)
150
+ super().__setstate__(state)
151
+
152
+ # Restore state directly without re-initialization
153
+ self.__dict__.update(state)
154
+ self._fit = _fit
155
+
156
+ def __repr__(self):
157
+ param_str = (
158
+ f"use_oblique={getattr(self, 'use_oblique', None)}, "
159
+ f"max_depth={getattr(self, 'max_depth', None)}, "
160
+ f"min_samples_leaf={getattr(self, 'min_samples_leaf', None)}, "
161
+ f"min_samples_split={getattr(self, 'min_samples_split', None)}, "
162
+ f"min_impurity_decrease={getattr(self, 'min_impurity_decrease', None)}, "
163
+ f"ccp_alpha={getattr(self, 'ccp_alpha', None)}, "
164
+ f"categories={getattr(self, 'categories', None)}, "
165
+ f"random_state={getattr(self, 'random_state', None)}, "
166
+ f"n_pair={getattr(self, 'n_pair', None)}, "
167
+ f"gamma={getattr(self, 'gamma', None)}, "
168
+ f"max_iter={getattr(self, 'max_iter', None)}, "
169
+ f"relative_change={getattr(self, 'relative_change', None)}"
170
+ )
171
+ return f"{self.__class__.__name__}({param_str})"
172
+
173
+ def _validate_max_depth(self, max_depth: int) -> int:
174
+ if not isinstance(max_depth, int):
175
+ raise ValueError("max_depth must be an integer")
176
+ if max_depth < -1:
177
+ raise ValueError("max_depth must be >= -1")
178
+ return 255 if max_depth == -1 else min(max_depth, 255)
179
+
180
+ def _validate_min_samples_leaf(self, min_samples_leaf: int) -> int:
181
+ if not isinstance(min_samples_leaf, int):
182
+ raise ValueError("min_samples_leaf must be an integer")
183
+ if min_samples_leaf < 1:
184
+ raise ValueError("min_samples_leaf must be >= 1")
185
+ return min_samples_leaf
186
+
187
+ def _validate_min_samples_split(self, min_samples_split: int) -> int:
188
+ if not isinstance(min_samples_split, int):
189
+ raise ValueError("min_samples_split must be an integer")
190
+ if min_samples_split < 2:
191
+ raise ValueError("min_samples_split must be >= 2")
192
+ return min_samples_split
193
+
194
+ def _validate_min_impurity_decrease(self, min_impurity_decrease: float) -> float:
195
+ if not isinstance(min_impurity_decrease, (int, float)):
196
+ raise ValueError("min_impurity_decrease must be a number")
197
+ if min_impurity_decrease < 0.0:
198
+ raise ValueError("min_impurity_decrease must be >= 0.0")
199
+ return float(min_impurity_decrease)
200
+
201
+ def _validate_ccp_alpha(self, ccp_alpha: float) -> float:
202
+ if not isinstance(ccp_alpha, (int, float)):
203
+ raise ValueError("ccp_alpha must be a number")
204
+ if ccp_alpha < 0.0:
205
+ raise ValueError("ccp_alpha must be >= 0.0")
206
+ return float(ccp_alpha)
207
+
208
+ def _validate_n_pair(self, n_pair: int) -> int:
209
+ if not isinstance(n_pair, int):
210
+ raise ValueError("n_pair must be an integer")
211
+ if n_pair < 2:
212
+ raise ValueError("n_pair must be >= 2")
213
+ return n_pair
214
+
215
+ def _validate_gamma(self, gamma: float) -> float:
216
+ if not isinstance(gamma, (int, float)):
217
+ raise ValueError("gamma must be a number")
218
+ if gamma <= 0.0:
219
+ raise ValueError("gamma must be > 0.0")
220
+ return float(gamma)
221
+
222
+ def _validate_max_iter(self, max_iter: int) -> int:
223
+ if not isinstance(max_iter, int):
224
+ raise ValueError("max_iter must be an integer")
225
+ if max_iter < 1:
226
+ raise ValueError("max_iter must be >= 1")
227
+ return max_iter
228
+
229
+ def _validate_relative_change(
230
+ self, relative_change: float, use_oblique: bool
231
+ ) -> float:
232
+ if not isinstance(relative_change, (int, float)):
233
+ raise ValueError("relative_change must be a number")
234
+ if relative_change < 0.0:
235
+ raise ValueError("relative_change must be >= 0.0")
236
+ if use_oblique and relative_change <= 1e-5:
237
+ warnings.warn(
238
+ "relative_change is set very low. This may prolong the oblique training time."
239
+ )
240
+ return float(relative_change)
241
+
242
+ def _validate_random_state(self, random_state: Optional[int]) -> int:
243
+ if random_state is not None and not isinstance(random_state, int):
244
+ raise ValueError("random_state must be None or an integer")
245
+ return (
246
+ random_state
247
+ if random_state is not None
248
+ else np.random.randint(0, np.iinfo(np.int32).max)
249
+ )
250
+
251
+ def _validate_categories(self, categories: Optional[List[int]]) -> List[int]:
252
+ if categories is not None:
253
+ if not isinstance(categories, (list, tuple)):
254
+ raise ValueError("categories must be None or a list/tuple of integers")
255
+ if not all(isinstance(x, int) for x in categories):
256
+ raise ValueError("All elements in categories must be integers")
257
+ if any(x < 0 for x in categories):
258
+ raise ValueError(
259
+ "All elements in categories must be non-negative integers"
260
+ )
261
+ return list(categories)
262
+ return []
263
+
264
+ def _validate_use_oblique(self, use_oblique: bool) -> bool:
265
+ if not isinstance(use_oblique, bool):
266
+ raise ValueError("use_oblique must be a boolean")
267
+ return use_oblique
268
+
269
+ def fit(
270
+ self, X: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike] = None
271
+ ) -> "BaseTree":
272
+ """
273
+ Fit the decision tree to the training data.
274
+
275
+ Parameters
276
+ ----------
277
+ X : ArrayLike
278
+ Training input samples of shape (n_samples, n_features).
279
+ y : ArrayLike
280
+ Target values of shape (n_samples,).
281
+ sample_weight : Optional[ArrayLike], default=None
282
+ Sample weights of shape (n_samples,). If None, all samples are given equal weight.
283
+
284
+ Returns
285
+ -------
286
+ self : BaseTree
287
+ Fitted estimator.
288
+
289
+ Raises
290
+ ------
291
+ ValueError
292
+ If input data is invalid or contains NaN/Inf values where not allowed.
293
+ """
294
+ X = np.asarray(X, order="F", dtype=np.float64)
295
+ y = np.asarray(y, order="C", dtype=np.float64)
296
+
297
+ if X.ndim != 2:
298
+ raise ValueError(
299
+ f"Expected a 2D array for input samples, but got an array with {X.ndim} dimensions."
300
+ )
301
+
302
+ if X.shape[0] != y.shape[0]:
303
+ raise ValueError(
304
+ f"The number of samples in `X` ({X.shape[0]}) does not match the number of target values in `y` ({y.shape[0]})."
305
+ )
306
+
307
+ # Validate target vector
308
+ self._validate_target(y)
309
+
310
+ # Validate sample weights
311
+ sample_weight = self._process_sample_weight(sample_weight, y.shape[0])
312
+
313
+ # Validate feature matrix
314
+ self._validate_features(X)
315
+
316
+ # Classification or Regression setup
317
+ self.n_classes = self._setup_task(y)
318
+
319
+ # Validate categorical features
320
+ self._validate_categories_in_data(X)
321
+
322
+ # Warn if the number of feature combinations is too large for oblique splits
323
+ if self.use_oblique:
324
+ self._warn_large_combinations(X.shape[1] - len(self.categories))
325
+
326
+ super().fit(X, y, sample_weight)
327
+
328
+ self._fit = True
329
+
330
+ return self
331
+
332
+ def _validate_target(self, y: NDArray) -> None:
333
+ if y.ndim != 1:
334
+ raise ValueError("y must be 1-dimensional")
335
+
336
+ if self.task: # Regression
337
+ return
338
+ else: # Classification
339
+ unique_labels = np.unique(y)
340
+ expected_labels = np.arange(len(unique_labels))
341
+ if not np.array_equal(unique_labels, expected_labels):
342
+ raise ValueError(
343
+ "Classification labels must start from 0 and increment by 1"
344
+ )
345
+
346
+ def _process_sample_weight(
347
+ self, sample_weight: Optional[ArrayLike], n_samples: int
348
+ ) -> NDArray:
349
+ if sample_weight is not None:
350
+ sample_weight = np.asarray(sample_weight, order="C", dtype=np.float64)
351
+
352
+ if sample_weight.shape != (n_samples,):
353
+ raise ValueError(
354
+ f"sample_weight has incompatible shape: {sample_weight.shape} "
355
+ f"while y has shape ({n_samples},)"
356
+ )
357
+
358
+ if (
359
+ np.any(np.isnan(sample_weight))
360
+ or np.any(np.isinf(sample_weight))
361
+ or np.any(sample_weight < 0)
362
+ ):
363
+ raise ValueError(
364
+ "sample_weight cannot contain negative, NaN or inf values"
365
+ )
366
+
367
+ min_val = np.min(sample_weight)
368
+ if min_val != 1:
369
+ sample_weight = sample_weight / min_val
370
+
371
+ else:
372
+ sample_weight = np.ones(n_samples, dtype=np.float64)
373
+
374
+ return sample_weight
375
+
376
+ def _validate_features(self, X: NDArray) -> None:
377
+ if self.use_oblique:
378
+ if np.any(np.isnan(X)) or np.any(np.isinf(X)):
379
+ raise ValueError(
380
+ "X cannot contain NaN or Inf values when use_oblique is False"
381
+ )
382
+
383
+ max_possible_pairs = (
384
+ X.shape[1] - len(self.categories) if self.categories else X.shape[1]
385
+ )
386
+
387
+ if self.categories:
388
+ if max_possible_pairs < 2:
389
+ warnings.warn(
390
+ f"Total features: {X.shape[1]}, categorical features: {len(self.categories)}. "
391
+ f"The number of possible feature pairs ({max_possible_pairs}) is less than 2. "
392
+ f"As a result, 'use_oblique' set 'False'."
393
+ )
394
+ self.use_oblique = False
395
+
396
+ elif self.n_pair > max_possible_pairs:
397
+ warnings.warn(
398
+ f"Total features: {X.shape[1]}, categorical features: {len(self.categories)}. "
399
+ f"n_pair ({self.n_pair}) exceeds the usable features, adjusting n_pair to {max_possible_pairs}."
400
+ )
401
+ self.n_pair = max_possible_pairs
402
+ else: # If there are no categorical features
403
+ if self.n_pair > X.shape[1]:
404
+ warnings.warn(
405
+ f"n_pair ({self.n_pair}) exceeds the total features ({X.shape[1]}). "
406
+ f"Adjusting n_pair to {X.shape[1]}."
407
+ )
408
+ self.n_pair = X.shape[1]
409
+
410
+ def _setup_task(self, y: NDArray) -> int:
411
+ if not self.task:
412
+ n_classes = len(np.unique(y))
413
+ return n_classes
414
+ else:
415
+ return 1 # Regression
416
+
417
+ def _validate_categories_in_data(self, X: NDArray) -> None:
418
+ if self.categories:
419
+ for col_idx in self.categories:
420
+ # Kategori indeksi matris boyutlarını aşmamalı
421
+ if col_idx >= X.shape[1]:
422
+ raise ValueError(
423
+ f"Category column index {col_idx} exceeds X dimensions ({X.shape[1]} features)."
424
+ )
425
+
426
+ # Kategorik sütunlardaki değerler negatif olmamalı
427
+ if (X[:, self.categories] < 0).any():
428
+ raise ValueError(
429
+ "X contains negative values in the specified category columns, which are not allowed."
430
+ )
431
+
432
+ def _warn_large_combinations(self, n_features: int) -> None:
433
+ total_combinations = comb(n_features, self.n_pair)
434
+ if total_combinations > 1000: # Optimal threshold can be adjusted
435
+ warnings.warn(
436
+ "The number of feature combinations for oblique splits is very large, which may lead to long training times. "
437
+ "Consider reducing `n_pair` or the number of features."
438
+ )
439
+
440
+ def predict(self, X: ArrayLike) -> NDArray:
441
+ """
442
+ Predict target values for the input samples.
443
+
444
+ Parameters
445
+ ----------
446
+ X : ArrayLike
447
+ Input samples of shape (n_samples, n_features).
448
+
449
+ Returns
450
+ -------
451
+ NDArray
452
+ Predicted values.
453
+
454
+ Raises
455
+ ------
456
+ ValueError
457
+ If the model has not been fitted yet.
458
+ """
459
+ if not self._fit:
460
+ raise ValueError(
461
+ "The model has not been fitted yet. Please call `fit` first."
462
+ )
463
+
464
+ X = np.asarray(X, order="F", dtype=np.float64)
465
+
466
+ if X.ndim != 2:
467
+ raise ValueError(
468
+ f"Expected a 2D array for input samples, but got an array with {X.ndim} dimensions. "
469
+ )
470
+
471
+ return super().predict(X)
472
+
473
+
474
+ class Classifier(BaseTree):
475
+ def __init__(
476
+ self,
477
+ use_oblique: bool = True,
478
+ max_depth: int = -1,
479
+ min_samples_leaf: int = 1,
480
+ min_samples_split: int = 2,
481
+ min_impurity_decrease: float = 0.0,
482
+ ccp_alpha: float = 0.0,
483
+ categories: Optional[List[int]] = None,
484
+ random_state: Optional[int] = None,
485
+ n_pair: int = 2,
486
+ gamma: float = 1.0,
487
+ max_iter: int = 100,
488
+ relative_change: float = 0.001,
489
+ ):
490
+ """
491
+ A decision tree classifier supporting both traditional axis-aligned and oblique splits.
492
+
493
+ This advanced decision tree classifier extends traditional regression trees by supporting oblique
494
+ splits (linear combinations of features) alongside conventional axis-aligned splits. It offers enhanced
495
+ flexibility in modeling continuous outputs while maintaining the interpretability of decision trees.
496
+
497
+ Parameters
498
+ ----------
499
+ use_oblique : bool, default=True
500
+ - If `True`, enables oblique splits using linear combinations of features.
501
+ - If `False`, uses traditional axis-aligned splits only.
502
+
503
+ max_depth : int, default=-1
504
+ Maximum depth of the tree. Controls model complexity and prevents overfitting.
505
+
506
+ - If `-1`: Expands until leaves are pure or contain fewer than `min_samples_split` samples.
507
+ - If `int > 0`: Limits the tree to the specified depth.
508
+
509
+ min_samples_leaf : int, default=1
510
+ Minimum number of samples required at leaf nodes.
511
+
512
+ min_samples_split : int, default=2
513
+ Minimum number of samples required to split an internal node.
514
+
515
+ min_impurity_decrease : float, default=0.0
516
+ Minimum required decrease in impurity to create a split.
517
+
518
+ ccp_alpha : float, default=0.0
519
+ Complexity parameter for Minimal Cost-Complexity Pruning.
520
+
521
+ categories : List[int], default=None
522
+ Indices of categorical features in the dataset.
523
+
524
+ random_state : int, default=None
525
+ Seed for random number generation in oblique splits.
526
+
527
+ - Only used when `use_oblique=True`.
528
+
529
+ n_pair : int, default=2
530
+ Number of features to combine in oblique splits.
531
+
532
+ - Only used when `use_oblique=True`.
533
+
534
+ gamma : float, default=1.0
535
+ Separation strength parameter for oblique splits.
536
+
537
+ - Only used when `use_oblique=True`.
538
+
539
+ max_iter : int, default=100
540
+ Maximum iterations for L-BFGS optimization in oblique splits.
541
+
542
+ - Only used when `use_oblique=True`.
543
+
544
+ relative_change : float, default=0.001
545
+ Early stopping threshold for L-BFGS optimization.
546
+
547
+ - Only used when `use_oblique=True`.
548
+ """
549
+ super().__init__(
550
+ task=False,
551
+ max_depth=max_depth,
552
+ min_samples_leaf=min_samples_leaf,
553
+ min_samples_split=min_samples_split,
554
+ min_impurity_decrease=min_impurity_decrease,
555
+ ccp_alpha=ccp_alpha,
556
+ categories=categories,
557
+ use_oblique=use_oblique,
558
+ random_state=random_state,
559
+ n_pair=n_pair,
560
+ gamma=gamma,
561
+ max_iter=max_iter,
562
+ relative_change=relative_change,
563
+ )
564
+
565
+ def fit(
566
+ self, X: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike] = None
567
+ ) -> "Classifier":
568
+ """
569
+ Build a decision tree classifier from the training set (X, y).
570
+
571
+ Parameters
572
+ ----------
573
+ X : array-like of shape (n_samples, n_features)
574
+ The training input samples.
575
+ y : array-like of shape (n_samples,)
576
+ Target values (class labels).
577
+ sample_weight : array-like of shape (n_samples,), default=None
578
+ Sample weights.
579
+
580
+ Returns
581
+ -------
582
+ self : Classifier
583
+ Fitted estimator.
584
+ """
585
+ return super().fit(X, y, sample_weight)
586
+
587
+ def predict(self, X: ArrayLike) -> NDArray:
588
+ """
589
+ Predict regression target for X.
590
+
591
+ Parameters
592
+ ----------
593
+ X : array-like of shape (n_samples, n_features)
594
+ The input samples to predict.
595
+
596
+ Returns
597
+ -------
598
+ y : NDArray of shape (n_samples,)
599
+ The predicted values.
600
+ """
601
+ return np.argmax(super().predict(X), axis=1)
602
+
603
+ def predict_proba(self, X: ArrayLike) -> NDArray:
604
+ """
605
+ Predict class probabilities for X.
606
+
607
+ Parameters
608
+ ----------
609
+ X : array-like of shape (n_samples, n_features)
610
+ The input samples.
611
+
612
+ Returns
613
+ -------
614
+ proba : NDArray of shape (n_samples, n_classes)
615
+ The class probabilities of the input samples.
616
+ """
617
+ return super().predict(X)
618
+
619
+
620
+ class Regressor(BaseTree):
621
+ def __init__(
622
+ self,
623
+ use_oblique: bool = True,
624
+ max_depth: int = -1,
625
+ min_samples_leaf: int = 1,
626
+ min_samples_split: int = 2,
627
+ min_impurity_decrease: float = 0.0,
628
+ ccp_alpha: float = 0.0,
629
+ categories: Optional[List[int]] = None,
630
+ random_state: Optional[int] = None,
631
+ n_pair: int = 2,
632
+ gamma: float = 1.0,
633
+ max_iter: int = 100,
634
+ relative_change: float = 0.001,
635
+ ):
636
+ """
637
+ A decision tree regressor supporting both traditional axis-aligned and oblique splits.
638
+
639
+ This advanced decision tree regressor extends traditional regression trees by supporting oblique
640
+ splits (linear combinations of features) alongside conventional axis-aligned splits. It offers enhanced
641
+ flexibility in modeling continuous outputs while maintaining the interpretability of decision trees.
642
+
643
+ Parameters
644
+ ----------
645
+ use_oblique : bool, default=True
646
+ - If `True`, enables oblique splits using linear combinations of features.
647
+ - If `False`, uses traditional axis-aligned splits only.
648
+
649
+ max_depth : int, default=-1
650
+ Maximum depth of the tree. Controls model complexity and prevents overfitting.
651
+
652
+ - If `-1`: Expands until leaves are pure or contain fewer than `min_samples_split` samples.
653
+ - If `int > 0`: Limits the tree to the specified depth.
654
+
655
+ min_samples_leaf : int, default=1
656
+ Minimum number of samples required at leaf nodes.
657
+
658
+ min_samples_split : int, default=2
659
+ Minimum number of samples required to split an internal node.
660
+
661
+ min_impurity_decrease : float, default=0.0
662
+ Minimum required decrease in impurity to create a split.
663
+
664
+ ccp_alpha : float, default=0.0
665
+ Complexity parameter for Minimal Cost-Complexity Pruning.
666
+
667
+ categories : List[int], default=None
668
+ Indices of categorical features in the dataset.
669
+
670
+ random_state : int, default=None
671
+ Seed for random number generation in oblique splits.
672
+
673
+ - Only used when `use_oblique=True`.
674
+
675
+ n_pair : int, default=2
676
+ Number of features to combine in oblique splits.
677
+
678
+ - Only used when `use_oblique=True`.
679
+
680
+ gamma : float, default=1.0
681
+ Separation strength parameter for oblique splits.
682
+
683
+ - Only used when `use_oblique=True`.
684
+
685
+ max_iter : int, default=100
686
+ Maximum iterations for L-BFGS optimization in oblique splits.
687
+
688
+ - Only used when `use_oblique=True`.
689
+
690
+ relative_change : float, default=0.001
691
+ Early stopping threshold for L-BFGS optimization.
692
+
693
+ - Only used when `use_oblique=True`.
694
+ """
695
+ super().__init__(
696
+ task=True,
697
+ max_depth=max_depth,
698
+ min_samples_leaf=min_samples_leaf,
699
+ min_samples_split=min_samples_split,
700
+ min_impurity_decrease=min_impurity_decrease,
701
+ ccp_alpha=ccp_alpha,
702
+ categories=categories,
703
+ use_oblique=use_oblique,
704
+ random_state=random_state,
705
+ n_pair=n_pair,
706
+ gamma=gamma,
707
+ max_iter=max_iter,
708
+ relative_change=relative_change,
709
+ )
710
+
711
+ def fit(
712
+ self, X: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike] = None
713
+ ) -> "Regressor":
714
+ """
715
+ Build a decision tree regressor from the training set (X, y).
716
+
717
+ Parameters
718
+ ----------
719
+ X : array-like of shape (n_samples, n_features)
720
+ The training input samples.
721
+ y : array-like of shape (n_samples,)
722
+ Target values.
723
+ sample_weight : array-like of shape (n_samples,), optional, default=None
724
+ Sample weights.
725
+
726
+ Returns
727
+ -------
728
+ self : Regressor
729
+ Fitted estimator.
730
+ """
731
+ return super().fit(X, y, sample_weight)
732
+
733
+ def predict(self, X: ArrayLike) -> NDArray:
734
+ """
735
+ Predict regression target for X.
736
+
737
+ Parameters
738
+ ----------
739
+ X : array-like of shape (n_samples, n_features)
740
+ The input samples to predict.
741
+
742
+ Returns
743
+ -------
744
+ y : NDArray of shape (n_samples,)
745
+ The predicted values.
746
+ """
747
+ return super().predict(X).ravel()