perpetual 1.0.35__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
perpetual/booster.py ADDED
@@ -0,0 +1,1915 @@
1
+ import inspect
2
+ import json
3
+ import warnings
4
+ from types import FunctionType
5
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, cast
6
+
7
+ import numpy as np
8
+ from typing_extensions import Self
9
+
10
+ from perpetual.data import Node
11
+ from perpetual.perpetual import (
12
+ MultiOutputBooster as CrateMultiOutputBooster, # type: ignore
13
+ )
14
+ from perpetual.perpetual import (
15
+ PerpetualBooster as CratePerpetualBooster, # type: ignore
16
+ )
17
+ from perpetual.serialize import BaseSerializer, ObjectSerializer
18
+ from perpetual.types import BoosterType, MultiOutputBoosterType
19
+ from perpetual.utils import (
20
+ CONTRIBUTION_METHODS,
21
+ convert_input_array,
22
+ convert_input_frame,
23
+ convert_input_frame_columnar,
24
+ transform_input_frame,
25
+ transform_input_frame_columnar,
26
+ type_df,
27
+ )
28
+
29
+
30
+ class PerpetualBooster:
31
+ # Define the metadata parameters
32
+ # that are present on all instances of this class
33
+ # this is useful for parameters that should be
34
+ # attempted to be loaded in and set
35
+ # as attributes on the booster after it is loaded.
36
+ metadata_attributes: Dict[str, BaseSerializer] = {
37
+ "feature_names_in_": ObjectSerializer(),
38
+ "n_features_": ObjectSerializer(),
39
+ "feature_importance_method": ObjectSerializer(),
40
+ "cat_mapping": ObjectSerializer(),
41
+ "classes_": ObjectSerializer(),
42
+ }
43
+
44
+ def __init__(
45
+ self,
46
+ *,
47
+ objective: Union[
48
+ str, Tuple[FunctionType, FunctionType, FunctionType]
49
+ ] = "LogLoss",
50
+ budget: float = 0.5,
51
+ num_threads: Optional[int] = None,
52
+ monotone_constraints: Union[Dict[Any, int], None] = None,
53
+ force_children_to_bound_parent: bool = False,
54
+ missing: float = np.nan,
55
+ allow_missing_splits: bool = True,
56
+ create_missing_branch: bool = False,
57
+ terminate_missing_features: Optional[Iterable[Any]] = None,
58
+ missing_node_treatment: str = "None",
59
+ log_iterations: int = 0,
60
+ feature_importance_method: str = "Gain",
61
+ quantile: Optional[float] = None,
62
+ reset: Optional[bool] = None,
63
+ categorical_features: Union[Iterable[int], Iterable[str], str, None] = "auto",
64
+ timeout: Optional[float] = None,
65
+ iteration_limit: Optional[int] = None,
66
+ memory_limit: Optional[float] = None,
67
+ stopping_rounds: Optional[int] = None,
68
+ max_bin: int = 256,
69
+ max_cat: int = 1000,
70
+ ):
71
+ """
72
+ Gradient Boosting Machine with Perpetual Learning.
73
+
74
+ A self-generalizing gradient boosting machine that doesn't need hyperparameter optimization.
75
+ It automatically finds the best configuration based on the provided budget.
76
+
77
+ Parameters
78
+ ----------
79
+ objective : str or tuple, default="LogLoss"
80
+ Learning objective function to be used for optimization. Valid options are:
81
+
82
+ - "LogLoss": logistic loss for binary classification.
83
+ - "SquaredLoss": squared error for regression.
84
+ - "QuantileLoss": quantile error for quantile regression.
85
+ - "HuberLoss": Huber loss for robust regression.
86
+ - "AdaptiveHuberLoss": adaptive Huber loss for robust regression.
87
+ - "ListNetLoss": ListNet loss for ranking.
88
+ - custom objective: a tuple of (grad, hess, init) functions.
89
+
90
+ budget : float, default=0.5
91
+ A positive number for fitting budget. Increasing this number will more likely result
92
+ in more boosting rounds and increased predictive power.
93
+ num_threads : int, optional
94
+ Number of threads to be used during training and prediction.
95
+ monotone_constraints : dict, optional
96
+ Constraints to enforce a specific relationship between features and target.
97
+ Keys are feature indices or names, values are -1, 1, or 0.
98
+ force_children_to_bound_parent : bool, default=False
99
+ Whether to restrict children nodes to be within the parent's range.
100
+ missing : float, default=np.nan
101
+ Value to consider as missing data.
102
+ allow_missing_splits : bool, default=True
103
+ Whether to allow splits that separate missing from non-missing values.
104
+ create_missing_branch : bool, default=False
105
+ Whether to create a separate branch for missing values (ternary trees).
106
+ terminate_missing_features : iterable, optional
107
+ Features for which missing branches will always be terminated if
108
+ ``create_missing_branch`` is True.
109
+ missing_node_treatment : str, default="None"
110
+ How to handle weights for missing nodes if ``create_missing_branch`` is True.
111
+ Options: "None", "AssignToParent", "AverageLeafWeight", "AverageNodeWeight".
112
+ log_iterations : int, default=0
113
+ Logging frequency (every N iterations). 0 disables logging.
114
+ feature_importance_method : str, default="Gain"
115
+ Method for calculating feature importance. Options: "Gain", "Weight", "Cover",
116
+ "TotalGain", "TotalCover".
117
+ quantile : float, optional
118
+ Target quantile for quantile regression (objective="QuantileLoss").
119
+ reset : bool, optional
120
+ Whether to reset the model or continue training on subsequent calls to fit.
121
+ categorical_features : str or iterable, default="auto"
122
+ Feature indices or names to treat as categorical.
123
+ timeout : float, optional
124
+ Time limit for fitting in seconds.
125
+ iteration_limit : int, optional
126
+ Maximum number of boosting iterations.
127
+ memory_limit : float, optional
128
+ Memory limit for training in GB.
129
+ stopping_rounds : int, optional
130
+ Early stopping rounds.
131
+ max_bin : int, default=256
132
+ Maximum number of bins for feature discretization.
133
+ max_cat : int, default=1000
134
+ Maximum unique categories before a feature is treated as numerical.
135
+
136
+ Attributes
137
+ ----------
138
+ feature_names_in_ : list of str
139
+ Names of features seen during :meth:`fit`.
140
+ n_features_ : int
141
+ Number of features seen during :meth:`fit`.
142
+ classes_ : list
143
+ Class labels for classification tasks.
144
+ feature_importances_ : ndarray of shape (n_features,)
145
+ Feature importances calculated via ``feature_importance_method``.
146
+
147
+ See Also
148
+ --------
149
+ perpetual.sklearn.PerpetualClassifier : Scikit-learn compatible classifier.
150
+ perpetual.sklearn.PerpetualRegressor : Scikit-learn compatible regressor.
151
+
152
+ Examples
153
+ --------
154
+ Basic usage for binary classification:
155
+
156
+ >>> from perpetual import PerpetualBooster
157
+ >>> from sklearn.datasets import make_classification
158
+ >>> X, y = make_classification(n_samples=1000, n_features=20)
159
+ >>> model = PerpetualBooster(objective="LogLoss")
160
+ >>> model.fit(X, y)
161
+ >>> preds = model.predict(X[:5])
162
+ """
163
+
164
+ terminate_missing_features_ = (
165
+ set() if terminate_missing_features is None else terminate_missing_features
166
+ )
167
+ monotone_constraints_ = (
168
+ {} if monotone_constraints is None else monotone_constraints
169
+ )
170
+
171
+ if isinstance(objective, str):
172
+ self.objective = objective
173
+ self.loss = None
174
+ self.grad = None
175
+ self.init = None
176
+ else:
177
+ self.objective = None
178
+ self.loss = objective[0]
179
+ self.grad = objective[1]
180
+ self.init = objective[2]
181
+ self.budget = budget
182
+ self.num_threads = num_threads
183
+ self.monotone_constraints = monotone_constraints_
184
+ self.force_children_to_bound_parent = force_children_to_bound_parent
185
+ self.allow_missing_splits = allow_missing_splits
186
+ self.missing = missing
187
+ self.create_missing_branch = create_missing_branch
188
+ self.terminate_missing_features = terminate_missing_features_
189
+ self.missing_node_treatment = missing_node_treatment
190
+ self.log_iterations = log_iterations
191
+ self.feature_importance_method = feature_importance_method
192
+ self.quantile = quantile
193
+ self.reset = reset
194
+ self.categorical_features = categorical_features
195
+ self.timeout = timeout
196
+ self.iteration_limit = iteration_limit
197
+ self.memory_limit = memory_limit
198
+ self.stopping_rounds = stopping_rounds
199
+ self.max_bin = max_bin
200
+ self.max_cat = max_cat
201
+
202
+ booster = CratePerpetualBooster(
203
+ objective=self.objective,
204
+ budget=self.budget,
205
+ max_bin=self.max_bin,
206
+ num_threads=self.num_threads,
207
+ monotone_constraints=dict(),
208
+ force_children_to_bound_parent=self.force_children_to_bound_parent,
209
+ missing=self.missing,
210
+ allow_missing_splits=self.allow_missing_splits,
211
+ create_missing_branch=self.create_missing_branch,
212
+ terminate_missing_features=set(),
213
+ missing_node_treatment=self.missing_node_treatment,
214
+ log_iterations=self.log_iterations,
215
+ quantile=self.quantile,
216
+ reset=self.reset,
217
+ categorical_features=set(),
218
+ timeout=self.timeout,
219
+ iteration_limit=self.iteration_limit,
220
+ memory_limit=self.memory_limit,
221
+ stopping_rounds=self.stopping_rounds,
222
+ loss=self.loss,
223
+ grad=self.grad,
224
+ init=self.init,
225
+ )
226
+ self.booster = cast(BoosterType, booster)
227
+
228
+ def fit(self, X, y, sample_weight=None, group=None) -> Self:
229
+ """
230
+ Fit the gradient booster on a provided dataset.
231
+
232
+ Parameters
233
+ ----------
234
+ X : array-like of shape (n_samples, n_features)
235
+ Training data. Can be a Polars or Pandas DataFrame, or a 2D Numpy array.
236
+ Polars DataFrames use a zero-copy columnar path for efficiency.
237
+ y : array-like of shape (n_samples,) or (n_samples, n_targets)
238
+ Target values.
239
+ sample_weight : array-like of shape (n_samples,), optional
240
+ Individual weights for each sample. If None, all samples are weighted equally.
241
+ group : array-like, optional
242
+ Group labels for ranking objectives.
243
+
244
+ Returns
245
+ -------
246
+ self : object
247
+ Returns self.
248
+ """
249
+
250
+ # Check if input is a Polars DataFrame for zero-copy columnar path
251
+ is_polars = type_df(X) == "polars_df"
252
+
253
+ if is_polars:
254
+ # Use columnar path for Polars DataFrames (true zero-copy)
255
+ (
256
+ features_,
257
+ columns, # list of 1D arrays instead of flat_data
258
+ masks,
259
+ rows,
260
+ cols,
261
+ categorical_features_,
262
+ cat_mapping,
263
+ ) = convert_input_frame_columnar(X, self.categorical_features, self.max_cat)
264
+ else:
265
+ # Use existing flat path for pandas and numpy
266
+ (
267
+ features_,
268
+ flat_data,
269
+ rows,
270
+ cols,
271
+ categorical_features_,
272
+ cat_mapping,
273
+ ) = convert_input_frame(X, self.categorical_features, self.max_cat)
274
+
275
+ self.n_features_ = cols
276
+ self.cat_mapping = cat_mapping
277
+ self.feature_names_in_ = features_
278
+
279
+ y_, classes_ = convert_input_array(y, self.objective, is_target=True)
280
+ self.classes_ = np.array(classes_).tolist()
281
+
282
+ if sample_weight is None:
283
+ sample_weight_ = None
284
+ else:
285
+ sample_weight_, _ = convert_input_array(sample_weight, self.objective)
286
+
287
+ if group is None:
288
+ group_ = None
289
+ else:
290
+ group_, _ = convert_input_array(group, self.objective, is_int=True)
291
+
292
+ # Convert the monotone constraints into the form needed
293
+ # by the rust code.
294
+ crate_mc = self._standardize_monotonicity_map(X)
295
+ crate_tmf = self._standardize_terminate_missing_features(X)
296
+
297
+ if (len(classes_) <= 2) or (
298
+ len(classes_) > 1 and self.objective == "SquaredLoss"
299
+ ):
300
+ booster = CratePerpetualBooster(
301
+ objective=self.objective,
302
+ budget=self.budget,
303
+ max_bin=self.max_bin,
304
+ num_threads=self.num_threads,
305
+ monotone_constraints=crate_mc,
306
+ force_children_to_bound_parent=self.force_children_to_bound_parent,
307
+ missing=self.missing,
308
+ allow_missing_splits=self.allow_missing_splits,
309
+ create_missing_branch=self.create_missing_branch,
310
+ terminate_missing_features=crate_tmf,
311
+ missing_node_treatment=self.missing_node_treatment,
312
+ log_iterations=self.log_iterations,
313
+ quantile=self.quantile,
314
+ reset=self.reset,
315
+ categorical_features=categorical_features_,
316
+ timeout=self.timeout,
317
+ iteration_limit=self.iteration_limit,
318
+ memory_limit=self.memory_limit,
319
+ stopping_rounds=self.stopping_rounds,
320
+ loss=self.loss,
321
+ grad=self.grad,
322
+ init=self.init,
323
+ )
324
+ self.booster = cast(BoosterType, booster)
325
+ else:
326
+ booster = CrateMultiOutputBooster(
327
+ n_boosters=len(classes_),
328
+ objective=self.objective,
329
+ budget=self.budget,
330
+ max_bin=self.max_bin,
331
+ num_threads=self.num_threads,
332
+ monotone_constraints=crate_mc,
333
+ force_children_to_bound_parent=self.force_children_to_bound_parent,
334
+ missing=self.missing,
335
+ allow_missing_splits=self.allow_missing_splits,
336
+ create_missing_branch=self.create_missing_branch,
337
+ terminate_missing_features=crate_tmf,
338
+ missing_node_treatment=self.missing_node_treatment,
339
+ log_iterations=self.log_iterations,
340
+ quantile=self.quantile,
341
+ reset=self.reset,
342
+ categorical_features=categorical_features_,
343
+ timeout=self.timeout,
344
+ iteration_limit=self.iteration_limit,
345
+ memory_limit=self.memory_limit,
346
+ stopping_rounds=self.stopping_rounds,
347
+ loss=self.loss,
348
+ grad=self.grad,
349
+ init=self.init,
350
+ )
351
+ self.booster = cast(MultiOutputBoosterType, booster)
352
+
353
+ self._set_metadata_attributes("n_features_", self.n_features_)
354
+ self._set_metadata_attributes("cat_mapping", self.cat_mapping)
355
+ self._set_metadata_attributes("feature_names_in_", self.feature_names_in_)
356
+ self._set_metadata_attributes(
357
+ "feature_importance_method", self.feature_importance_method
358
+ )
359
+ self._set_metadata_attributes("classes_", self.classes_)
360
+
361
+ self.categorical_features = categorical_features_
362
+
363
+ if is_polars:
364
+ # Use columnar fit for Polars (zero-copy)
365
+ self.booster.fit_columnar(
366
+ columns=columns,
367
+ masks=masks,
368
+ rows=rows,
369
+ y=y_,
370
+ sample_weight=sample_weight_, # type: ignore
371
+ group=group_,
372
+ )
373
+ else:
374
+ # Use standard fit for pandas/numpy
375
+ self.booster.fit(
376
+ flat_data=flat_data,
377
+ rows=rows,
378
+ cols=cols,
379
+ y=y_,
380
+ sample_weight=sample_weight_, # type: ignore
381
+ group=group_,
382
+ )
383
+
384
+ return self
385
+
386
+ def prune(self, X, y, sample_weight=None, group=None) -> Self:
387
+ """
388
+ Prune the gradient booster on a provided dataset.
389
+
390
+ This removes nodes that do not contribute to a reduction in loss on the provided
391
+ validation set.
392
+
393
+ Parameters
394
+ ----------
395
+ X : array-like of shape (n_samples, n_features)
396
+ Validation data.
397
+ y : array-like of shape (n_samples,)
398
+ Validation targets.
399
+ sample_weight : array-like of shape (n_samples,), optional
400
+ Weights for validation samples.
401
+ group : array-like, optional
402
+ Group labels for ranking objectives.
403
+
404
+ Returns
405
+ -------
406
+ self : object
407
+ Returns self.
408
+ """
409
+
410
+ _, flat_data, rows, cols = transform_input_frame(X, self.cat_mapping)
411
+
412
+ y_, _ = convert_input_array(y, self.objective)
413
+
414
+ if sample_weight is None:
415
+ sample_weight_ = None
416
+ else:
417
+ sample_weight_, _ = convert_input_array(sample_weight, self.objective)
418
+
419
+ if group is None:
420
+ group_ = None
421
+ else:
422
+ group_, _ = convert_input_array(group, self.objective, is_int=True)
423
+
424
+ self.booster.prune(
425
+ flat_data=flat_data,
426
+ rows=rows,
427
+ cols=cols,
428
+ y=y_,
429
+ sample_weight=sample_weight_, # type: ignore
430
+ group=group_,
431
+ )
432
+
433
+ return self
434
+
435
+ def calibrate(
436
+ self, X_train, y_train, X_cal, y_cal, alpha, sample_weight=None, group=None
437
+ ) -> Self:
438
+ """
439
+ Calibrate the gradient booster for prediction intervals.
440
+
441
+ Uses the provided training and calibration sets to compute scaling factors
442
+ for intervals.
443
+
444
+ Parameters
445
+ ----------
446
+ X_train : array-like
447
+ Data used to train the base model.
448
+ y_train : array-like
449
+ Targets for training data.
450
+ X_cal : array-like
451
+ Independent calibration dataset.
452
+ y_cal : array-like
453
+ Targets for calibration data.
454
+ alpha : float or array-like
455
+ Significance level(s) for the intervals (1 - coverage).
456
+ sample_weight : array-like, optional
457
+ Sample weights.
458
+ group : array-like, optional
459
+ Group labels.
460
+
461
+ Returns
462
+ -------
463
+ self : object
464
+ Returns self.
465
+ """
466
+
467
+ is_polars = type_df(X_train) == "polars_df"
468
+ if is_polars:
469
+ features_train, cols_train, masks_train, rows_train, _ = (
470
+ transform_input_frame_columnar(X_train, self.cat_mapping)
471
+ )
472
+ self._validate_features(features_train)
473
+ features_cal, cols_cal, masks_cal, rows_cal, _ = (
474
+ transform_input_frame_columnar(X_cal, self.cat_mapping)
475
+ )
476
+ # Use columnar calibration
477
+ y_train_, _ = convert_input_array(y_train, self.objective)
478
+ y_cal_, _ = convert_input_array(y_cal, self.objective)
479
+ if sample_weight is None:
480
+ sample_weight_ = None
481
+ else:
482
+ sample_weight_, _ = convert_input_array(sample_weight, self.objective)
483
+
484
+ self.booster.calibrate_columnar(
485
+ columns=cols_train,
486
+ masks=masks_train,
487
+ rows=rows_train,
488
+ y=y_train_,
489
+ columns_cal=cols_cal,
490
+ masks_cal=masks_cal,
491
+ rows_cal=rows_cal,
492
+ y_cal=y_cal_,
493
+ alpha=np.array(alpha),
494
+ sample_weight=sample_weight_, # type: ignore
495
+ group=group,
496
+ )
497
+ else:
498
+ _, flat_data_train, rows_train, cols_train = transform_input_frame(
499
+ X_train, self.cat_mapping
500
+ )
501
+
502
+ y_train_, _ = convert_input_array(y_train, self.objective)
503
+
504
+ _, flat_data_cal, rows_cal, cols_cal = transform_input_frame(
505
+ X_cal, self.cat_mapping
506
+ )
507
+
508
+ y_cal_, _ = convert_input_array(y_cal, self.objective)
509
+
510
+ if sample_weight is None:
511
+ sample_weight_ = None
512
+ else:
513
+ sample_weight_, _ = convert_input_array(sample_weight, self.objective)
514
+
515
+ self.booster.calibrate(
516
+ flat_data=flat_data_train,
517
+ rows=rows_train,
518
+ cols=cols_train,
519
+ y=y_train_,
520
+ flat_data_cal=flat_data_cal,
521
+ rows_cal=rows_cal,
522
+ cols_cal=cols_cal,
523
+ y_cal=y_cal_,
524
+ alpha=np.array(alpha),
525
+ sample_weight=sample_weight_, # type: ignore
526
+ group=group,
527
+ )
528
+
529
+ return self
530
+
531
+ def _validate_features(self, features: List[str]):
532
+ if len(features) > 0 and hasattr(self, "feature_names_in_"):
533
+ if features[0] != "0" and self.feature_names_in_[0] != "0":
534
+ if features != self.feature_names_in_:
535
+ raise ValueError(
536
+ f"Columns mismatch between data {features} passed, and data {self.feature_names_in_} used at fit."
537
+ )
538
+
539
+ def predict_intervals(self, X, parallel: Union[bool, None] = None) -> dict:
540
+ """
541
+ Predict intervals with the fitted booster on new data.
542
+
543
+ Parameters
544
+ ----------
545
+ X : array-like of shape (n_samples, n_features)
546
+ New data for prediction.
547
+ parallel : bool, optional
548
+ Whether to run prediction in parallel. If None, uses class default.
549
+
550
+ Returns
551
+ -------
552
+ intervals : dict
553
+ A dictionary containing lower and upper bounds for the specified alpha levels.
554
+ """
555
+ is_polars = type_df(X) == "polars_df"
556
+ if is_polars:
557
+ features_, columns, masks, rows, cols = transform_input_frame_columnar(
558
+ X, self.cat_mapping
559
+ )
560
+ self._validate_features(features_)
561
+ return self.booster.predict_intervals_columnar(
562
+ columns=columns, masks=masks, rows=rows, parallel=parallel
563
+ )
564
+
565
+ features_, flat_data, rows, cols = transform_input_frame(X, self.cat_mapping)
566
+ self._validate_features(features_)
567
+
568
+ return self.booster.predict_intervals(
569
+ flat_data=flat_data,
570
+ rows=rows,
571
+ cols=cols,
572
+ parallel=parallel,
573
+ )
574
+
575
+ def predict(self, X, parallel: Union[bool, None] = None) -> np.ndarray:
576
+ """
577
+ Predict with the fitted booster on new data.
578
+
579
+ Parameters
580
+ ----------
581
+ X : array-like of shape (n_samples, n_features)
582
+ Input features.
583
+ parallel : bool, optional
584
+ Whether to run prediction in parallel.
585
+
586
+ Returns
587
+ -------
588
+ predictions : ndarray of shape (n_samples,)
589
+ The predicted values (log-odds for classification, raw values for regression).
590
+ """
591
+ is_polars = type_df(X) == "polars_df"
592
+ if is_polars:
593
+ features_, columns, masks, rows, cols = transform_input_frame_columnar(
594
+ X, self.cat_mapping
595
+ )
596
+ else:
597
+ features_, flat_data, rows, cols = transform_input_frame(
598
+ X, self.cat_mapping
599
+ )
600
+ self._validate_features(features_)
601
+
602
+ if len(self.classes_) == 0:
603
+ if is_polars:
604
+ return self.booster.predict_columnar(
605
+ columns=columns, masks=masks, rows=rows, parallel=parallel
606
+ )
607
+ return self.booster.predict(
608
+ flat_data=flat_data, rows=rows, cols=cols, parallel=parallel
609
+ )
610
+ elif len(self.classes_) == 2:
611
+ if is_polars:
612
+ return np.rint(
613
+ self.booster.predict_proba_columnar(
614
+ columns=columns, masks=masks, rows=rows, parallel=parallel
615
+ )
616
+ ).astype(int)
617
+ return np.rint(
618
+ self.booster.predict_proba(
619
+ flat_data=flat_data, rows=rows, cols=cols, parallel=parallel
620
+ )
621
+ ).astype(int)
622
+ else:
623
+ if is_polars:
624
+ preds = self.booster.predict_columnar(
625
+ columns=columns, masks=masks, rows=rows, parallel=parallel
626
+ )
627
+ else:
628
+ preds = self.booster.predict(
629
+ flat_data=flat_data, rows=rows, cols=cols, parallel=parallel
630
+ )
631
+ preds_matrix = preds.reshape((-1, len(self.classes_)), order="F")
632
+ indices = np.argmax(preds_matrix, axis=1)
633
+ return np.array([self.classes_[i] for i in indices])
634
+
635
+ def predict_proba(self, X, parallel: Union[bool, None] = None) -> np.ndarray:
636
+ """
637
+ Predict class probabilities with the fitted booster on new data.
638
+
639
+ Only valid for classification tasks.
640
+
641
+ Parameters
642
+ ----------
643
+ X : array-like of shape (n_samples, n_features)
644
+ Input features.
645
+ parallel : bool, optional
646
+ Whether to run prediction in parallel.
647
+
648
+ Returns
649
+ -------
650
+ probabilities : ndarray of shape (n_samples, n_classes)
651
+ The class probabilities.
652
+ """
653
+ is_polars = type_df(X) == "polars_df"
654
+ if is_polars:
655
+ features_, columns, masks, rows, cols = transform_input_frame_columnar(
656
+ X, self.cat_mapping
657
+ )
658
+ else:
659
+ features_, flat_data, rows, cols = transform_input_frame(
660
+ X, self.cat_mapping
661
+ )
662
+ self._validate_features(features_)
663
+
664
+ if len(self.classes_) > 2:
665
+ if is_polars:
666
+ probabilities = self.booster.predict_proba_columnar(
667
+ columns=columns, masks=masks, rows=rows, parallel=parallel
668
+ )
669
+ else:
670
+ probabilities = self.booster.predict_proba(
671
+ flat_data=flat_data, rows=rows, cols=cols, parallel=parallel
672
+ )
673
+ return probabilities.reshape((-1, len(self.classes_)), order="C")
674
+ elif len(self.classes_) == 2:
675
+ if is_polars:
676
+ probabilities = self.booster.predict_proba_columnar(
677
+ columns=columns, masks=masks, rows=rows, parallel=parallel
678
+ )
679
+ else:
680
+ probabilities = self.booster.predict_proba(
681
+ flat_data=flat_data, rows=rows, cols=cols, parallel=parallel
682
+ )
683
+ return np.concatenate(
684
+ [(1.0 - probabilities).reshape(-1, 1), probabilities.reshape(-1, 1)],
685
+ axis=1,
686
+ )
687
+ else:
688
+ warnings.warn(
689
+ f"predict_proba not implemented for regression. n_classes = {len(self.classes_)}"
690
+ )
691
+ return np.ones((rows, 1))
692
+
693
+ def predict_log_proba(self, X, parallel: Union[bool, None] = None) -> np.ndarray:
694
+ """
695
+ Predict class log-probabilities with the fitted booster on new data.
696
+
697
+ Only valid for classification tasks.
698
+
699
+ Parameters
700
+ ----------
701
+ X : array-like of shape (n_samples, n_features)
702
+ Input features.
703
+ parallel : bool, optional
704
+ Whether to run prediction in parallel.
705
+
706
+ Returns
707
+ -------
708
+ log_probabilities : ndarray of shape (n_samples, n_classes)
709
+ The log-probabilities of each class.
710
+ """
711
+ is_polars = type_df(X) == "polars_df"
712
+ if is_polars:
713
+ features_, columns, masks, rows, cols = transform_input_frame_columnar(
714
+ X, self.cat_mapping
715
+ )
716
+ else:
717
+ features_, flat_data, rows, cols = transform_input_frame(
718
+ X, self.cat_mapping
719
+ )
720
+ self._validate_features(features_)
721
+
722
+ if len(self.classes_) > 2:
723
+ if is_polars:
724
+ preds = self.booster.predict_columnar(
725
+ columns=columns, masks=masks, rows=rows, parallel=parallel
726
+ )
727
+ else:
728
+ preds = self.booster.predict(
729
+ flat_data=flat_data,
730
+ rows=rows,
731
+ cols=cols,
732
+ parallel=parallel,
733
+ )
734
+ return preds.reshape((-1, len(self.classes_)), order="F")
735
+ elif len(self.classes_) == 2:
736
+ if is_polars:
737
+ return self.booster.predict_columnar(
738
+ columns=columns, masks=masks, rows=rows, parallel=parallel
739
+ )
740
+ return self.booster.predict(
741
+ flat_data=flat_data,
742
+ rows=rows,
743
+ cols=cols,
744
+ parallel=parallel,
745
+ )
746
+ else:
747
+ warnings.warn("predict_log_proba not implemented for regression.")
748
+ return np.ones((rows, 1))
749
+
750
+ def predict_nodes(self, X, parallel: Union[bool, None] = None) -> List:
751
+ """
752
+ Predict leaf node indices with the fitted booster on new data.
753
+
754
+ Parameters
755
+ ----------
756
+ X : array-like of shape (n_samples, n_features)
757
+ Input features.
758
+ parallel : bool, optional
759
+ Whether to run prediction in parallel.
760
+
761
+ Returns
762
+ -------
763
+ node_indices : list of ndarray
764
+ A list where each element corresponds to a tree and contains node indices
765
+ for each sample.
766
+ """
767
+ is_polars = type_df(X) == "polars_df"
768
+ if is_polars:
769
+ features_, columns, masks, rows, cols = transform_input_frame_columnar(
770
+ X, self.cat_mapping
771
+ )
772
+ self._validate_features(features_)
773
+ return self.booster.predict_nodes_columnar(
774
+ columns=columns, masks=masks, rows=rows, parallel=parallel
775
+ )
776
+
777
+ features_, flat_data, rows, cols = transform_input_frame(X, self.cat_mapping)
778
+ self._validate_features(features_)
779
+
780
+ return self.booster.predict_nodes(
781
+ flat_data=flat_data, rows=rows, cols=cols, parallel=parallel
782
+ )
783
+
784
+ @property
785
+ def feature_importances_(self) -> np.ndarray:
786
+ vals = self.calculate_feature_importance(
787
+ method=self.feature_importance_method, normalize=True
788
+ )
789
+ if hasattr(self, "feature_names_in_"):
790
+ vals = cast(Dict[str, float], vals)
791
+ return np.array([vals.get(ft, 0.0) for ft in self.feature_names_in_])
792
+ else:
793
+ vals = cast(Dict[int, float], vals)
794
+ return np.array([vals.get(ft, 0.0) for ft in range(self.n_features_)])
795
+
796
+ def predict_contributions(
797
+ self, X, method: str = "Average", parallel: Union[bool, None] = None
798
+ ) -> np.ndarray:
799
+ """
800
+ Predict feature contributions (SHAP-like values) for new data.
801
+
802
+ Parameters
803
+ ----------
804
+ X : array-like of shape (n_samples, n_features)
805
+ Input features.
806
+ method : str, default="Average"
807
+ Method to calculate contributions. Options:
808
+
809
+ - "Average": Internal node averages.
810
+ - "Shapley": Exact tree SHAP values.
811
+ - "Weight": Saabas-style leaf weights.
812
+ - "BranchDifference": Difference between chosen and other branch.
813
+ - "MidpointDifference": Weighted difference between branches.
814
+ - "ModeDifference": Difference from the most frequent node.
815
+ - "ProbabilityChange": Change in probability (LogLoss only).
816
+
817
+ parallel : bool, optional
818
+ Whether to run prediction in parallel.
819
+
820
+ Returns
821
+ -------
822
+ contributions : ndarray of shape (n_samples, n_features + 1)
823
+ The contribution of each feature to the prediction. The last column
824
+ is the bias term.
825
+ """
826
+ is_polars = type_df(X) == "polars_df"
827
+ if is_polars:
828
+ features_, columns, masks, rows, cols = transform_input_frame_columnar(
829
+ X, self.cat_mapping
830
+ )
831
+ self._validate_features(features_)
832
+ contributions = self.booster.predict_contributions_columnar(
833
+ columns=columns,
834
+ masks=masks,
835
+ rows=rows,
836
+ method=CONTRIBUTION_METHODS.get(method, method),
837
+ parallel=parallel,
838
+ )
839
+ else:
840
+ features_, flat_data, rows, cols = transform_input_frame(
841
+ X, self.cat_mapping
842
+ )
843
+ self._validate_features(features_)
844
+
845
+ contributions = self.booster.predict_contributions(
846
+ flat_data=flat_data,
847
+ rows=rows,
848
+ cols=cols,
849
+ method=CONTRIBUTION_METHODS.get(method, method),
850
+ parallel=parallel,
851
+ )
852
+
853
+ if len(self.classes_) > 2:
854
+ return (
855
+ np.reshape(contributions, (len(self.classes_), rows, cols + 1))
856
+ .transpose(1, 0, 2)
857
+ .reshape(rows, -1)
858
+ )
859
+ return np.reshape(contributions, (rows, cols + 1))
860
+
861
+ def partial_dependence(
862
+ self,
863
+ X,
864
+ feature: Union[str, int],
865
+ samples: Optional[int] = 100,
866
+ exclude_missing: bool = True,
867
+ percentile_bounds: Tuple[float, float] = (0.2, 0.98),
868
+ ) -> np.ndarray:
869
+ """
870
+ Calculate the partial dependence values of a feature.
871
+
872
+ For each unique value of the feature, this gives the estimate of the predicted
873
+ value for that feature, with the effects of all other features averaged out.
874
+
875
+ Parameters
876
+ ----------
877
+ X : array-like
878
+ Data used to calculate partial dependence. Should be the same format
879
+ as passed to :meth:`fit`.
880
+ feature : str or int
881
+ The feature for which to calculate partial dependence.
882
+ samples : int, optional, default=100
883
+ Number of evenly spaced samples to select. If None, all unique values are used.
884
+ exclude_missing : bool, optional, default=True
885
+ Whether to exclude missing values from the calculation.
886
+ percentile_bounds : tuple of float, optional, default=(0.2, 0.98)
887
+ Lower and upper percentiles for sample selection.
888
+
889
+ Returns
890
+ -------
891
+ pd_values : ndarray of shape (n_samples, 2)
892
+ The first column contains the feature values, and the second column
893
+ contains the partial dependence values.
894
+
895
+ Examples
896
+ --------
897
+ >>> import matplotlib.pyplot as plt
898
+ >>> pd_values = model.partial_dependence(X, feature="age")
899
+ >>> plt.plot(pd_values[:, 0], pd_values[:, 1])
900
+ """
901
+ if isinstance(feature, str):
902
+ is_polars = type_df(X) == "polars_df"
903
+ if not (type_df(X) == "pandas_df" or is_polars):
904
+ raise ValueError(
905
+ "If `feature` is a string, then the object passed as `X` must be a pandas or polars DataFrame."
906
+ )
907
+ if is_polars:
908
+ values = X[feature].to_numpy()
909
+ else:
910
+ values = X.loc[:, feature].to_numpy()
911
+
912
+ if hasattr(self, "feature_names_in_") and self.feature_names_in_[0] != "0":
913
+ [feature_idx] = [
914
+ i for i, v in enumerate(self.feature_names_in_) if v == feature
915
+ ]
916
+ else:
917
+ w_msg = (
918
+ "No feature names were provided at fit, but feature was a string, attempting to "
919
+ + "determine feature index from DataFrame Column, "
920
+ + "ensure columns are the same order as data passed when fit."
921
+ )
922
+ warnings.warn(w_msg)
923
+ features = X.columns if is_polars else X.columns.to_list()
924
+ [feature_idx] = [i for i, v in enumerate(features) if v == feature]
925
+ elif isinstance(feature, int):
926
+ feature_idx = feature
927
+ if type_df(X) == "pandas_df":
928
+ values = X.to_numpy()[:, feature]
929
+ elif type_df(X) == "polars_df":
930
+ values = X.to_numpy(allow_copy=False)[:, feature]
931
+ else:
932
+ values = X[:, feature]
933
+ else:
934
+ raise ValueError(
935
+ f"The parameter `feature` must be a string, or an int, however an object of type {type(feature)} was passed."
936
+ )
937
+ min_p, max_p = percentile_bounds
938
+ values = values[~(np.isnan(values) | (values == self.missing))]
939
+ if samples is None:
940
+ search_values = np.sort(np.unique(values))
941
+ else:
942
+ # Exclude missing from this calculation.
943
+ search_values = np.quantile(values, np.linspace(min_p, max_p, num=samples))
944
+
945
+ # Add missing back, if they wanted it...
946
+ if not exclude_missing:
947
+ search_values = np.append([self.missing], search_values)
948
+
949
+ res = []
950
+ for v in search_values:
951
+ res.append(
952
+ (v, self.booster.value_partial_dependence(feature=feature_idx, value=v))
953
+ )
954
+ return np.array(res)
955
+
956
+ def calculate_feature_importance(
957
+ self, method: str = "Gain", normalize: bool = True
958
+ ) -> Union[Dict[int, float], Dict[str, float]]:
959
+ """
960
+ Calculate feature importance for the model.
961
+
962
+ Parameters
963
+ ----------
964
+ method : str, optional, default="Gain"
965
+ Importance method. Options:
966
+
967
+ - "Weight": Number of times a feature is used in splits.
968
+ - "Gain": Average improvement in loss brought by a feature.
969
+ - "Cover": Average number of samples affected by splits on a feature.
970
+ - "TotalGain": Total improvement in loss brought by a feature.
971
+ - "TotalCover": Total number of samples affected by splits on a feature.
972
+
973
+ normalize : bool, optional, default=True
974
+ Whether to normalize importance scores to sum to 1.
975
+
976
+ Returns
977
+ -------
978
+ importance : dict
979
+ A dictionary mapping feature names (or indices) to importance scores.
980
+ """
981
+ importance_: Dict[int, float] = self.booster.calculate_feature_importance(
982
+ method=method,
983
+ normalize=normalize,
984
+ )
985
+ if hasattr(self, "feature_names_in_"):
986
+ feature_map: Dict[int, str] = {
987
+ i: f for i, f in enumerate(self.feature_names_in_)
988
+ }
989
+ return {feature_map[i]: v for i, v in importance_.items()}
990
+ return importance_
991
+
992
+ def text_dump(self) -> List[str]:
993
+ """
994
+ Return the booster model in a human-readable text format.
995
+
996
+ Returns
997
+ -------
998
+ dump : list of str
999
+ A list where each string represents a tree in the ensemble.
1000
+ """
1001
+ return self.booster.text_dump()
1002
+
1003
+ def json_dump(self) -> str:
1004
+ """
1005
+ Return the booster model in JSON format.
1006
+
1007
+ Returns
1008
+ -------
1009
+ dump : str
1010
+ The JSON representation of the model.
1011
+ """
1012
+ return self.booster.json_dump()
1013
+
1014
+ @classmethod
1015
+ def load_booster(cls, path: str) -> Self:
1016
+ """
1017
+ Load a booster model from a file.
1018
+
1019
+ Parameters
1020
+ ----------
1021
+ path : str
1022
+ Path to the saved booster (JSON format).
1023
+
1024
+ Returns
1025
+ -------
1026
+ model : PerpetualBooster
1027
+ The loaded booster object.
1028
+ """
1029
+ try:
1030
+ booster = CratePerpetualBooster.load_booster(str(path))
1031
+ except ValueError:
1032
+ booster = CrateMultiOutputBooster.load_booster(str(path))
1033
+
1034
+ params = booster.get_params()
1035
+ with warnings.catch_warnings():
1036
+ warnings.simplefilter("ignore")
1037
+ c = cls(**params)
1038
+ c.booster = booster
1039
+ for m in c.metadata_attributes:
1040
+ try:
1041
+ m_ = c._get_metadata_attributes(m)
1042
+ setattr(c, m, m_)
1043
+ # If "feature_names_in_" is present, we know a
1044
+ # pandas dataframe was used for fitting, in this case
1045
+ # get back the original monotonicity map, with the
1046
+ # feature names as keys.
1047
+ if m == "feature_names_in_" and c.feature_names_in_[0] != "0":
1048
+ if c.monotone_constraints is not None:
1049
+ c.monotone_constraints = {
1050
+ ft: c.monotone_constraints[i]
1051
+ for i, ft in enumerate(c.feature_names_in_)
1052
+ }
1053
+ except KeyError:
1054
+ pass
1055
+ return c
1056
+
1057
+ def save_booster(self, path: str):
1058
+ """
1059
+ Save the booster model to a file.
1060
+
1061
+ The model is saved in a JSON-based format.
1062
+
1063
+ Parameters
1064
+ ----------
1065
+ path : str
1066
+ Path where the model will be saved.
1067
+ """
1068
+ self.booster.save_booster(str(path))
1069
+
1070
+ def _standardize_monotonicity_map(
1071
+ self,
1072
+ X,
1073
+ ) -> Dict[int, Any]:
1074
+ if isinstance(X, np.ndarray):
1075
+ return self.monotone_constraints
1076
+ else:
1077
+ feature_map = {f: i for i, f in enumerate(X.columns)}
1078
+ return {feature_map[f]: v for f, v in self.monotone_constraints.items()}
1079
+
1080
+ def _standardize_terminate_missing_features(
1081
+ self,
1082
+ X,
1083
+ ) -> Set[int]:
1084
+ if isinstance(X, np.ndarray):
1085
+ return set(self.terminate_missing_features)
1086
+ else:
1087
+ feature_map = {f: i for i, f in enumerate(X.columns)}
1088
+ return set(feature_map[f] for f in self.terminate_missing_features)
1089
+
1090
+ def insert_metadata(self, key: str, value: str):
1091
+ """
1092
+ Insert metadata into the model.
1093
+
1094
+ Metadata is saved alongside the model and can be retrieved later.
1095
+
1096
+ Parameters
1097
+ ----------
1098
+ key : str
1099
+ The key for the metadata item.
1100
+ value : str
1101
+ The value for the metadata item.
1102
+ """
1103
+ self.booster.insert_metadata(key=key, value=value)
1104
+
1105
+ def get_metadata(self, key: str) -> str:
1106
+ """
1107
+ Get metadata associated with a given key.
1108
+
1109
+ Parameters
1110
+ ----------
1111
+ key : str
1112
+ The key to look up in the metadata.
1113
+
1114
+ Returns
1115
+ -------
1116
+ value : str
1117
+ The value associated with the key.
1118
+ """
1119
+ v = self.booster.get_metadata(key=key)
1120
+ return v
1121
+
1122
+ def _set_metadata_attributes(self, key: str, value: Any) -> None:
1123
+ value_ = self.metadata_attributes[key].serialize(value)
1124
+ self.insert_metadata(key=key, value=value_)
1125
+
1126
+ def _get_metadata_attributes(self, key: str) -> Any:
1127
+ value = self.get_metadata(key)
1128
+ return self.metadata_attributes[key].deserialize(value)
1129
+
1130
+ @property
1131
+ def base_score(self) -> Union[float, Iterable[float]]:
1132
+ """
1133
+ The base score(s) of the model.
1134
+
1135
+ Returns
1136
+ -------
1137
+ score : float or iterable of float
1138
+ The initial prediction value(s) of the model.
1139
+ """
1140
+ return self.booster.base_score
1141
+
1142
+ @property
1143
+ def number_of_trees(self) -> Union[int, Iterable[int]]:
1144
+ """
1145
+ The number of trees in the ensemble.
1146
+
1147
+ Returns
1148
+ -------
1149
+ n_trees : int or iterable of int
1150
+ Total number of trees.
1151
+ """
1152
+ return self.booster.number_of_trees
1153
+
1154
+ # Make picklable with getstate and setstate
1155
+ def __getstate__(self) -> Dict[Any, Any]:
1156
+ booster_json = self.json_dump()
1157
+ # Delete booster
1158
+ # Doing it like this, so it doesn't delete it globally.
1159
+ res = {k: v for k, v in self.__dict__.items() if k != "booster"}
1160
+ res["__booster_json_file__"] = booster_json
1161
+ return res
1162
+
1163
+ def __setstate__(self, d: Dict[Any, Any]) -> None:
1164
+ # Load the booster object the pickled JSon string.
1165
+ try:
1166
+ booster_object = CratePerpetualBooster.from_json(d["__booster_json_file__"])
1167
+ except ValueError:
1168
+ booster_object = CrateMultiOutputBooster.from_json(
1169
+ d["__booster_json_file__"]
1170
+ )
1171
+ d["booster"] = booster_object
1172
+ # Are there any new parameters, that need to be added to the python object,
1173
+ # that would have been loaded in as defaults on the json object?
1174
+ # This makes sure that defaults set with a serde default function get
1175
+ # carried through to the python object.
1176
+ for p, v in booster_object.get_params().items():
1177
+ if p not in d:
1178
+ d[p] = v
1179
+ del d["__booster_json_file__"]
1180
+ self.__dict__ = d
1181
+
1182
+ # Functions for scikit-learn compatibility, will feel out adding these manually,
1183
+ # and then if that feels too unwieldy will add scikit-learn as a dependency.
1184
+ def get_params(self, deep=True) -> Dict[str, Any]:
1185
+ """
1186
+ Get parameters for this booster.
1187
+
1188
+ Parameters
1189
+ ----------
1190
+ deep : bool, default=True
1191
+ Currently ignored, exists for scikit-learn compatibility.
1192
+
1193
+ Returns
1194
+ -------
1195
+ params : dict
1196
+ Parameter names mapped to their values.
1197
+ """
1198
+ args = inspect.getfullargspec(PerpetualBooster).kwonlyargs
1199
+ return {param: getattr(self, param) for param in args}
1200
+
1201
+ def set_params(self, **params: Any) -> Self:
1202
+ """
1203
+ Set parameters for this booster.
1204
+
1205
+ Parameters
1206
+ ----------
1207
+ **params : dict
1208
+ Booster parameters.
1209
+
1210
+ Returns
1211
+ -------
1212
+ self : object
1213
+ Returns self.
1214
+ """
1215
+ old_params = self.get_params()
1216
+ old_params.update(params)
1217
+ PerpetualBooster.__init__(self, **old_params)
1218
+ return self
1219
+
1220
+ def get_node_lists(self, map_features_names: bool = True) -> List[List[Node]]:
1221
+ """
1222
+ Return tree structures as lists of node objects.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ map_features_names : bool, default=True
1227
+ Whether to use feature names instead of indices.
1228
+
1229
+ Returns
1230
+ -------
1231
+ trees : list of list of Node
1232
+ Each inner list represents a tree.
1233
+ """
1234
+ dump = json.loads(self.json_dump())
1235
+ if "trees" in dump:
1236
+ all_booster_trees = [dump["trees"]]
1237
+ else:
1238
+ # Multi-output
1239
+ all_booster_trees = [b["trees"] for b in dump["boosters"]]
1240
+
1241
+ feature_map: Union[Dict[int, str], Dict[int, int]]
1242
+ leaf_split_feature: Union[str, int]
1243
+ if map_features_names and hasattr(self, "feature_names_in_"):
1244
+ feature_map = {i: ft for i, ft in enumerate(self.feature_names_in_)}
1245
+ leaf_split_feature = ""
1246
+ else:
1247
+ feature_map = {i: i for i in range(self.n_features_)}
1248
+ leaf_split_feature = -1
1249
+
1250
+ trees = []
1251
+ for booster_trees in all_booster_trees:
1252
+ for t in booster_trees:
1253
+ nodes = []
1254
+ for node in t["nodes"].values():
1255
+ if not node["is_leaf"]:
1256
+ node["split_feature"] = feature_map[node["split_feature"]]
1257
+ else:
1258
+ node["split_feature"] = leaf_split_feature
1259
+ nodes.append(Node(**node))
1260
+ trees.append(nodes)
1261
+ return trees
1262
+
1263
+ def trees_to_dataframe(self) -> Any:
1264
+ """
1265
+ Return the tree structures as a DataFrame.
1266
+
1267
+ Returns
1268
+ -------
1269
+ df : DataFrame
1270
+ A Polars or Pandas DataFrame containing tree information.
1271
+ """
1272
+
1273
+ def node_to_row(
1274
+ n: Node,
1275
+ tree_n: int,
1276
+ ) -> Dict[str, Any]:
1277
+ def _id(i: int) -> str:
1278
+ return f"{tree_n}-{i}"
1279
+
1280
+ return dict(
1281
+ Tree=tree_n,
1282
+ Node=n.num,
1283
+ ID=_id(n.num),
1284
+ Feature="Leaf" if n.is_leaf else str(n.split_feature),
1285
+ Split=None if n.is_leaf else n.split_value,
1286
+ Yes=None if n.is_leaf else _id(n.left_child),
1287
+ No=None if n.is_leaf else _id(n.right_child),
1288
+ Missing=None if n.is_leaf else _id(n.missing_node),
1289
+ Gain=n.weight_value if n.is_leaf else n.split_gain,
1290
+ Cover=n.hessian_sum,
1291
+ Left_Cats=n.left_cats,
1292
+ Right_Cats=n.right_cats,
1293
+ )
1294
+
1295
+ # Flatten list of lists using list comprehension
1296
+ vals = [
1297
+ node_to_row(n, i)
1298
+ for i, tree in enumerate(self.get_node_lists())
1299
+ for n in tree
1300
+ ]
1301
+
1302
+ try:
1303
+ import polars as pl
1304
+
1305
+ return pl.from_records(vals).sort(
1306
+ ["Tree", "Node"], descending=[False, False]
1307
+ )
1308
+ except ImportError:
1309
+ import pandas as pd
1310
+
1311
+ return pd.DataFrame.from_records(vals).sort_values(
1312
+ ["Tree", "Node"], ascending=[True, True]
1313
+ )
1314
+
1315
+ def _to_xgboost_json(self) -> Dict[str, Any]:
1316
+ """Convert the Perpetual model to an XGBoost JSON model structure."""
1317
+
1318
+ # Check if it's a multi-output model
1319
+ is_multi = len(self.classes_) > 2
1320
+
1321
+ # Get raw dump
1322
+ raw_dump = json.loads(self.json_dump())
1323
+
1324
+ # Initialize XGBoost structure
1325
+ xgb_json = {
1326
+ "learner": {
1327
+ "attributes": {},
1328
+ "feature_names": [],
1329
+ "feature_types": [],
1330
+ "gradient_booster": {
1331
+ "model": {
1332
+ "gbtree_model_param": {
1333
+ "num_parallel_tree": "1",
1334
+ },
1335
+ "trees": [],
1336
+ "tree_info": [],
1337
+ "iteration_indptr": [],
1338
+ "cats": {
1339
+ "enc": [],
1340
+ "feature_segments": [],
1341
+ "sorted_idx": [],
1342
+ },
1343
+ },
1344
+ "name": "gbtree",
1345
+ },
1346
+ "learner_model_param": {
1347
+ "boost_from_average": "1",
1348
+ "num_feature": str(self.n_features_),
1349
+ },
1350
+ "objective": {
1351
+ "name": "binary:logistic",
1352
+ },
1353
+ },
1354
+ "version": [3, 1, 3], # Use a reasonably recent version
1355
+ }
1356
+
1357
+ # Fill feature names if available
1358
+ if hasattr(self, "feature_names_in_"):
1359
+ xgb_json["learner"]["feature_names"] = self.feature_names_in_
1360
+ xgb_json["learner"]["feature_types"] = ["float"] * self.n_features_
1361
+ else:
1362
+ xgb_json["learner"]["feature_names"] = [
1363
+ f"f{i}" for i in range(self.n_features_)
1364
+ ]
1365
+ xgb_json["learner"]["feature_types"] = ["float"] * self.n_features_
1366
+
1367
+ # Objective and Base Score Handling
1368
+ if is_multi:
1369
+ # Multi-class
1370
+ n_classes = len(self.classes_)
1371
+ xgb_json["learner"]["objective"]["name"] = "multi:softprob"
1372
+ xgb_json["learner"]["objective"]["softmax_multiclass_param"] = {
1373
+ "num_class": str(n_classes)
1374
+ }
1375
+ xgb_json["learner"]["learner_model_param"]["num_class"] = str(n_classes)
1376
+ xgb_json["learner"]["learner_model_param"]["num_target"] = "1"
1377
+
1378
+ # Base score vector [0.5, 0.5, ...]
1379
+ # 5.0E-1
1380
+ base_score_str = ",".join(["5.0E-1"] * n_classes)
1381
+ xgb_json["learner"]["learner_model_param"]["base_score"] = (
1382
+ f"[{base_score_str}]"
1383
+ )
1384
+
1385
+ boosters = raw_dump["boosters"]
1386
+
1387
+ trees = []
1388
+ tree_info = []
1389
+ # For multi-class, we need to interleave trees if we want to follow XGBoost structure perfectly?
1390
+ # Or can we just dump them? iteration_indptr depends on this.
1391
+ # XGBoost expects trees for iteration i to be contiguous.
1392
+ # Perpetual stores boosters separately.
1393
+ # Booster 0 has trees for class 0. Booster 1 for class 1.
1394
+ # We need to rearrange them: Round 0 (C0, C1, C2), Round 1 (C0, C1, C2)...
1395
+
1396
+ # Assuming all boosters have same number of trees?
1397
+ num_trees_per_booster = [len(b["trees"]) for b in boosters]
1398
+ max_trees = max(num_trees_per_booster) if num_trees_per_booster else 0
1399
+
1400
+ # Iteration pointers: 0, 3, 6...
1401
+ # But what if some booster has fewer trees? (Early stopping might cause this?)
1402
+ # Perpetual implementation usually stops all or none?
1403
+ # "MultiOutputBooster::fit" trains them sequentially but they might have different tree counts if EarlyStopping is per-booster.
1404
+ # But XGBoost expects consistent num_class trees per round (or use "multi:softprob"?)
1405
+ # If we just list them, XGBoost might get confused if we don't align them.
1406
+
1407
+ # Let's try to align them by round.
1408
+
1409
+ iteration_indptr = [0]
1410
+ current_ptr = 0
1411
+
1412
+ for round_idx in range(max_trees):
1413
+ # For each class
1414
+ for group_id, booster_dump in enumerate(boosters):
1415
+ booster_trees = booster_dump["trees"]
1416
+ if round_idx < len(booster_trees):
1417
+ tree = booster_trees[round_idx]
1418
+ base_score = booster_dump["base_score"]
1419
+
1420
+ xgb_tree = self._convert_tree(tree, current_ptr)
1421
+
1422
+ if round_idx == 0:
1423
+ self._adjust_tree_leaves(xgb_tree, base_score)
1424
+
1425
+ trees.append(xgb_tree)
1426
+ tree_info.append(group_id)
1427
+ current_ptr += 1
1428
+ else:
1429
+ # Missing tree for this class in this round?
1430
+ # Should we insert a dummy tree (0 prediction)?
1431
+ # For now, let's assume balanced trees or hope XGB handles it.
1432
+ # If we skip, tree_info tracks class.
1433
+ pass
1434
+
1435
+ iteration_indptr.append(current_ptr)
1436
+
1437
+ xgb_json["learner"]["gradient_booster"]["model"]["trees"] = trees
1438
+ xgb_json["learner"]["gradient_booster"]["model"]["tree_info"] = tree_info
1439
+ xgb_json["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
1440
+ "num_trees"
1441
+ ] = str(len(trees))
1442
+ xgb_json["learner"]["gradient_booster"]["model"]["iteration_indptr"] = (
1443
+ iteration_indptr
1444
+ )
1445
+
1446
+ else:
1447
+ # Binary or Regression
1448
+ if self.objective == "LogLoss":
1449
+ xgb_json["learner"]["objective"]["name"] = "binary:logistic"
1450
+ xgb_json["learner"]["objective"]["reg_loss_param"] = {
1451
+ "scale_pos_weight": "1"
1452
+ }
1453
+ xgb_json["learner"]["learner_model_param"]["num_class"] = "0"
1454
+ xgb_json["learner"]["learner_model_param"]["num_target"] = "1"
1455
+
1456
+ # Base Score
1457
+ base_score_val = 1.0 / (1.0 + np.exp(-raw_dump["base_score"]))
1458
+ xgb_json["learner"]["learner_model_param"]["base_score"] = (
1459
+ f"[{base_score_val:.6E}]"
1460
+ )
1461
+
1462
+ elif self.objective == "SquaredLoss":
1463
+ xgb_json["learner"]["objective"]["name"] = "reg:squarederror"
1464
+ xgb_json["learner"]["objective"]["reg_loss_param"] = {}
1465
+ xgb_json["learner"]["learner_model_param"]["num_class"] = "0"
1466
+ xgb_json["learner"]["learner_model_param"]["num_target"] = "1"
1467
+ xgb_json["learner"]["learner_model_param"]["base_score"] = (
1468
+ f"[{raw_dump['base_score']:.6E}]"
1469
+ )
1470
+ else:
1471
+ warnings.warn(
1472
+ f"Objective {self.objective} not explicitly supported for XGBoost export. Defaulting to reg:squarederror."
1473
+ )
1474
+ xgb_json["learner"]["objective"]["name"] = "reg:squarederror"
1475
+ xgb_json["learner"]["objective"]["reg_loss_param"] = {}
1476
+ xgb_json["learner"]["learner_model_param"]["num_class"] = "0"
1477
+ xgb_json["learner"]["learner_model_param"]["num_target"] = "1"
1478
+ xgb_json["learner"]["learner_model_param"]["base_score"] = (
1479
+ f"[{raw_dump['base_score']:.6E}]"
1480
+ )
1481
+
1482
+ trees = []
1483
+ tree_info = []
1484
+ for tree_idx, tree in enumerate(raw_dump["trees"]):
1485
+ xgb_tree = self._convert_tree(tree, tree_idx)
1486
+ trees.append(xgb_tree)
1487
+ tree_info.append(0)
1488
+
1489
+ xgb_json["learner"]["gradient_booster"]["model"]["trees"] = trees
1490
+ xgb_json["learner"]["gradient_booster"]["model"]["tree_info"] = tree_info
1491
+ xgb_json["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
1492
+ "num_trees"
1493
+ ] = str(len(trees))
1494
+ xgb_json["learner"]["gradient_booster"]["model"]["iteration_indptr"] = list(
1495
+ range(len(trees) + 1)
1496
+ )
1497
+
1498
+ return xgb_json
1499
+
1500
+ def _convert_tree(self, tree: Dict[str, Any], group_id: int) -> Dict[str, Any]:
1501
+ """Convert a single Perpetual tree to XGBoost dictionary format."""
1502
+
1503
+ nodes_dict = tree["nodes"]
1504
+ # Convert keys to int and sort
1505
+ sorted_keys = sorted(nodes_dict.keys(), key=lambda x: int(x))
1506
+
1507
+ # Mapping from Perpetual ID (int) to XGBoost Array Index (int)
1508
+ node_map = {int(k): i for i, k in enumerate(sorted_keys)}
1509
+
1510
+ num_nodes = len(sorted_keys)
1511
+ # print(f"DEBUG: Converting tree group={group_id}. num_nodes={num_nodes}")
1512
+
1513
+ left_children = [-1] * num_nodes
1514
+ right_children = [-1] * num_nodes
1515
+ parents = [2147483647] * num_nodes
1516
+ split_indices = [0] * num_nodes
1517
+ split_conditions = [0.0] * num_nodes
1518
+ split_type = [0] * num_nodes
1519
+ sum_hessian = [0.0] * num_nodes
1520
+ loss_changes = [0.0] * num_nodes
1521
+ base_weights = [0.0] * num_nodes
1522
+ default_left = [0] * num_nodes
1523
+
1524
+ categories = []
1525
+ categories_nodes = []
1526
+ categories_segments = []
1527
+ categories_sizes = []
1528
+
1529
+ for i, k in enumerate(sorted_keys):
1530
+ node = nodes_dict[k]
1531
+ nid = int(node["num"])
1532
+ idx = node_map[nid]
1533
+
1534
+ # print(f" DEBUG: Node {i} nid={nid} idx={idx}")
1535
+
1536
+ sum_hessian[idx] = node["hessian_sum"]
1537
+ base_weights[idx] = node["weight_value"]
1538
+ loss_changes[idx] = node.get("split_gain", 0.0)
1539
+
1540
+ if node["is_leaf"]:
1541
+ left_children[idx] = -1
1542
+ right_children[idx] = -1
1543
+ split_indices[idx] = 0
1544
+ split_conditions[idx] = node["weight_value"]
1545
+ else:
1546
+ left_id = node["left_child"]
1547
+ right_id = node["right_child"]
1548
+
1549
+ left_idx = node_map[left_id]
1550
+ right_idx = node_map[right_id]
1551
+
1552
+ left_children[idx] = left_idx
1553
+ right_children[idx] = right_idx
1554
+ parents[left_idx] = idx
1555
+ parents[right_idx] = idx
1556
+
1557
+ split_indices[idx] = node["split_feature"]
1558
+ split_conditions[idx] = node["split_value"]
1559
+
1560
+ # Missing handling
1561
+ # If missing_node goes left
1562
+ if node["missing_node"] == left_id:
1563
+ default_left[idx] = 1
1564
+ else:
1565
+ default_left[idx] = 0
1566
+
1567
+ if (
1568
+ "left_cats" in node
1569
+ and node["left_cats"] is not None
1570
+ and len(node["left_cats"]) > 0
1571
+ ):
1572
+ # It's a categorical split
1573
+ cats = node["left_cats"]
1574
+ # XGBoost uses split_type=1 for categorical?
1575
+ # Or just presence in categories_nodes?
1576
+ # Docs say: split_type [default=0]: 0=numerical, 1=categorical
1577
+ split_type[idx] = 1
1578
+
1579
+ # Update categorical arrays
1580
+ categories_nodes.append(idx)
1581
+ categories_sizes.append(len(cats))
1582
+ # Segment is start index.
1583
+ # If this is the first one, 0. Else prev_segment + prev_size?
1584
+ # Actually valid XGBoost format usually has segments as exclusive scan.
1585
+ # [0, len0, len0+len1, ...]
1586
+ # Wait, segments length should be same as nodes?
1587
+ # Let's check logic:
1588
+ # segments[i] points to start of cats for node i (in categories_nodes)
1589
+
1590
+ next_segment = (
1591
+ (categories_segments[-1] + categories_sizes[-2])
1592
+ if categories_segments
1593
+ else 0
1594
+ )
1595
+ categories_segments.append(next_segment)
1596
+
1597
+ categories.extend(sorted(cats))
1598
+
1599
+ # split_condition for categorical is usually NaN or special?
1600
+ # XGBoost JSON parser might ignore it if type is categorical
1601
+ # But often it is set to something.
1602
+
1603
+ return {
1604
+ "base_weights": base_weights,
1605
+ "default_left": default_left,
1606
+ "id": group_id,
1607
+ "left_children": left_children,
1608
+ "loss_changes": loss_changes,
1609
+ "parents": parents,
1610
+ "right_children": right_children,
1611
+ "split_conditions": split_conditions,
1612
+ "split_indices": split_indices,
1613
+ "split_type": split_type,
1614
+ "sum_hessian": sum_hessian,
1615
+ "tree_param": {
1616
+ "num_deleted": "0",
1617
+ "num_feature": str(self.n_features_),
1618
+ "num_nodes": str(num_nodes),
1619
+ "size_leaf_vector": "1",
1620
+ },
1621
+ "categories": categories,
1622
+ "categories_nodes": categories_nodes,
1623
+ "categories_segments": categories_segments,
1624
+ "categories_sizes": categories_sizes,
1625
+ }
1626
+
1627
+ def _adjust_tree_leaves(self, xgb_tree: Dict[str, Any], adjustment: float):
1628
+ """Add adjustment value to all leaves in an XGBoost tree dict."""
1629
+ left_children = xgb_tree["left_children"]
1630
+ split_conditions = xgb_tree["split_conditions"]
1631
+ base_weights = xgb_tree["base_weights"]
1632
+
1633
+ for i, left in enumerate(left_children):
1634
+ if left == -1: # Leaf
1635
+ split_conditions[i] += adjustment
1636
+ base_weights[i] += adjustment
1637
+
1638
+ def save_as_xgboost(self, path: str):
1639
+ """
1640
+ Save the model in XGBoost JSON format.
1641
+
1642
+ Parameters
1643
+ ----------
1644
+ path : str
1645
+ The path where the XGBoost-compatible model will be saved.
1646
+ """
1647
+ xgboost_json = self._to_xgboost_json()
1648
+ with open(path, "w") as f:
1649
+ json.dump(xgboost_json, f, indent=2)
1650
+
1651
+ def save_as_onnx(self, path: str, name: str = "perpetual_model"):
1652
+ """
1653
+ Save the model in ONNX format.
1654
+
1655
+ Parameters
1656
+ ----------
1657
+ path : str
1658
+ The path where the ONNX model will be saved.
1659
+ name : str, optional, default="perpetual_model"
1660
+ The name of the graph in the exported model.
1661
+ """
1662
+ import json
1663
+
1664
+ import onnx
1665
+ from onnx import TensorProto, helper
1666
+
1667
+ raw_dump = json.loads(self.json_dump())
1668
+ is_classifier = len(self.classes_) >= 2
1669
+ is_multi = is_classifier and len(self.classes_) > 2
1670
+ n_classes = len(self.classes_) if is_classifier else 1
1671
+
1672
+ if "trees" in raw_dump:
1673
+ booster_data = [{"trees": raw_dump["trees"]}]
1674
+ else:
1675
+ booster_data = raw_dump["boosters"]
1676
+
1677
+ feature_map_inverse = (
1678
+ {v: k for k, v in enumerate(self.feature_names_in_)}
1679
+ if hasattr(self, "feature_names_in_")
1680
+ else None
1681
+ )
1682
+
1683
+ nodes_treeids = []
1684
+ nodes_nodeids = []
1685
+ nodes_featureids = []
1686
+ nodes_values = []
1687
+ nodes_modes = []
1688
+ nodes_truenodeids = []
1689
+ nodes_falsenodeids = []
1690
+ nodes_missing_value_tracks_true = []
1691
+
1692
+ target_treeids = []
1693
+ target_nodeids = []
1694
+ target_ids = []
1695
+ target_weights = []
1696
+
1697
+ # Base score handling
1698
+ base_score = self.base_score
1699
+ if is_classifier:
1700
+ if is_multi:
1701
+ base_values = [float(b) for b in base_score]
1702
+ else:
1703
+ base_values = [float(base_score)]
1704
+ else:
1705
+ base_values = [float(base_score)]
1706
+
1707
+ global_tree_idx = 0
1708
+ for b_idx, booster in enumerate(booster_data):
1709
+ for tree_data in booster["trees"]:
1710
+ nodes_dict = tree_data["nodes"]
1711
+ node_keys = sorted(nodes_dict.keys(), key=lambda x: int(x))
1712
+
1713
+ node_id_to_idx = {}
1714
+ for i, k in enumerate(node_keys):
1715
+ node_id_to_idx[int(k)] = i
1716
+
1717
+ for k in node_keys:
1718
+ node_dict = nodes_dict[k]
1719
+ nid = int(node_dict["num"])
1720
+ idx_for_onnx = node_id_to_idx[nid]
1721
+
1722
+ nodes_treeids.append(global_tree_idx)
1723
+ nodes_nodeids.append(idx_for_onnx)
1724
+
1725
+ if node_dict["is_leaf"]:
1726
+ nodes_modes.append("LEAF")
1727
+ nodes_featureids.append(0)
1728
+ nodes_values.append(0.0)
1729
+ nodes_truenodeids.append(0)
1730
+ nodes_falsenodeids.append(0)
1731
+ nodes_missing_value_tracks_true.append(0)
1732
+
1733
+ target_treeids.append(global_tree_idx)
1734
+ target_nodeids.append(idx_for_onnx)
1735
+ target_ids.append(b_idx if is_multi else 0)
1736
+ target_weights.append(float(node_dict["weight_value"]))
1737
+ else:
1738
+ nodes_modes.append("BRANCH_LT")
1739
+ feat_val = node_dict["split_feature"]
1740
+ f_idx = 0
1741
+ if isinstance(feat_val, int):
1742
+ f_idx = feat_val
1743
+ elif feature_map_inverse and feat_val in feature_map_inverse:
1744
+ f_idx = feature_map_inverse[feat_val]
1745
+ elif isinstance(feat_val, str) and feat_val.isdigit():
1746
+ f_idx = int(feat_val)
1747
+
1748
+ nodes_featureids.append(f_idx)
1749
+ nodes_values.append(float(node_dict["split_value"]))
1750
+
1751
+ tracks_true = 0
1752
+ if node_dict["missing_node"] == node_dict["left_child"]:
1753
+ tracks_true = 1
1754
+ nodes_missing_value_tracks_true.append(tracks_true)
1755
+
1756
+ nodes_truenodeids.append(
1757
+ node_id_to_idx[int(node_dict["left_child"])]
1758
+ )
1759
+ nodes_falsenodeids.append(
1760
+ node_id_to_idx[int(node_dict["right_child"])]
1761
+ )
1762
+
1763
+ global_tree_idx += 1
1764
+
1765
+ input_name = "input"
1766
+ input_type = helper.make_tensor_value_info(
1767
+ input_name, TensorProto.FLOAT, [None, self.n_features_]
1768
+ )
1769
+
1770
+ raw_scores_name = "raw_scores"
1771
+ reg_node = helper.make_node(
1772
+ "TreeEnsembleRegressor",
1773
+ inputs=[input_name],
1774
+ outputs=[raw_scores_name],
1775
+ domain="ai.onnx.ml",
1776
+ nodes_treeids=nodes_treeids,
1777
+ nodes_nodeids=nodes_nodeids,
1778
+ nodes_featureids=nodes_featureids,
1779
+ nodes_values=nodes_values,
1780
+ nodes_modes=nodes_modes,
1781
+ nodes_truenodeids=nodes_truenodeids,
1782
+ nodes_falsenodeids=nodes_falsenodeids,
1783
+ nodes_missing_value_tracks_true=nodes_missing_value_tracks_true,
1784
+ target_treeids=target_treeids,
1785
+ target_nodeids=target_nodeids,
1786
+ target_ids=target_ids,
1787
+ target_weights=target_weights,
1788
+ base_values=base_values,
1789
+ n_targets=n_classes if is_multi else 1,
1790
+ name="PerpetualTreeEnsemble",
1791
+ )
1792
+
1793
+ ops = [reg_node]
1794
+ if is_classifier:
1795
+ # Prepare class labels mapping
1796
+ classes = self.classes_
1797
+ if all(isinstance(c, (int, np.integer)) for c in classes):
1798
+ tensor_type = TensorProto.INT64
1799
+ classes_array = np.array(classes, dtype=np.int64)
1800
+ elif all(isinstance(c, (float, np.floating)) for c in classes):
1801
+ tensor_type = TensorProto.FLOAT
1802
+ classes_array = np.array(classes, dtype=np.float32)
1803
+ else:
1804
+ tensor_type = TensorProto.STRING
1805
+ classes_array = np.array([str(c) for c in classes], dtype=object)
1806
+
1807
+ classes_name = "class_labels"
1808
+ if tensor_type == TensorProto.STRING:
1809
+ classes_const_node = helper.make_node(
1810
+ "Constant",
1811
+ [],
1812
+ [classes_name],
1813
+ value=helper.make_tensor(
1814
+ name="classes_tensor",
1815
+ data_type=tensor_type,
1816
+ dims=[len(classes)],
1817
+ vals=[s.encode("utf-8") for s in classes_array],
1818
+ ),
1819
+ )
1820
+ else:
1821
+ classes_const_node = helper.make_node(
1822
+ "Constant",
1823
+ [],
1824
+ [classes_name],
1825
+ value=helper.make_tensor(
1826
+ name="classes_tensor",
1827
+ data_type=tensor_type,
1828
+ dims=[len(classes)],
1829
+ vals=classes_array.flatten().tolist(),
1830
+ ),
1831
+ )
1832
+ ops.append(classes_const_node)
1833
+
1834
+ if is_multi:
1835
+ prob_name = "probabilities"
1836
+ softmax_node = helper.make_node(
1837
+ "Softmax", [raw_scores_name], [prob_name], axis=1
1838
+ )
1839
+ label_idx_name = "label_idx"
1840
+ argmax_node = helper.make_node(
1841
+ "ArgMax", [prob_name], [label_idx_name], axis=1, keepdims=0
1842
+ )
1843
+ label_name = "label"
1844
+ gather_node = helper.make_node(
1845
+ "Gather", [classes_name, label_idx_name], [label_name], axis=0
1846
+ )
1847
+ ops.extend([softmax_node, argmax_node, gather_node])
1848
+ outputs = [
1849
+ helper.make_tensor_value_info(label_name, tensor_type, [None]),
1850
+ helper.make_tensor_value_info(
1851
+ prob_name, TensorProto.FLOAT, [None, n_classes]
1852
+ ),
1853
+ ]
1854
+ else:
1855
+ p_name = "p"
1856
+ sigmoid_node = helper.make_node("Sigmoid", [raw_scores_name], [p_name])
1857
+ one_name = "one"
1858
+ one_node = helper.make_node(
1859
+ "Constant",
1860
+ [],
1861
+ [one_name],
1862
+ value=helper.make_tensor("one_v", TensorProto.FLOAT, [1, 1], [1.0]),
1863
+ )
1864
+ one_minus_p_name = "one_minus_p"
1865
+ sub_node = helper.make_node(
1866
+ "Sub", [one_name, p_name], [one_minus_p_name]
1867
+ )
1868
+ prob_name = "probabilities"
1869
+ concat_node = helper.make_node(
1870
+ "Concat", [one_minus_p_name, p_name], [prob_name], axis=1
1871
+ )
1872
+ label_idx_name = "label_idx"
1873
+ argmax_node = helper.make_node(
1874
+ "ArgMax", [prob_name], [label_idx_name], axis=1, keepdims=0
1875
+ )
1876
+ label_name = "label"
1877
+ gather_node = helper.make_node(
1878
+ "Gather", [classes_name, label_idx_name], [label_name], axis=0
1879
+ )
1880
+ ops.extend(
1881
+ [
1882
+ sigmoid_node,
1883
+ one_node,
1884
+ sub_node,
1885
+ concat_node,
1886
+ argmax_node,
1887
+ gather_node,
1888
+ ]
1889
+ )
1890
+ outputs = [
1891
+ helper.make_tensor_value_info(label_name, tensor_type, [None]),
1892
+ helper.make_tensor_value_info(
1893
+ prob_name, TensorProto.FLOAT, [None, 2]
1894
+ ),
1895
+ ]
1896
+ else:
1897
+ prediction_name = "prediction"
1898
+ reg_node.output[0] = prediction_name
1899
+ outputs = [
1900
+ helper.make_tensor_value_info(
1901
+ prediction_name, TensorProto.FLOAT, [None, 1]
1902
+ )
1903
+ ]
1904
+
1905
+ graph_def = helper.make_graph(ops, name, [input_type], outputs)
1906
+ model_def = helper.make_model(
1907
+ graph_def,
1908
+ producer_name="perpetual",
1909
+ opset_imports=[
1910
+ helper.make_opsetid("", 13),
1911
+ helper.make_opsetid("ai.onnx.ml", 2),
1912
+ ],
1913
+ )
1914
+ model_def.ir_version = 6
1915
+ onnx.save(model_def, path)