autogluon.tabular 1.2.1b20250205__py3-none-any.whl → 1.2.1b20250207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4764 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ import os
6
+ import shutil
7
+ import time
8
+ import traceback
9
+ from collections import defaultdict
10
+ from pathlib import Path
11
+ from typing import Any, Literal
12
+
13
+ import networkx as nx
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from autogluon.common.features.feature_metadata import FeatureMetadata
18
+ from autogluon.common.features.types import R_FLOAT, S_STACK
19
+ from autogluon.common.utils.distribute_utils import DistributedContext
20
+ from autogluon.common.utils.lite import disable_if_lite_mode
21
+ from autogluon.common.utils.log_utils import convert_time_in_s_to_log_friendly, reset_logger_for_remote_call
22
+ from autogluon.common.utils.resource_utils import ResourceManager, get_resource_manager
23
+ from autogluon.common.utils.try_import import try_import_ray, try_import_torch
24
+ from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels
25
+ from autogluon.core.calibrate import calibrate_decision_threshold
26
+ from autogluon.core.calibrate.conformity_score import compute_conformity_score
27
+ from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling
28
+ from autogluon.core.callbacks import AbstractCallback
29
+ from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS
30
+ from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary
31
+ from autogluon.core.metrics import Scorer, compute_metric, get_metric
32
+ from autogluon.core.models import (
33
+ AbstractModel,
34
+ BaggedEnsembleModel,
35
+ GreedyWeightedEnsembleModel,
36
+ SimpleWeightedEnsembleModel,
37
+ StackerEnsembleModel,
38
+ WeightedEnsembleModel,
39
+ )
40
+ from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match
41
+ from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager
42
+ from autogluon.core.trainer import AbstractTrainer
43
+ from autogluon.core.trainer.utils import process_hyperparameters
44
+ from autogluon.core.utils import (
45
+ compute_permutation_feature_importance,
46
+ convert_pred_probas_to_df,
47
+ default_holdout_frac,
48
+ extract_column,
49
+ generate_train_test_split,
50
+ get_pred_from_proba,
51
+ infer_eval_metric,
52
+ )
53
+ from autogluon.core.utils.exceptions import (
54
+ InsufficientTime,
55
+ NoGPUError,
56
+ NoStackFeatures,
57
+ NotEnoughCudaMemoryError,
58
+ NotEnoughMemoryError,
59
+ NotValidStacker,
60
+ NoValidFeatures,
61
+ TimeLimitExceeded,
62
+ )
63
+ from autogluon.core.utils.feature_selection import FeatureSelector
64
+ from autogluon.core.utils.loaders import load_pkl
65
+ from autogluon.core.utils.savers import save_pkl
66
+
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
72
+ """
73
+ AbstractTabularTrainer contains logic to train a variety of models under a variety of constraints and automatically generate a multi-layer stack ensemble.
74
+ Beyond the basic functionality, it also has support for model refitting, distillation, pseudo-labelling, unlabeled data, and much more.
75
+
76
+ It is not recommended to directly use Trainer. Instead, use Predictor or Learner which internally uses Trainer.
77
+ This documentation is for developers. Users should avoid this class.
78
+
79
+ Due to the complexity of the logic within this class, a text description will not give the full picture.
80
+ It is recommended to carefully read the code and use a debugger to understand how it works.
81
+
82
+ AbstractTabularTrainer makes much fewer assumptions about the problem than Learner and Predictor.
83
+ It expects these ambiguities to have already been resolved upstream. For example, problem_type, feature_metadata, num_classes, etc.
84
+
85
+ Parameters
86
+ ----------
87
+ path : str
88
+ Path to save and load trainer artifacts to disk.
89
+ Path should end in `/` or `os.path.sep()`.
90
+ problem_type : str
91
+ One of ['binary', 'multiclass', 'regression', 'quantile', 'softclass']
92
+ num_classes : int
93
+ The number of classes in the problem.
94
+ If problem_type is in ['regression', 'quantile'], this must be None.
95
+ If problem_type is 'binary', this must be 2.
96
+ If problem_type is in ['multiclass', 'softclass'], this must be >= 2.
97
+ feature_metadata : FeatureMetadata
98
+ FeatureMetadata for X. Sent to each model during fit.
99
+ eval_metric : Scorer, default = None
100
+ Metric to optimize. If None, a default metric is used depending on the problem_type.
101
+ quantile_levels : list[float] | np.ndarray, default = None
102
+ # TODO: Add documentation, not documented in Predictor.
103
+ Only used when problem_type=quantile
104
+ low_memory : bool, default = True
105
+ Deprecated parameter, likely to be removed in future versions.
106
+ If True, caches models to disk separately instead of containing all models within memory.
107
+ If False, may cause a variety of bugs.
108
+ k_fold : int, default = 0
109
+ If <2, then non-bagged mode is used.
110
+ If >= 2, then bagged mode is used with num_bag_folds == k_fold for each model.
111
+ Bagged mode changes the way models are trained and ensembled.
112
+ Bagged mode enables multi-layer stacking and repeated bagging.
113
+ n_repeats : int, default = 1
114
+ The maximum repeats of bagging to do when in bagged mode.
115
+ Larger values take linearly longer to train and infer, but improves quality slightly.
116
+ sample_weight : str, default = None
117
+ Column name of the sample weight in X
118
+ weight_evaluation : bool, default = False
119
+ If True, the eval_metric is calculated with sample_weight incorporated into the score.
120
+ save_data : bool, default = True
121
+ Whether to cache the data (X, y, X_val, y_val) to disk.
122
+ Required for a variety of advanced post-fit functionality.
123
+ It is recommended to keep as True.
124
+ random_state : int, default = 0
125
+ Random state for data splitting in bagged mode.
126
+ verbosity : int, default = 2
127
+ Verbosity levels range from 0 to 4 and control how much information is printed.
128
+ Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings).
129
+ If using logging, you can alternatively control amount of information printed via `logger.setLevel(L)`,
130
+ where `L` ranges from 0 to 50 (Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels).
131
+ """
132
+
133
+ distill_stackname = "distill" # name of stack-level for distilled student models
134
+
135
+ def __init__(
136
+ self,
137
+ path: str,
138
+ *,
139
+ problem_type: str,
140
+ num_classes: int | None = None,
141
+ feature_metadata: FeatureMetadata | None = None,
142
+ eval_metric: Scorer | None = None,
143
+ quantile_levels: list[float] | np.ndarray | None = None,
144
+ low_memory: bool = True,
145
+ k_fold: int = 0,
146
+ n_repeats: int = 1,
147
+ sample_weight: str | None = None,
148
+ weight_evaluation: bool = False,
149
+ save_data: bool = False,
150
+ random_state: int = 0,
151
+ verbosity: int = 2,
152
+ ):
153
+ super().__init__(
154
+ path=path,
155
+ low_memory=low_memory,
156
+ save_data=save_data,
157
+ )
158
+ self._validate_num_classes(num_classes=num_classes, problem_type=problem_type)
159
+ self._validate_quantile_levels(quantile_levels=quantile_levels, problem_type=problem_type)
160
+ self.problem_type = problem_type
161
+ self.feature_metadata = feature_metadata
162
+
163
+ #: Integer value added to the stack level to get the random_state for kfold splits or the train/val split if bagging is disabled
164
+ self.random_state = random_state
165
+ self.verbosity = verbosity
166
+
167
+ # TODO: consider redesign where Trainer doesn't need sample_weight column name and weights are separate from X
168
+ self.sample_weight = sample_weight
169
+ self.weight_evaluation = weight_evaluation
170
+ if eval_metric is not None:
171
+ self.eval_metric = eval_metric
172
+ else:
173
+ self.eval_metric = infer_eval_metric(problem_type=self.problem_type)
174
+
175
+ logger.log(
176
+ 20, f"AutoGluon will gauge predictive performance using evaluation metric: '{self.eval_metric.name}'"
177
+ )
178
+ if not self.eval_metric.greater_is_better_internal:
179
+ logger.log(
180
+ 20,
181
+ "\tThis metric's sign has been flipped to adhere to being higher_is_better. "
182
+ "The metric score can be multiplied by -1 to get the metric value.",
183
+ )
184
+ if not (self.eval_metric.needs_pred or self.eval_metric.needs_quantile):
185
+ logger.log(
186
+ 20,
187
+ "\tThis metric expects predicted probabilities rather than predicted class labels, "
188
+ "so you'll need to use predict_proba() instead of predict()",
189
+ )
190
+
191
+ logger.log(20, "\tTo change this, specify the eval_metric parameter of Predictor()")
192
+ self.num_classes = num_classes
193
+ self.quantile_levels = quantile_levels
194
+
195
+ #: will be set to True if feature-pruning is turned on.
196
+ self.feature_prune = False
197
+
198
+ self.bagged_mode = True if k_fold >= 2 else False
199
+ if self.bagged_mode:
200
+ #: int number of folds to do model bagging, < 2 means disabled
201
+ self.k_fold = k_fold
202
+ self.n_repeats = n_repeats
203
+ else:
204
+ self.k_fold = 0
205
+ self.n_repeats = 1
206
+
207
+ #: Internal float of the total time limit allowed for a given fit call. Used in logging statements.
208
+ self._time_limit = None
209
+ #: Internal timestamp of the time training started for a given fit call. Used in logging statements.
210
+ self._time_train_start = None
211
+ #: Same as `self._time_train_start` except it is not reset to None after the fit call completes.
212
+ self._time_train_start_last = None
213
+
214
+ self._num_rows_train = None
215
+ self._num_cols_train = None
216
+ self._num_rows_val = None
217
+ self._num_rows_test = None
218
+
219
+ self.is_data_saved = False
220
+ self._X_saved = False
221
+ self._y_saved = False
222
+ self._X_val_saved = False
223
+ self._y_val_saved = False
224
+
225
+ #: custom split indices
226
+ self._groups = None
227
+
228
+ #: whether to treat regression predictions as class-probabilities (during distillation)
229
+ self._regress_preds_asprobas = False
230
+
231
+ #: dict of model name -> model failure metadata
232
+ self._models_failed_to_train_errors = dict()
233
+
234
+ # self._exceptions_list = [] # TODO: Keep exceptions list for debugging during benchmarking.
235
+
236
+ self.callbacks: list[AbstractCallback] = []
237
+ self._callback_early_stop = False
238
+
239
+ @property
240
+ def _path_attr(self) -> str:
241
+ """Path to cached model graph attributes"""
242
+ return os.path.join(self.path_utils, "attr")
243
+
244
+ @property
245
+ def has_val(self) -> bool:
246
+ """Whether the trainer uses validation data"""
247
+ return self._num_rows_val is not None
248
+
249
+ @property
250
+ def num_rows_val_for_calibration(self) -> int:
251
+ """The number of rows available to optimize model calibration"""
252
+ if self._num_rows_val is not None:
253
+ return self._num_rows_val
254
+ elif self.bagged_mode:
255
+ assert self._num_rows_train is not None
256
+ return self._num_rows_train
257
+ else:
258
+ return 0
259
+
260
+ @property
261
+ def time_left(self) -> float | None:
262
+ """
263
+ Remaining time left in the fit call.
264
+ None if time_limit was unspecified.
265
+ """
266
+ if self._time_train_start is None:
267
+ return None
268
+ elif self._time_limit is None:
269
+ return None
270
+ time_elapsed = time.time() - self._time_train_start
271
+ time_left = self._time_limit - time_elapsed
272
+ return time_left
273
+
274
+ @property
275
+ def logger(self) -> logging.Logger:
276
+ return logger
277
+
278
+ def log(self, level: int, msg, *args, **kwargs):
279
+ self.logger.log(level, msg, *args, **kwargs)
280
+
281
+ def load_X(self):
282
+ if self._X_saved:
283
+ path = os.path.join(self.path_data, "X.pkl")
284
+ return load_pkl.load(path=path)
285
+ return None
286
+
287
+ def load_X_val(self):
288
+ if self._X_val_saved:
289
+ path = os.path.join(self.path_data, "X_val.pkl")
290
+ return load_pkl.load(path=path)
291
+ return None
292
+
293
+ def load_y(self):
294
+ if self._y_saved:
295
+ path = os.path.join(self.path_data, "y.pkl")
296
+ return load_pkl.load(path=path)
297
+ return None
298
+
299
+ def load_y_val(self):
300
+ if self._y_val_saved:
301
+ path = os.path.join(self.path_data, "y_val.pkl")
302
+ return load_pkl.load(path=path)
303
+ return None
304
+
305
+ def load_data(self):
306
+ X = self.load_X()
307
+ y = self.load_y()
308
+ X_val = self.load_X_val()
309
+ y_val = self.load_y_val()
310
+
311
+ return X, y, X_val, y_val
312
+
313
+ def save_X(self, X, verbose=True):
314
+ path = os.path.join(self.path_data, "X.pkl")
315
+ save_pkl.save(path=path, object=X, verbose=verbose)
316
+ self._X_saved = True
317
+
318
+ def save_X_val(self, X, verbose=True):
319
+ path = os.path.join(self.path_data, "X_val.pkl")
320
+ save_pkl.save(path=path, object=X, verbose=verbose)
321
+ self._X_val_saved = True
322
+
323
+ def save_X_test(self, X, verbose=True):
324
+ path = os.path.join(self.path_data, "X_test.pkl")
325
+ save_pkl.save(path=path, object=X, verbose=verbose)
326
+ self._X_test_saved = True
327
+
328
+ def save_y(self, y, verbose=True):
329
+ path = os.path.join(self.path_data, "y.pkl")
330
+ save_pkl.save(path=path, object=y, verbose=verbose)
331
+ self._y_saved = True
332
+
333
+ def save_y_val(self, y, verbose=True):
334
+ path = os.path.join(self.path_data, "y_val.pkl")
335
+ save_pkl.save(path=path, object=y, verbose=verbose)
336
+ self._y_val_saved = True
337
+
338
+ def save_y_test(self, y, verbose=True):
339
+ path = os.path.join(self.path_data, "y_test.pkl")
340
+ save_pkl.save(path=path, object=y, verbose=verbose)
341
+ self._y_test_saved = True
342
+
343
+ def get_model_names(
344
+ self,
345
+ stack_name: list[str] | str | None = None,
346
+ level: list[int] | int | None = None,
347
+ can_infer: bool | None = None,
348
+ models: list[str] | None = None
349
+ ) -> list[str]:
350
+ if models is None:
351
+ models = list(self.model_graph.nodes)
352
+ if stack_name is not None:
353
+ if not isinstance(stack_name, list):
354
+ stack_name = [stack_name]
355
+ node_attributes: dict = self.get_models_attribute_dict(attribute="stack_name", models=models)
356
+ models = [model_name for model_name in models if node_attributes[model_name] in stack_name]
357
+ if level is not None:
358
+ if not isinstance(level, list):
359
+ level = [level]
360
+ node_attributes: dict = self.get_models_attribute_dict(attribute="level", models=models)
361
+ models = [model_name for model_name in models if node_attributes[model_name] in level]
362
+ # TODO: can_infer is technically more complicated, if an ancestor can't infer then the model can't infer.
363
+ if can_infer is not None:
364
+ node_attributes = self.get_models_attribute_full(attribute="can_infer", models=models, func=min)
365
+ models = [model for model in models if node_attributes[model] == can_infer]
366
+ return models
367
+
368
+ def get_max_level(self, stack_name: str | None = None, models: list[str] | None = None) -> int:
369
+ models = self.get_model_names(stack_name=stack_name, models=models)
370
+ models_attribute_dict = self.get_models_attribute_dict(attribute="level", models=models)
371
+ if models_attribute_dict:
372
+ return max(list(models_attribute_dict.values()))
373
+ else:
374
+ return -1
375
+
376
+ def construct_model_templates(self, hyperparameters: dict[str, Any]) -> tuple[list[AbstractModel], dict]:
377
+ """Constructs a list of unfit models based on the hyperparameters dict."""
378
+ raise NotImplementedError
379
+
380
+ def construct_model_templates_distillation(self, hyperparameters: dict, **kwargs) -> tuple[list[AbstractModel], dict]:
381
+ """Constructs a list of unfit models based on the hyperparameters dict for softclass distillation."""
382
+ raise NotImplementedError
383
+
384
+ def get_model_level(self, model_name: str) -> int:
385
+ return self.get_model_attribute(model=model_name, attribute="level")
386
+
387
+ def fit(self, X, y, hyperparameters: dict, X_val=None, y_val=None, **kwargs):
388
+ raise NotImplementedError
389
+
390
+ # TODO: Enable easier re-mapping of trained models -> hyperparameters input (They don't share a key since name can change)
391
+ def train_multi_levels(
392
+ self,
393
+ X,
394
+ y,
395
+ hyperparameters: dict,
396
+ X_val=None,
397
+ y_val=None,
398
+ X_test=None,
399
+ y_test=None,
400
+ X_unlabeled=None,
401
+ base_model_names: list[str] | None = None,
402
+ core_kwargs: dict | None = None,
403
+ aux_kwargs: dict | None = None,
404
+ level_start=1,
405
+ level_end=1,
406
+ time_limit=None,
407
+ name_suffix: str | None = None,
408
+ relative_stack=True,
409
+ level_time_modifier=0.333,
410
+ infer_limit=None,
411
+ infer_limit_batch_size=None,
412
+ callbacks: list[AbstractCallback] | None = None,
413
+ ) -> list[str]:
414
+ """
415
+ Trains a multi-layer stack ensemble using the input data on the hyperparameters dict input.
416
+ hyperparameters is used to determine the models used in each stack layer.
417
+ If continuing a stack ensemble with level_start>1, ensure that base_model_names is set to the appropriate base models that will be used by the level_start level models.
418
+ Trains both core and aux models.
419
+ core models are standard models which are fit on the data features. Core models will also use model predictions if base_model_names was specified or if level != 1.
420
+ aux models are ensemble models which only use the predictions of core models as features. These models never use the original features.
421
+
422
+ level_time_modifier : float, default 0.333
423
+ The amount of extra time given relatively to early stack levels compared to later stack levels.
424
+ If 0, then all stack levels are given 100%/L of the time, where L is the number of stack levels.
425
+ If 1, then all stack levels are given 100% of the time, meaning if the first level uses all of the time given to it, the other levels won't train.
426
+ Time given to a level = remaining_time / remaining_levels * (1 + level_time_modifier), capped by total remaining time.
427
+
428
+ Returns a list of the model names that were trained from this method call, in order of fit.
429
+ """
430
+ self._fit_setup(time_limit=time_limit, callbacks=callbacks)
431
+ time_train_start = self._time_train_start
432
+ assert time_train_start is not None
433
+
434
+ if self.callbacks:
435
+ callback_classes = [c.__class__.__name__ for c in self.callbacks]
436
+ logger.log(20, f"User-specified callbacks ({len(self.callbacks)}): {callback_classes}")
437
+
438
+ hyperparameters = self._process_hyperparameters(hyperparameters=hyperparameters)
439
+
440
+ if relative_stack:
441
+ if level_start != 1:
442
+ raise AssertionError(f"level_start must be 1 when `relative_stack=True`. (level_start = {level_start})")
443
+ level_add = 0
444
+ if base_model_names:
445
+ max_base_model_level = self.get_max_level(models=base_model_names)
446
+ level_start = max_base_model_level + 1
447
+ level_add = level_start - 1
448
+ level_end += level_add
449
+ if level_start != 1:
450
+ hyperparameters_relative = {}
451
+ for key in hyperparameters:
452
+ if isinstance(key, int):
453
+ hyperparameters_relative[key + level_add] = hyperparameters[key]
454
+ else:
455
+ hyperparameters_relative[key] = hyperparameters[key]
456
+ hyperparameters = hyperparameters_relative
457
+
458
+ core_kwargs = {} if core_kwargs is None else core_kwargs.copy()
459
+ aux_kwargs = {} if aux_kwargs is None else aux_kwargs.copy()
460
+
461
+ self._callbacks_setup(
462
+ X=X,
463
+ y=y,
464
+ hyperparameters=hyperparameters,
465
+ X_val=X_val,
466
+ y_val=y_val,
467
+ X_unlabeled=X_unlabeled,
468
+ level_start=level_start,
469
+ level_end=level_end,
470
+ time_limit=time_limit,
471
+ base_model_names=base_model_names,
472
+ core_kwargs=core_kwargs,
473
+ aux_kwargs=aux_kwargs,
474
+ name_suffix=name_suffix,
475
+ level_time_modifier=level_time_modifier,
476
+ infer_limit=infer_limit,
477
+ infer_limit_batch_size=infer_limit_batch_size,
478
+ )
479
+ # TODO: Add logic for callbacks to specify that the rest of the trainer logic should be skipped in the case where they are overriding the trainer logic.
480
+
481
+ model_names_fit = []
482
+ if level_start != level_end:
483
+ logger.log(20, f"AutoGluon will fit {level_end - level_start + 1} stack levels (L{level_start} to L{level_end}) ...")
484
+ for level in range(level_start, level_end + 1):
485
+ core_kwargs_level = core_kwargs.copy()
486
+ aux_kwargs_level = aux_kwargs.copy()
487
+ full_weighted_ensemble = aux_kwargs_level.pop("fit_full_last_level_weighted_ensemble", True) and (level == level_end) and (level > 1)
488
+ additional_full_weighted_ensemble = aux_kwargs_level.pop("full_weighted_ensemble_additionally", False) and full_weighted_ensemble
489
+ if time_limit is not None:
490
+ time_train_level_start = time.time()
491
+ levels_left = level_end - level + 1
492
+ time_left = time_limit - (time_train_level_start - time_train_start)
493
+ time_limit_for_level = min(time_left / levels_left * (1 + level_time_modifier), time_left)
494
+ time_limit_core = time_limit_for_level
495
+ time_limit_aux = max(time_limit_for_level * 0.1, min(time_limit, 360)) # Allows aux to go over time_limit, but only by a small amount
496
+ core_kwargs_level["time_limit"] = core_kwargs_level.get("time_limit", time_limit_core)
497
+ aux_kwargs_level["time_limit"] = aux_kwargs_level.get("time_limit", time_limit_aux)
498
+ base_model_names, aux_models = self.stack_new_level(
499
+ X=X,
500
+ y=y,
501
+ X_val=X_val,
502
+ y_val=y_val,
503
+ X_test=X_test,
504
+ y_test=y_test,
505
+ X_unlabeled=X_unlabeled,
506
+ models=hyperparameters,
507
+ level=level,
508
+ base_model_names=base_model_names,
509
+ core_kwargs=core_kwargs_level,
510
+ aux_kwargs=aux_kwargs_level,
511
+ name_suffix=name_suffix,
512
+ infer_limit=infer_limit,
513
+ infer_limit_batch_size=infer_limit_batch_size,
514
+ full_weighted_ensemble=full_weighted_ensemble,
515
+ additional_full_weighted_ensemble=additional_full_weighted_ensemble,
516
+ )
517
+ model_names_fit += base_model_names + aux_models
518
+ if (self.model_best is None or infer_limit is not None) and len(model_names_fit) != 0:
519
+ self.model_best = self.get_model_best(infer_limit=infer_limit, infer_limit_as_child=True)
520
+ self._callbacks_conclude()
521
+ self._fit_cleanup()
522
+ self.save()
523
+ return model_names_fit
524
+
525
+ def _fit_setup(self, time_limit: float | None = None, callbacks: list[AbstractCallback] | None = None):
526
+ """
527
+ Prepare the trainer state at the start of / prior to a fit call.
528
+ Should be paired with a `self._fit_cleanup()` at the conclusion of the fit call.
529
+ """
530
+ self._time_train_start = time.time()
531
+ self._time_train_start_last = self._time_train_start
532
+ self._time_limit = time_limit
533
+ self.reset_callbacks()
534
+ if callbacks is not None:
535
+ assert isinstance(callbacks, list), f"`callbacks` must be a list. Found invalid type: `{type(callbacks)}`."
536
+ for callback in callbacks:
537
+ assert isinstance(
538
+ callback, AbstractCallback
539
+ ), f"Elements in `callbacks` must be of type AbstractCallback. Found invalid type: `{type(callback)}`."
540
+ else:
541
+ callbacks = []
542
+ self.callbacks = callbacks
543
+
544
+ def _fit_cleanup(self):
545
+ """
546
+ Cleanup the trainer state after fit call completes.
547
+ This ensures that future fit calls are not corrupted by prior fit calls.
548
+ Should be paired with an earlier `self._fit_setup()` call.
549
+ """
550
+ self._time_limit = None
551
+ self._time_train_start = None
552
+ self.reset_callbacks()
553
+
554
+ def _callbacks_setup(self, **kwargs):
555
+ for callback in self.callbacks:
556
+ callback.before_trainer_fit(trainer=self, **kwargs)
557
+
558
+ def _callbacks_conclude(self):
559
+ for callback in self.callbacks:
560
+ callback.after_trainer_fit(trainer=self)
561
+
562
+ def reset_callbacks(self):
563
+ """Deletes callback objects and resets `self._callback_early_stop` to False."""
564
+ self.callbacks = []
565
+ self._callback_early_stop = False
566
+
567
+ # TODO: Consider better greedy approximation method such as via fitting a weighted ensemble to evaluate the value of a subset.
568
+ def _filter_base_models_via_infer_limit(
569
+ self,
570
+ base_model_names: list[str],
571
+ infer_limit: float | None,
572
+ infer_limit_modifier: float = 1.0,
573
+ as_child: bool = True,
574
+ verbose: bool = True,
575
+ ) -> list[str]:
576
+ """
577
+ Returns a subset of base_model_names whose combined prediction time for 1 row of data does not exceed infer_limit seconds.
578
+ With the goal of selecting the best valid subset that is most valuable to stack ensembles who use them as base models,
579
+ this is a variant of the constrained knapsack problem and is NP-Hard and infeasible to exactly solve even with fewer than 10 models.
580
+ For practical purposes, this method applies a greedy approximation approach to selecting the subset
581
+ by simply removing models in reverse order of validation score until the remaining subset is valid.
582
+
583
+ Parameters
584
+ ----------
585
+ base_model_names: list[str]
586
+ list of model names. These models must already be added to the trainer.
587
+ infer_limit: float, optional
588
+ Inference limit in seconds for 1 row of data. This is compared against values pre-computed during fit for the models.
589
+ infer_limit_modifier: float, default = 1.0
590
+ Modifier to multiply infer_limit by.
591
+ Set to <1.0 to provide headroom for stack models who take the returned subset as base models
592
+ so that the stack models are less likely to exceed infer_limit.
593
+ as_child: bool, default = True
594
+ If True, use the inference time of only 1 child model for bags instead of the overall inference time of the bag.
595
+ This is useful if the intent is to refit the models, as this will best estimate the inference time of the refit model.
596
+ verbose: bool, default = True
597
+ Whether to log the models that are removed.
598
+
599
+ Returns
600
+ -------
601
+ Returns valid subset of models that satisfy constraints.
602
+ """
603
+ if infer_limit is None or not base_model_names:
604
+ return base_model_names
605
+
606
+ base_model_names = base_model_names.copy()
607
+ num_models_og = len(base_model_names)
608
+ infer_limit_threshold = infer_limit * infer_limit_modifier # Add headroom
609
+
610
+ if as_child:
611
+ attribute = "predict_1_child_time"
612
+ else:
613
+ attribute = "predict_1_time"
614
+
615
+ predict_1_time_full_set = self.get_model_attribute_full(model=base_model_names, attribute=attribute)
616
+
617
+ messages_to_log = []
618
+
619
+ base_model_names_copy = base_model_names.copy()
620
+ # Prune models that by themselves have larger inference latency than the infer_limit, as they can never be valid
621
+ for base_model_name in base_model_names_copy:
622
+ predict_1_time_full = self.get_model_attribute_full(model=base_model_name, attribute=attribute)
623
+ if predict_1_time_full >= infer_limit_threshold:
624
+ predict_1_time_full_set_old = predict_1_time_full_set
625
+ base_model_names.remove(base_model_name)
626
+ predict_1_time_full_set = self.get_model_attribute_full(model=base_model_names, attribute=attribute)
627
+ if verbose:
628
+ predict_1_time_full_set_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set)
629
+ predict_1_time_full_set_old_log, time_unit_old = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set_old)
630
+ messages_to_log.append(
631
+ f"\t{round(predict_1_time_full_set_old_log, 3)}{time_unit_old}\t-> {round(predict_1_time_full_set_log, 3)}{time_unit}\t({base_model_name})"
632
+ )
633
+
634
+ score_val_dict = self.get_models_attribute_dict(attribute="val_score", models=base_model_names)
635
+ sorted_scores = sorted(score_val_dict.items(), key=lambda x: x[1])
636
+ i = 0
637
+ # Prune models by ascending validation score until the remaining subset's combined inference latency satisfies infer_limit
638
+ while base_model_names and (predict_1_time_full_set >= infer_limit_threshold):
639
+ # TODO: Incorporate score vs inference speed tradeoff in a smarter way
640
+ base_model_to_remove = sorted_scores[i][0]
641
+ predict_1_time_full_set_old = predict_1_time_full_set
642
+ base_model_names.remove(base_model_to_remove)
643
+ i += 1
644
+ predict_1_time_full_set = self.get_model_attribute_full(model=base_model_names, attribute=attribute)
645
+ if verbose:
646
+ predict_1_time_full_set_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set)
647
+ predict_1_time_full_set_old_log, time_unit_old = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_set_old)
648
+ messages_to_log.append(
649
+ f"\t{round(predict_1_time_full_set_old_log, 3)}{time_unit_old}\t-> {round(predict_1_time_full_set_log, 3)}{time_unit}\t({base_model_to_remove})"
650
+ )
651
+
652
+ if messages_to_log:
653
+ infer_limit_threshold_log, time_unit_threshold = convert_time_in_s_to_log_friendly(time_in_sec=infer_limit_threshold)
654
+ logger.log(
655
+ 20,
656
+ f"Removing {len(messages_to_log)}/{num_models_og} base models to satisfy inference constraint "
657
+ f"(constraint={round(infer_limit_threshold_log, 3)}{time_unit_threshold}) ...",
658
+ )
659
+ for msg in messages_to_log:
660
+ logger.log(20, msg)
661
+
662
+ return base_model_names
663
+
664
+ def stack_new_level(
665
+ self,
666
+ X,
667
+ y,
668
+ models: list[AbstractModel] | dict,
669
+ X_val=None,
670
+ y_val=None,
671
+ X_test=None,
672
+ y_test=None,
673
+ X_unlabeled=None,
674
+ level=1,
675
+ base_model_names: list[str] | None = None,
676
+ core_kwargs: dict | None = None,
677
+ aux_kwargs: dict | None = None,
678
+ name_suffix: str | None = None,
679
+ infer_limit=None,
680
+ infer_limit_batch_size=None,
681
+ full_weighted_ensemble: bool = False,
682
+ additional_full_weighted_ensemble: bool = False,
683
+ ) -> tuple[list[str], list[str]]:
684
+ """
685
+ Similar to calling self.stack_new_level_core, except auxiliary models will also be trained via a call to self.stack_new_level_aux, with the models trained from self.stack_new_level_core used as base models.
686
+ """
687
+ if base_model_names is None:
688
+ base_model_names = []
689
+ core_kwargs = {} if core_kwargs is None else core_kwargs.copy()
690
+ aux_kwargs = {} if aux_kwargs is None else aux_kwargs.copy()
691
+ if level < 1:
692
+ raise AssertionError(f"Stack level must be >= 1, but level={level}.")
693
+ if base_model_names and level == 1:
694
+ raise AssertionError(f"Stack level 1 models cannot have base models, but base_model_names={base_model_names}.")
695
+ if name_suffix:
696
+ core_kwargs["name_suffix"] = core_kwargs.get("name_suffix", "") + name_suffix
697
+ aux_kwargs["name_suffix"] = aux_kwargs.get("name_suffix", "") + name_suffix
698
+ core_models = self.stack_new_level_core(
699
+ X=X,
700
+ y=y,
701
+ X_val=X_val,
702
+ y_val=y_val,
703
+ X_test=X_test,
704
+ y_test=y_test,
705
+ X_unlabeled=X_unlabeled,
706
+ models=models,
707
+ level=level,
708
+ infer_limit=infer_limit,
709
+ infer_limit_batch_size=infer_limit_batch_size,
710
+ base_model_names=base_model_names,
711
+ **core_kwargs,
712
+ )
713
+
714
+ aux_models = []
715
+ if full_weighted_ensemble:
716
+ full_aux_kwargs = aux_kwargs.copy()
717
+ if additional_full_weighted_ensemble:
718
+ full_aux_kwargs["name_extra"] = "_ALL"
719
+ all_base_model_names = self.get_model_names(stack_name="core") # Fit weighted ensemble on all previously fitted core models
720
+ aux_models += self._stack_new_level_aux(X_val, y_val, X, y, all_base_model_names, level, infer_limit, infer_limit_batch_size, **full_aux_kwargs)
721
+
722
+ if (not full_weighted_ensemble) or additional_full_weighted_ensemble:
723
+ aux_models += self._stack_new_level_aux(X_val, y_val, X, y, core_models, level, infer_limit, infer_limit_batch_size, **aux_kwargs)
724
+
725
+ return core_models, aux_models
726
+
727
+ def stack_new_level_core(
728
+ self,
729
+ X,
730
+ y,
731
+ models: list[AbstractModel] | dict,
732
+ X_val=None,
733
+ y_val=None,
734
+ X_test=None,
735
+ y_test=None,
736
+ X_unlabeled=None,
737
+ level=1,
738
+ base_model_names: list[str] | None = None,
739
+ fit_strategy: Literal["sequential", "parallel"] = "sequential",
740
+ stack_name="core",
741
+ ag_args=None,
742
+ ag_args_fit=None,
743
+ ag_args_ensemble=None,
744
+ included_model_types=None,
745
+ excluded_model_types=None,
746
+ ensemble_type=StackerEnsembleModel,
747
+ name_suffix: str | None = None,
748
+ get_models_func=None,
749
+ refit_full=False,
750
+ infer_limit=None,
751
+ infer_limit_batch_size=None,
752
+ **kwargs,
753
+ ) -> list[str]:
754
+ """
755
+ Trains all models using the data provided.
756
+ If level > 1, then the models will use base model predictions as additional features.
757
+ The base models used can be specified via base_model_names.
758
+ If self.bagged_mode, then models will be trained as StackerEnsembleModels.
759
+ The data provided in this method should not contain stack features, as they will be automatically generated if necessary.
760
+ """
761
+ if self._callback_early_stop:
762
+ return []
763
+ if get_models_func is None:
764
+ get_models_func = self.construct_model_templates
765
+ if base_model_names is None:
766
+ base_model_names = []
767
+ if not self.bagged_mode and level != 1:
768
+ raise ValueError("Stack Ensembling is not valid for non-bagged mode.")
769
+
770
+ base_model_names = self._filter_base_models_via_infer_limit(
771
+ base_model_names=base_model_names,
772
+ infer_limit=infer_limit,
773
+ infer_limit_modifier=0.8,
774
+ )
775
+ if ag_args_fit is None:
776
+ ag_args_fit = {}
777
+ ag_args_fit = ag_args_fit.copy()
778
+ if infer_limit_batch_size is not None:
779
+ ag_args_fit["predict_1_batch_size"] = infer_limit_batch_size
780
+
781
+ if isinstance(models, dict):
782
+ get_models_kwargs = dict(
783
+ level=level,
784
+ name_suffix=name_suffix,
785
+ ag_args=ag_args,
786
+ ag_args_fit=ag_args_fit,
787
+ included_model_types=included_model_types,
788
+ excluded_model_types=excluded_model_types,
789
+ )
790
+
791
+ if self.bagged_mode:
792
+ if level == 1:
793
+ (base_model_names, base_model_paths, base_model_types) = (None, None, None)
794
+ elif level > 1:
795
+ base_model_names, base_model_paths, base_model_types = self._get_models_load_info(model_names=base_model_names)
796
+ if len(base_model_names) == 0: # type: ignore
797
+ logger.log(20, f"No base models to train on, skipping stack level {level}...")
798
+ return []
799
+ else:
800
+ raise AssertionError(f"Stack level cannot be less than 1! level = {level}")
801
+
802
+ ensemble_kwargs = {
803
+ "base_model_names": base_model_names,
804
+ "base_model_paths_dict": base_model_paths,
805
+ "base_model_types_dict": base_model_types,
806
+ "base_model_types_inner_dict": self.get_models_attribute_dict(attribute="type_inner", models=base_model_names),
807
+ "base_model_performances_dict": self.get_models_attribute_dict(attribute="val_score", models=base_model_names),
808
+ "random_state": level + self.random_state,
809
+ }
810
+ get_models_kwargs.update(
811
+ dict(
812
+ ag_args_ensemble=ag_args_ensemble,
813
+ ensemble_type=ensemble_type,
814
+ ensemble_kwargs=ensemble_kwargs,
815
+ )
816
+ )
817
+ models, model_args_fit = get_models_func(hyperparameters=models, **get_models_kwargs)
818
+ if model_args_fit:
819
+ hyperparameter_tune_kwargs = {
820
+ model_name: model_args_fit[model_name]["hyperparameter_tune_kwargs"]
821
+ for model_name in model_args_fit
822
+ if "hyperparameter_tune_kwargs" in model_args_fit[model_name]
823
+ }
824
+ kwargs["hyperparameter_tune_kwargs"] = hyperparameter_tune_kwargs
825
+
826
+ logger.log(10 if ((not refit_full) and DistributedContext.is_distributed_mode()) else 20, f'Fitting {len(models)} L{level} models, fit_strategy="{fit_strategy}" ...')
827
+
828
+ X_init = self.get_inputs_to_stacker(X, base_models=base_model_names, fit=True)
829
+ feature_metadata = self.get_feature_metadata(use_orig_features=True, base_models=base_model_names)
830
+ if X_val is not None:
831
+ X_val = self.get_inputs_to_stacker(X_val, base_models=base_model_names, fit=False, use_val_cache=True)
832
+ if X_test is not None:
833
+ X_test = self.get_inputs_to_stacker(X_test, base_models=base_model_names, fit=False, use_val_cache=False)
834
+ compute_score = not refit_full
835
+ if refit_full and X_val is not None:
836
+ X_init = pd.concat([X_init, X_val])
837
+ y = pd.concat([y, y_val])
838
+ X_val = None
839
+ y_val = None
840
+ if X_unlabeled is not None:
841
+ X_unlabeled = self.get_inputs_to_stacker(X_unlabeled, base_models=base_model_names, fit=False)
842
+
843
+ fit_kwargs = dict(
844
+ num_classes=self.num_classes,
845
+ feature_metadata=feature_metadata,
846
+ )
847
+
848
+ # FIXME: TODO: v0.1 X_unlabeled isn't cached so it won't be available during refit_full or fit_extra.
849
+ return self._train_multi(
850
+ X=X_init,
851
+ y=y,
852
+ X_val=X_val,
853
+ y_val=y_val,
854
+ X_test=X_test,
855
+ y_test=y_test,
856
+ X_unlabeled=X_unlabeled,
857
+ models=models,
858
+ level=level,
859
+ stack_name=stack_name,
860
+ compute_score=compute_score,
861
+ fit_kwargs=fit_kwargs,
862
+ fit_strategy=fit_strategy,
863
+ **kwargs,
864
+ )
865
+
866
+ def _stack_new_level_aux(self, X_val, y_val, X, y, core_models, level, infer_limit, infer_limit_batch_size, **kwargs):
867
+ if X_val is None:
868
+ aux_models = self.stack_new_level_aux(
869
+ X=X, y=y, base_model_names=core_models, level=level + 1, infer_limit=infer_limit, infer_limit_batch_size=infer_limit_batch_size, **kwargs
870
+ )
871
+ else:
872
+ aux_models = self.stack_new_level_aux(
873
+ X=X_val,
874
+ y=y_val,
875
+ fit=False,
876
+ base_model_names=core_models,
877
+ level=level + 1,
878
+ infer_limit=infer_limit,
879
+ infer_limit_batch_size=infer_limit_batch_size,
880
+ **kwargs,
881
+ )
882
+ return aux_models
883
+
884
+ # TODO: Consider making level be auto-determined based off of max(base_model_levels)+1
885
+ # TODO: Remove name_suffix, hacked in
886
+ # TODO: X can be optional because it isn't needed if fit=True
887
+ def stack_new_level_aux(
888
+ self,
889
+ X,
890
+ y,
891
+ base_model_names: list[str],
892
+ level: int | str = "auto",
893
+ fit=True,
894
+ stack_name="aux1",
895
+ time_limit=None,
896
+ name_suffix: str | None = None,
897
+ get_models_func=None,
898
+ check_if_best=True,
899
+ infer_limit=None,
900
+ infer_limit_batch_size=None,
901
+ use_val_cache=True,
902
+ fit_weighted_ensemble: bool = True,
903
+ name_extra: str | None = None,
904
+ total_resources: dict | None = None,
905
+ ) -> list[str]:
906
+ """
907
+ Trains auxiliary models (currently a single weighted ensemble) using the provided base models.
908
+ Level must be greater than the level of any of the base models.
909
+ Auxiliary models never use the original features and only train with the predictions of other models as features.
910
+ """
911
+ if self._callback_early_stop:
912
+ return []
913
+ if fit_weighted_ensemble is False:
914
+ # Skip fitting of aux models
915
+ return []
916
+
917
+ base_model_names = self._filter_base_models_via_infer_limit(base_model_names=base_model_names, infer_limit=infer_limit, infer_limit_modifier=0.95)
918
+
919
+ if len(base_model_names) == 0:
920
+ logger.log(20, f"No base models to train on, skipping auxiliary stack level {level}...")
921
+ return []
922
+
923
+ if isinstance(level, str):
924
+ assert level == "auto", f"level must be 'auto' if str, found: {level}"
925
+ levels_dict = self.get_models_attribute_dict(attribute="level", models=base_model_names)
926
+ base_model_level_max = None
927
+ for k, v in levels_dict.items():
928
+ if base_model_level_max is None or v > base_model_level_max:
929
+ base_model_level_max = v
930
+ level = base_model_level_max + 1
931
+
932
+ if infer_limit_batch_size is not None:
933
+ ag_args_fit = dict()
934
+ ag_args_fit["predict_1_batch_size"] = infer_limit_batch_size
935
+ else:
936
+ ag_args_fit = None
937
+ X_stack_preds = self.get_inputs_to_stacker(X, base_models=base_model_names, fit=fit, use_orig_features=False, use_val_cache=use_val_cache)
938
+ if self.weight_evaluation:
939
+ X, w = extract_column(X, self.sample_weight) # TODO: consider redesign with w as separate arg instead of bundled inside X
940
+ if w is not None:
941
+ X_stack_preds[self.sample_weight] = w.values / w.mean()
942
+ child_hyperparameters = None
943
+ if name_extra is not None:
944
+ child_hyperparameters = {"ag_args": {"name_suffix": name_extra}}
945
+ return self.generate_weighted_ensemble(
946
+ X=X_stack_preds,
947
+ y=y,
948
+ level=level,
949
+ base_model_names=base_model_names,
950
+ k_fold=1,
951
+ n_repeats=1,
952
+ ag_args_fit=ag_args_fit,
953
+ stack_name=stack_name,
954
+ time_limit=time_limit,
955
+ name_suffix=name_suffix,
956
+ get_models_func=get_models_func,
957
+ check_if_best=check_if_best,
958
+ child_hyperparameters=child_hyperparameters,
959
+ total_resources=total_resources,
960
+ )
961
+
962
+ def predict(self, X: pd.DataFrame, model: str | None = None) -> np.ndarray:
963
+ if model is None:
964
+ model = self._get_best()
965
+ return self._predict_model(X=X, model=model)
966
+
967
+ def predict_proba(self, X: pd.DataFrame, model: str | None = None) -> np.ndarray:
968
+ if model is None:
969
+ model = self._get_best()
970
+ return self._predict_proba_model(X=X, model=model)
971
+
972
+ def _get_best(self) -> str:
973
+ if self.model_best is not None:
974
+ return self.model_best
975
+ else:
976
+ return self.get_model_best()
977
+
978
+ # Note: model_pred_proba_dict is mutated in this function to minimize memory usage
979
+ def get_inputs_to_model(
980
+ self,
981
+ model: str | AbstractModel,
982
+ X: pd.DataFrame,
983
+ model_pred_proba_dict: dict[str, np.ndarray] | None = None,
984
+ fit: bool = False,
985
+ preprocess_nonadaptive: bool = False,
986
+ ) -> pd.DataFrame:
987
+ """
988
+ For output X:
989
+ If preprocess_nonadaptive=False, call model.predict(X)
990
+ If preprocess_nonadaptive=True, call model.predict(X, preprocess_nonadaptive=False)
991
+ """
992
+ if isinstance(model, str):
993
+ # TODO: Remove unnecessary load when no stacking
994
+ model = self.load_model(model)
995
+ model_level = self.get_model_level(model.name)
996
+ if model_level > 1 and isinstance(model, StackerEnsembleModel):
997
+ if fit:
998
+ model_pred_proba_dict = None
999
+ else:
1000
+ model_set = self.get_minimum_model_set(model)
1001
+ model_set = [m for m in model_set if m != model.name] # TODO: Can probably be faster, get this result from graph
1002
+ model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=model_set, model_pred_proba_dict=model_pred_proba_dict)
1003
+ X = model.preprocess(X=X, preprocess_nonadaptive=preprocess_nonadaptive, fit=fit, model_pred_proba_dict=model_pred_proba_dict)
1004
+ elif preprocess_nonadaptive:
1005
+ X = model.preprocess(X=X, preprocess_stateful=False)
1006
+ return X
1007
+
1008
+ def score(
1009
+ self,
1010
+ X: pd.DataFrame,
1011
+ y: np.ndarray,
1012
+ model: str | None = None,
1013
+ metric: Scorer | None = None,
1014
+ weights: np.ndarray | None = None,
1015
+ as_error: bool = False,
1016
+ ) -> float:
1017
+ if metric is None:
1018
+ metric = self.eval_metric
1019
+ if metric.needs_pred or metric.needs_quantile:
1020
+ y_pred = self.predict(X=X, model=model)
1021
+ y_pred_proba = None
1022
+ else:
1023
+ y_pred = None
1024
+ y_pred_proba = self.predict_proba(X=X, model=model)
1025
+ return compute_metric(
1026
+ y=y,
1027
+ y_pred=y_pred,
1028
+ y_pred_proba=y_pred_proba,
1029
+ metric=metric,
1030
+ weights=weights,
1031
+ weight_evaluation=self.weight_evaluation,
1032
+ as_error=as_error,
1033
+ quantile_levels=self.quantile_levels,
1034
+ )
1035
+
1036
+ def score_with_y_pred_proba(
1037
+ self,
1038
+ y: np.ndarray,
1039
+ y_pred_proba: np.ndarray,
1040
+ metric: Scorer | None = None,
1041
+ weights: np.ndarray | None = None,
1042
+ as_error: bool = False,
1043
+ ) -> float:
1044
+ if metric is None:
1045
+ metric = self.eval_metric
1046
+ if metric.needs_pred or metric.needs_quantile:
1047
+ y_pred = get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=self.problem_type)
1048
+ y_pred_proba = None
1049
+ else:
1050
+ y_pred = None
1051
+ return compute_metric(
1052
+ y=y,
1053
+ y_pred=y_pred,
1054
+ y_pred_proba=y_pred_proba,
1055
+ metric=metric,
1056
+ weights=weights,
1057
+ weight_evaluation=self.weight_evaluation,
1058
+ as_error=as_error,
1059
+ quantile_levels=self.quantile_levels,
1060
+ )
1061
+
1062
+ def score_with_y_pred(
1063
+ self,
1064
+ y: np.ndarray,
1065
+ y_pred: np.ndarray,
1066
+ weights: np.ndarray | None = None,
1067
+ metric: Scorer | None = None,
1068
+ as_error: bool = False,
1069
+ ) -> float:
1070
+ if metric is None:
1071
+ metric = self.eval_metric
1072
+ return compute_metric(
1073
+ y=y,
1074
+ y_pred=y_pred,
1075
+ y_pred_proba=None,
1076
+ metric=metric,
1077
+ weights=weights,
1078
+ weight_evaluation=self.weight_evaluation,
1079
+ as_error=as_error,
1080
+ quantile_levels=self.quantile_levels,
1081
+ )
1082
+
1083
+ # TODO: Slow if large ensemble with many models, could cache output result to speed up during inference
1084
+ def _construct_model_pred_order(self, models: list[str]) -> list[str]:
1085
+ """
1086
+ Constructs a list of model names in order of inference calls required to infer on all the models.
1087
+
1088
+ Parameters
1089
+ ----------
1090
+ models : list[str]
1091
+ The list of models to construct the prediction order from.
1092
+ If a model has dependencies, the dependency models will be put earlier in the output list.
1093
+ Models explicitly mentioned in the `models` input will be placed as early as possible in the output list.
1094
+ Models earlier in `models` will attempt to be placed earlier in the output list than those later in `models`.
1095
+ It is recommended that earlier elements do not have dependency models that are listed later in `models`.
1096
+
1097
+ Returns
1098
+ -------
1099
+ Returns list of models in inference call order, including dependency models of those specified in the input.
1100
+ """
1101
+ model_set = set()
1102
+ model_order = []
1103
+ for model in models:
1104
+ if model in model_set:
1105
+ continue
1106
+ min_models_set = set(self.get_minimum_model_set(model))
1107
+ models_to_load = list(min_models_set.difference(model_set))
1108
+ subgraph = nx.subgraph(self.model_graph, models_to_load)
1109
+ model_pred_order = list(nx.lexicographical_topological_sort(subgraph))
1110
+ model_order += [m for m in model_pred_order if m not in model_set]
1111
+ model_set = set(model_order)
1112
+ return model_order
1113
+
1114
+ def _construct_model_pred_order_with_pred_dict(
1115
+ self, models: list[str], models_to_ignore: list[str] | None = None
1116
+ ) -> list[str]:
1117
+ """
1118
+ Constructs a list of model names in order of inference calls required to infer on all the models.
1119
+ Unlike `_construct_model_pred_order`, this method's output is in undefined order when multiple models are valid to infer at the same time.
1120
+
1121
+ Parameters
1122
+ ----------
1123
+ models : list[str]
1124
+ The list of models to construct the prediction order from.
1125
+ If a model has dependencies, the dependency models will be put earlier in the output list.
1126
+ models_to_ignore : list[str], optional
1127
+ A list of models that have already been computed and can be ignored.
1128
+ Models in this list and their dependencies (if not depended on by other models in `models`) will be pruned from the final output.
1129
+
1130
+ Returns
1131
+ -------
1132
+ Returns list of models in inference call order, including dependency models of those specified in the input.
1133
+ """
1134
+ model_set = set()
1135
+ for model in models:
1136
+ if model in model_set:
1137
+ continue
1138
+ min_model_set = set(self.get_minimum_model_set(model))
1139
+ model_set = model_set.union(min_model_set)
1140
+ if models_to_ignore is not None:
1141
+ model_set = model_set.difference(set(models_to_ignore))
1142
+ models_to_load = list(model_set)
1143
+ subgraph = nx.DiGraph(nx.subgraph(self.model_graph, models_to_load)) # Wrap subgraph in DiGraph to unfreeze it
1144
+ # For model in models_to_ignore, remove model node from graph and all ancestors that have no remaining descendants and are not in `models`
1145
+ models_to_ignore = [model for model in models_to_load if (model not in models) and (not list(subgraph.successors(model)))]
1146
+ while models_to_ignore:
1147
+ model = models_to_ignore[0]
1148
+ predecessors = list(subgraph.predecessors(model))
1149
+ subgraph.remove_node(model)
1150
+ models_to_ignore = models_to_ignore[1:]
1151
+ for predecessor in predecessors:
1152
+ if (predecessor not in models) and (not list(subgraph.successors(predecessor))) and (predecessor not in models_to_ignore):
1153
+ models_to_ignore.append(predecessor)
1154
+
1155
+ # Get model prediction order
1156
+ return list(nx.lexicographical_topological_sort(subgraph))
1157
+
1158
+ def get_models_attribute_dict(self, attribute: str, models: list | None = None) -> dict[str, Any]:
1159
+ """Returns dictionary of model name -> attribute value for the provided attribute.
1160
+ """
1161
+ models_attribute_dict = nx.get_node_attributes(self.model_graph, attribute)
1162
+ if models is not None:
1163
+ model_names = []
1164
+ for model in models:
1165
+ if not isinstance(model, str):
1166
+ model = model.name
1167
+ model_names.append(model)
1168
+ if attribute == "path":
1169
+ models_attribute_dict = {key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names}
1170
+ else:
1171
+ models_attribute_dict = {key: val for key, val in models_attribute_dict.items() if key in model_names}
1172
+ return models_attribute_dict
1173
+
1174
+ # TODO: Consider adding persist to disk functionality for pred_proba dictionary to lessen memory burden on large multiclass problems.
1175
+ # For datasets with 100+ classes, this function could potentially run the system OOM due to each pred_proba numpy array taking significant amounts of space.
1176
+ # This issue already existed in the previous level-based version but only had the minimum required predictions in memory at a time, whereas this has all model predictions in memory.
1177
+ # TODO: Add memory optimal topological ordering -> Minimize amount of pred_probas in memory at a time, delete pred probas that are no longer required
1178
+ def get_model_pred_proba_dict(
1179
+ self,
1180
+ X: pd.DataFrame,
1181
+ models: list[str],
1182
+ model_pred_proba_dict: dict | None = None,
1183
+ model_pred_time_dict: dict | None = None,
1184
+ record_pred_time: bool = False,
1185
+ use_val_cache: bool = False,
1186
+ ):
1187
+ """
1188
+ Optimally computes pred_probas (or predictions if regression) for each model in `models`.
1189
+ Will compute each necessary model only once and store predictions in a `model_pred_proba_dict` dictionary.
1190
+ Note: Mutates model_pred_proba_dict and model_pred_time_dict input if present to minimize memory usage.
1191
+
1192
+ Parameters
1193
+ ----------
1194
+ X : pd.DataFrame
1195
+ Input data to predict on.
1196
+ models : list[str]
1197
+ The list of models to predict with.
1198
+ Note that if models have dependency models, their dependencies will also be predicted with and included in the output.
1199
+ model_pred_proba_dict : dict, optional
1200
+ A dict of predict_probas that could have been computed by a prior call to `get_model_pred_proba_dict` to avoid redundant computations.
1201
+ Models already present in model_pred_proba_dict will not be predicted on.
1202
+ get_model_pred_proba_dict(X, models=['A', 'B', 'C']) is equivalent to
1203
+ get_model_pred_proba_dict(X, models=['C'], model_pred_proba_dict=get_model_pred_proba_dict(X, models=['A', 'B']))
1204
+ Note: Mutated in-place to minimize memory usage
1205
+ model_pred_time_dict : dict, optional
1206
+ If `record_pred_time==True`, this is a dict of model name to marginal time taken in seconds for the prediction of X.
1207
+ Must be specified alongside `model_pred_proba_dict` if `record_pred_time=True` and `model_pred_proba_dict != None`.
1208
+ Ignored if `record_pred_time=False`.
1209
+ Note: Mutated in-place to minimize memory usage
1210
+ record_pred_time : bool, default = False
1211
+ Whether to store marginal inference times of each model as an extra output `model_pred_time_dict`.
1212
+ use_val_cache : bool, default = False
1213
+ Whether to fetch cached val prediction probabilities for models instead of predicting on the data.
1214
+ Only set to True if X is equal to the validation data and you want to skip live predictions.
1215
+
1216
+ Returns
1217
+ -------
1218
+ If `record_pred_time==True`, outputs tuple of dicts (model_pred_proba_dict, model_pred_time_dict), else output only model_pred_proba_dict
1219
+ """
1220
+ if model_pred_proba_dict is None:
1221
+ model_pred_proba_dict = {}
1222
+ if model_pred_time_dict is None:
1223
+ model_pred_time_dict = {}
1224
+
1225
+ if use_val_cache:
1226
+ _, model_pred_proba_dict = self._update_pred_proba_dict_with_val_cache(model_set=set(models), model_pred_proba_dict=model_pred_proba_dict)
1227
+ if not model_pred_proba_dict:
1228
+ model_pred_order = self._construct_model_pred_order(models)
1229
+ else:
1230
+ model_pred_order = self._construct_model_pred_order_with_pred_dict(models, models_to_ignore=list(model_pred_proba_dict.keys()))
1231
+ if use_val_cache:
1232
+ model_set, model_pred_proba_dict = self._update_pred_proba_dict_with_val_cache(
1233
+ model_set=set(model_pred_order), model_pred_proba_dict=model_pred_proba_dict
1234
+ )
1235
+ model_pred_order = [model for model in model_pred_order if model in model_set]
1236
+
1237
+ # Compute model predictions in topological order
1238
+ for model_name in model_pred_order:
1239
+ if record_pred_time:
1240
+ time_start = time.time()
1241
+
1242
+ model = self.load_model(model_name=model_name)
1243
+ if isinstance(model, StackerEnsembleModel):
1244
+ preprocess_kwargs = dict(infer=False, model_pred_proba_dict=model_pred_proba_dict)
1245
+ model_pred_proba_dict[model_name] = model.predict_proba(X, **preprocess_kwargs)
1246
+ else:
1247
+ model_pred_proba_dict[model_name] = model.predict_proba(X)
1248
+
1249
+ if record_pred_time:
1250
+ time_end = time.time()
1251
+ model_pred_time_dict[model_name] = time_end - time_start
1252
+
1253
+ if record_pred_time:
1254
+ return model_pred_proba_dict, model_pred_time_dict
1255
+ else:
1256
+ return model_pred_proba_dict
1257
+
1258
+ def get_model_oof_dict(self, models: list[str]) -> dict:
1259
+ """
1260
+ Returns a dictionary of out-of-fold prediction probabilities, keyed by model name
1261
+ """
1262
+ return {model: self.get_model_oof(model) for model in models}
1263
+
1264
+ def get_model_pred_dict(self, X: pd.DataFrame, models: list[str], record_pred_time: bool = False, **kwargs):
1265
+ """
1266
+ Optimally computes predictions for each model in `models`.
1267
+ Will compute each necessary model only once and store predictions in a `model_pred_dict` dictionary.
1268
+ Note: Mutates model_pred_proba_dict and model_pred_time_dict input if present to minimize memory usage.
1269
+
1270
+ Acts as a wrapper to `self.get_model_pred_proba_dict`, converting the output to predictions.
1271
+
1272
+ Parameters
1273
+ ----------
1274
+ X : pd.DataFrame
1275
+ Input data to predict on.
1276
+ models : list[str]
1277
+ The list of models to predict with.
1278
+ Note that if models have dependency models, their dependencies will also be predicted with and included in the output.
1279
+ record_pred_time : bool, default = False
1280
+ Whether to store marginal inference times of each model as an extra output `model_pred_time_dict`.
1281
+ **kwargs : dict, optional
1282
+ Refer to `self.get_model_pred_proba_dict` for documentation of remaining arguments.
1283
+ This method shares identical arguments.
1284
+
1285
+ Returns
1286
+ -------
1287
+ If `record_pred_time==True`, outputs tuple of dicts (model_pred_dict, model_pred_time_dict), else output only model_pred_dict
1288
+ """
1289
+ model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=models, record_pred_time=record_pred_time, **kwargs)
1290
+ if record_pred_time:
1291
+ model_pred_proba_dict, model_pred_time_dict = model_pred_proba_dict
1292
+ else:
1293
+ model_pred_time_dict = None
1294
+
1295
+ model_pred_dict = {}
1296
+ for m in model_pred_proba_dict:
1297
+ # Convert pred_proba to pred
1298
+ model_pred_dict[m] = get_pred_from_proba(y_pred_proba=model_pred_proba_dict[m], problem_type=self.problem_type)
1299
+
1300
+ if record_pred_time:
1301
+ return model_pred_dict, model_pred_time_dict
1302
+ else:
1303
+ return model_pred_dict
1304
+
1305
+ def get_model_oof(self, model: str, use_refit_parent: bool = False) -> np.ndarray:
1306
+ """
1307
+ Gets the out of fold prediction probabilities for a bagged ensemble model
1308
+
1309
+ Parameters
1310
+ ----------
1311
+ model : str
1312
+ Name of the model to get OOF.
1313
+ use_refit_parent: bool = False
1314
+ If True and the model is a refit model, will instead return the parent model's OOF.
1315
+ If False and the model is a refit model, an exception will be raised.
1316
+
1317
+ Returns
1318
+ -------
1319
+ np.ndarray
1320
+ model OOF prediction probabilities (if classification) or predictions (if regression)
1321
+ """
1322
+ if use_refit_parent and self.get_model_attribute(model=model, attribute="refit_full", default=False):
1323
+ model = self.get_model_attribute(model=model, attribute="refit_full_parent")
1324
+ model_type = self.get_model_attribute(model=model, attribute="type")
1325
+ if issubclass(model_type, BaggedEnsembleModel):
1326
+ model_path = self.get_model_attribute(model=model, attribute="path")
1327
+ return model_type.load_oof(path=os.path.join(self.path, model_path))
1328
+ else:
1329
+ raise AssertionError(f"Model {model} must be a BaggedEnsembleModel to return oof_pred_proba")
1330
+
1331
+ def get_model_learning_curves(self, model: str) -> dict:
1332
+ model_type = self.get_model_attribute(model=model, attribute="type")
1333
+ model_path = self.get_model_attribute(model=model, attribute="path")
1334
+ return model_type.load_learning_curves(path=os.path.join(self.path, model_path))
1335
+
1336
+ def _update_pred_proba_dict_with_val_cache(self, model_set: set, model_pred_proba_dict):
1337
+ """For each model in model_set, check if y_pred_proba_val is cached to disk. If so, load and add it to model_pred_proba_dict"""
1338
+ for model in model_set:
1339
+ y_pred_proba = self.get_model_attribute(model, attribute="cached_y_pred_proba_val", default=None)
1340
+ if isinstance(y_pred_proba, bool):
1341
+ if y_pred_proba:
1342
+ try:
1343
+ y_pred_proba = self._load_model_y_pred_proba_val(model)
1344
+ except FileNotFoundError:
1345
+ y_pred_proba = None
1346
+ else:
1347
+ y_pred_proba = None
1348
+ if y_pred_proba is not None:
1349
+ model_pred_proba_dict[model] = y_pred_proba
1350
+ model_set = model_set.difference(set(model_pred_proba_dict.keys()))
1351
+ return model_set, model_pred_proba_dict
1352
+
1353
+ def get_inputs_to_stacker(
1354
+ self,
1355
+ X: pd.DataFrame,
1356
+ *,
1357
+ model: str | None = None,
1358
+ base_models: list[str] | None = None,
1359
+ model_pred_proba_dict: dict | None = None,
1360
+ fit: bool = False,
1361
+ use_orig_features: bool = True,
1362
+ use_val_cache: bool = False,
1363
+ ) -> pd.DataFrame:
1364
+ """
1365
+ Returns the valid X input for a stacker model with base models equal to `base_models`.
1366
+ Pairs with `feature_metadata = self.get_feature_metadata(...)`. The contents of the returned `X` should reflect `feature_metadata`.
1367
+
1368
+ Parameters
1369
+ ----------
1370
+ X : pd.DataFrame
1371
+ Input data to augment.
1372
+ model : str, default = None
1373
+ The model to derive `base_models` from.
1374
+ Cannot be specified alongside `base_models`.
1375
+ base_models : list[str], default = None
1376
+ The list of base models to augment X with.
1377
+ Base models will add their prediction probabilities as extra features to X.
1378
+ Cannot be specified alongside `model`.
1379
+ model_pred_proba_dict : dict, optional
1380
+ A dict of predict_probas that could have been computed by a prior call to `get_model_pred_proba_dict` to avoid redundant computations.
1381
+ Models already present in model_pred_proba_dict will not be predicted on.
1382
+ Note: Mutated in-place to minimize memory usage
1383
+ fit : bool, default = False
1384
+ If True, X represents the training data and the models will return their out-of-fold prediction probabilities.
1385
+ If False, X represents validation or test data and the models will predict directly on X to generate their prediction probabilities.
1386
+ use_orig_features : bool, default = True
1387
+ If True, the output DataFrame will include X's original features in addition to the new stack features.
1388
+ If False, the output DataFrame will only contain the new stack features.
1389
+ use_val_cache : bool, default = False
1390
+ Whether to fetch cached val prediction probabilities for models instead of predicting on the data.
1391
+ Only set to True if X is equal to the validation data and you want to skip live predictions.
1392
+
1393
+ Returns
1394
+ -------
1395
+ X : DataFrame, an updated DataFrame with the additional stack features from `base_models`.
1396
+ """
1397
+ if model is not None and base_models is not None:
1398
+ raise AssertionError("Only one of `model`, `base_models` is allowed to be set.")
1399
+
1400
+ if model is not None and base_models is None:
1401
+ base_models = self.get_base_model_names(model)
1402
+ if not base_models:
1403
+ return X
1404
+ if fit:
1405
+ model_pred_proba_dict = self.get_model_oof_dict(models=base_models)
1406
+ else:
1407
+ model_pred_proba_dict = self.get_model_pred_proba_dict(
1408
+ X=X, models=base_models, model_pred_proba_dict=model_pred_proba_dict, use_val_cache=use_val_cache
1409
+ )
1410
+ pred_proba_list = [model_pred_proba_dict[model] for model in base_models]
1411
+ stack_column_names, _ = self._get_stack_column_names(models=base_models)
1412
+ X_stacker = convert_pred_probas_to_df(pred_proba_list=pred_proba_list, problem_type=self.problem_type, columns=stack_column_names, index=X.index)
1413
+ if use_orig_features:
1414
+ X = pd.concat([X_stacker, X], axis=1)
1415
+ else:
1416
+ X = X_stacker
1417
+ return X
1418
+
1419
+ def get_feature_metadata(self, use_orig_features: bool = True, model: str | None = None, base_models: list[str] | None = None) -> FeatureMetadata:
1420
+ """
1421
+ Returns the FeatureMetadata input to a `model.fit` call.
1422
+ Pairs with `X = self.get_inputs_to_stacker(...)`. The returned FeatureMetadata should reflect the contents of `X`.
1423
+
1424
+ Parameters
1425
+ ----------
1426
+ use_orig_features : bool, default = True
1427
+ If True, will include the original features in the FeatureMetadata.
1428
+ If False, will only include the stack features in the FeatureMetadata.
1429
+ model : str, default = None
1430
+ If specified, it must be an already existing model.
1431
+ `base_models` will be set to the base models of `model`.
1432
+ base_models : list[str], default = None
1433
+ If specified, will add the stack features of the `base_models` to FeatureMetadata.
1434
+
1435
+ Returns
1436
+ -------
1437
+ FeatureMetadata
1438
+ The FeatureMetadata that should be passed into a `model.fit` call.
1439
+ """
1440
+ if model is not None and base_models is not None:
1441
+ raise AssertionError("Only one of `model`, `base_models` is allowed to be set.")
1442
+ if model is not None and base_models is None:
1443
+ base_models = self.get_base_model_names(model)
1444
+
1445
+ feature_metadata = None
1446
+ if use_orig_features:
1447
+ feature_metadata = self.feature_metadata
1448
+ if base_models:
1449
+ stack_column_names, _ = self._get_stack_column_names(models=base_models)
1450
+ stacker_type_map_raw = {column: R_FLOAT for column in stack_column_names}
1451
+ stacker_type_group_map_special = {S_STACK: stack_column_names}
1452
+ stacker_feature_metadata = FeatureMetadata(type_map_raw=stacker_type_map_raw, type_group_map_special=stacker_type_group_map_special)
1453
+ if feature_metadata is not None:
1454
+ feature_metadata = feature_metadata.join_metadata(stacker_feature_metadata)
1455
+ else:
1456
+ feature_metadata = stacker_feature_metadata
1457
+ if feature_metadata is None:
1458
+ feature_metadata = FeatureMetadata(type_map_raw={})
1459
+ return feature_metadata
1460
+
1461
+ def _get_stack_column_names(self, models: list[str]) -> tuple[list[str], int]:
1462
+ """
1463
+ Get the stack column names generated when the provided models are used as base models in a stack ensemble.
1464
+ Additionally output the number of columns per model as an int.
1465
+ """
1466
+ if self.problem_type in [MULTICLASS, SOFTCLASS]:
1467
+ stack_column_names = [stack_column_prefix + "_" + str(cls) for stack_column_prefix in models for cls in range(self.num_classes)]
1468
+ num_columns_per_model = self.num_classes
1469
+ elif self.problem_type == QUANTILE:
1470
+ stack_column_names = [stack_column_prefix + "_" + str(q) for stack_column_prefix in models for q in self.quantile_levels]
1471
+ num_columns_per_model = len(self.quantile_levels)
1472
+ else:
1473
+ stack_column_names = models
1474
+ num_columns_per_model = 1
1475
+ return stack_column_names, num_columns_per_model
1476
+
1477
+ # You must have previously called fit() with cache_data=True
1478
+ # Fits _FULL versions of specified models, but does NOT link them (_FULL stackers will still use normal models as input)
1479
+ def refit_single_full(
1480
+ self,
1481
+ X=None,
1482
+ y=None,
1483
+ X_val=None,
1484
+ y_val=None,
1485
+ X_unlabeled=None,
1486
+ models=None,
1487
+ fit_strategy: Literal["sequential", "parallel"] = "sequential",
1488
+ **kwargs,
1489
+ ) -> list[str]:
1490
+ if fit_strategy == "parallel":
1491
+ logger.log(30, f"Note: refit_full does not yet support fit_strategy='parallel', switching to 'sequential'...")
1492
+ fit_strategy = "sequential"
1493
+ if X is None:
1494
+ X = self.load_X()
1495
+ if X_val is None:
1496
+ X_val = self.load_X_val()
1497
+ if y is None:
1498
+ y = self.load_y()
1499
+ if y_val is None:
1500
+ y_val = self.load_y_val()
1501
+
1502
+ if models is None:
1503
+ models = self.get_model_names()
1504
+
1505
+ model_levels = dict()
1506
+ ignore_models = []
1507
+ ignore_stack_names = [REFIT_FULL_NAME]
1508
+ for stack_name in ignore_stack_names:
1509
+ ignore_models += self.get_model_names(stack_name=stack_name) # get_model_names returns [] if stack_name does not exist
1510
+ models = [model for model in models if model not in ignore_models]
1511
+ for model in models:
1512
+ model_level = self.get_model_level(model)
1513
+ if model_level not in model_levels:
1514
+ model_levels[model_level] = []
1515
+ model_levels[model_level].append(model)
1516
+
1517
+ levels = sorted(model_levels.keys())
1518
+ models_trained_full = []
1519
+ model_refit_map = {} # FIXME: is this even used, remove?
1520
+
1521
+ if fit_strategy == "sequential":
1522
+ for level in levels:
1523
+ models_level = model_levels[level]
1524
+ for model in models_level:
1525
+ model_name, models_trained = _detached_refit_single_full(
1526
+ _self=self,
1527
+ model=model,
1528
+ X=X,
1529
+ y=y,
1530
+ X_val=X_val,
1531
+ y_val=y_val,
1532
+ X_unlabeled=X_unlabeled,
1533
+ level=level,
1534
+ kwargs=kwargs,
1535
+ fit_strategy=fit_strategy,
1536
+ )
1537
+ if len(models_trained) == 1:
1538
+ model_refit_map[model_name] = models_trained[0]
1539
+ for model_trained in models_trained:
1540
+ self._update_model_attr(
1541
+ model_trained,
1542
+ refit_full=True,
1543
+ refit_full_parent=model_name,
1544
+ refit_full_parent_val_score=self.get_model_attribute(model_name, "val_score"),
1545
+ )
1546
+ models_trained_full += models_trained
1547
+ elif fit_strategy == "parallel":
1548
+ # -- Parallel refit
1549
+ ray = try_import_ray()
1550
+
1551
+ # FIXME: Need a common utility class for initializing ray so we don't duplicate code
1552
+ if not ray.is_initialized():
1553
+ ray.init(log_to_driver=False, logging_level=logging.ERROR)
1554
+
1555
+ distributed_manager = ParallelFitManager(
1556
+ mode="refit",
1557
+ func=_remote_refit_single_full,
1558
+ func_kwargs=dict(fit_strategy=fit_strategy),
1559
+ func_put_kwargs=dict(
1560
+ _self=self,
1561
+ X=X,
1562
+ y=y,
1563
+ X_val=X_val,
1564
+ y_val=y_val,
1565
+ X_unlabeled=X_unlabeled,
1566
+ kwargs=kwargs,
1567
+ ),
1568
+ # TODO: check if this is available in the kwargs
1569
+ num_cpus=kwargs.get("total_resources", {}).get("num_cpus", 1),
1570
+ num_gpus=kwargs.get("total_resources", {}).get("num_gpus", 0),
1571
+ get_model_attribute_func=self.get_model_attribute,
1572
+ X=X,
1573
+ y=y,
1574
+ )
1575
+
1576
+ for level in levels:
1577
+ models_trained_full_level = []
1578
+ distributed_manager.job_kwargs["level"] = level
1579
+ models_level = model_levels[level]
1580
+
1581
+ logger.log(20, f"Scheduling distributed model-workers for refitting {len(models_level)} L{level} models...")
1582
+ unfinished_job_refs = distributed_manager.schedule_jobs(models_to_fit=models_level)
1583
+
1584
+ while unfinished_job_refs:
1585
+ finished, unfinished_job_refs = ray.wait(unfinished_job_refs, num_returns=1)
1586
+ refit_full_parent, model_trained, model_path, model_type = ray.get(finished[0])
1587
+
1588
+ self._add_model(
1589
+ model_type.load(path=os.path.join(self.path,model_path), reset_paths=self.reset_paths),
1590
+ stack_name=REFIT_FULL_NAME,
1591
+ level=level,
1592
+ _is_refit=True
1593
+ )
1594
+ model_refit_map[refit_full_parent] = model_trained
1595
+ self._update_model_attr(
1596
+ model_trained,
1597
+ refit_full=True,
1598
+ refit_full_parent=refit_full_parent,
1599
+ refit_full_parent_val_score=self.get_model_attribute(refit_full_parent,"val_score"),
1600
+ )
1601
+ models_trained_full_level.append(model_trained)
1602
+
1603
+ logger.log(20,f"Finished refit model for {refit_full_parent}")
1604
+ unfinished_job_refs += distributed_manager.schedule_jobs()
1605
+
1606
+ logger.log(20, f"Finished distributed refitting for {len(models_trained_full_level)} L{level} models.")
1607
+ models_trained_full += models_trained_full_level
1608
+ distributed_manager.clean_job_state(unfinished_job_refs=unfinished_job_refs)
1609
+
1610
+ distributed_manager.clean_up_ray()
1611
+ else:
1612
+ raise ValueError(f"Invalid value for fit_strategy: '{fit_strategy}'")
1613
+
1614
+ keys_to_del = []
1615
+ for model in model_refit_map.keys():
1616
+ if model_refit_map[model] not in models_trained_full:
1617
+ keys_to_del.append(model)
1618
+ for key in keys_to_del:
1619
+ del model_refit_map[key]
1620
+ self.save() # TODO: This could be more efficient by passing in arg to not save if called by refit_ensemble_full since it saves anyways later.
1621
+ return models_trained_full
1622
+
1623
+ # Fits _FULL models and links them in the stack so _FULL models only use other _FULL models as input during stacking
1624
+ # If model is specified, will fit all _FULL models that are ancestors of the provided model, automatically linking them.
1625
+ # If no model is specified, all models are refit and linked appropriately.
1626
+ def refit_ensemble_full(self, model: str | list[str] = "all", **kwargs) -> dict:
1627
+ if model == "all":
1628
+ ensemble_set = self.get_model_names()
1629
+ elif isinstance(model, list):
1630
+ ensemble_set = self.get_minimum_models_set(model)
1631
+ else:
1632
+ if model == "best":
1633
+ model = self.get_model_best()
1634
+ ensemble_set = self.get_minimum_model_set(model)
1635
+ existing_models = self.get_model_names()
1636
+ ensemble_set_valid = []
1637
+ model_refit_map = self.model_refit_map()
1638
+ for model in ensemble_set:
1639
+ if model in model_refit_map and model_refit_map[model] in existing_models:
1640
+ logger.log(20, f"Model '{model}' already has a refit _FULL model: '{model_refit_map[model]}', skipping refit...")
1641
+ else:
1642
+ ensemble_set_valid.append(model)
1643
+ if ensemble_set_valid:
1644
+ models_trained_full = self.refit_single_full(models=ensemble_set_valid, **kwargs)
1645
+ else:
1646
+ models_trained_full = []
1647
+
1648
+ model_refit_map = self.model_refit_map()
1649
+ for model_full in models_trained_full:
1650
+ # TODO: Consider moving base model info to a separate pkl file so that it can be edited without having to load/save the model again
1651
+ # Downside: Slower inference speed when models are not persisted in memory prior.
1652
+ model_loaded = self.load_model(model_full)
1653
+ if isinstance(model_loaded, StackerEnsembleModel):
1654
+ for stack_column_prefix in model_loaded.stack_column_prefix_lst:
1655
+ base_model = model_loaded.stack_column_prefix_to_model_map[stack_column_prefix]
1656
+ new_base_model = model_refit_map[base_model]
1657
+ new_base_model_type = self.get_model_attribute(model=new_base_model, attribute="type")
1658
+ new_base_model_path = self.get_model_attribute(model=new_base_model, attribute="path")
1659
+
1660
+ model_loaded.base_model_paths_dict[new_base_model] = new_base_model_path
1661
+ model_loaded.base_model_types_dict[new_base_model] = new_base_model_type
1662
+ model_loaded.base_model_names.append(new_base_model)
1663
+ model_loaded.stack_column_prefix_to_model_map[stack_column_prefix] = new_base_model
1664
+
1665
+ model_loaded.save() # TODO: Avoid this!
1666
+
1667
+ # Remove old edges and add new edges
1668
+ edges_to_remove = list(self.model_graph.in_edges(model_loaded.name))
1669
+ self.model_graph.remove_edges_from(edges_to_remove)
1670
+ if isinstance(model_loaded, StackerEnsembleModel):
1671
+ for stack_column_prefix in model_loaded.stack_column_prefix_lst:
1672
+ base_model_name = model_loaded.stack_column_prefix_to_model_map[stack_column_prefix]
1673
+ self.model_graph.add_edge(base_model_name, model_loaded.name)
1674
+
1675
+ self.save()
1676
+ return self.model_refit_map()
1677
+
1678
+ def get_refit_full_parent(self, model: str) -> str:
1679
+ """Get refit full model's parent. If model does not have a parent, return `model`."""
1680
+ return self.get_model_attribute(model=model, attribute="refit_full_parent", default=model)
1681
+
1682
+ def get_model_best(
1683
+ self,
1684
+ can_infer: bool | None = None,
1685
+ allow_full: bool = True,
1686
+ infer_limit: float | None = None,
1687
+ infer_limit_as_child: bool = False
1688
+ ) -> str:
1689
+ """
1690
+ Returns the name of the model with the best validation score that satisfies all specified constraints.
1691
+ If no model satisfies the constraints, an AssertionError will be raised.
1692
+
1693
+ Parameters
1694
+ ----------
1695
+ can_infer: bool, default = None
1696
+ If True, only consider models that can infer.
1697
+ If False, only consider models that can't infer.
1698
+ If None, consider all models.
1699
+ allow_full: bool, default = True
1700
+ If True, consider all models.
1701
+ If False, disallow refit_full models.
1702
+ infer_limit: float, default = None
1703
+ The maximum time in seconds per sample that a model is allowed to take during inference.
1704
+ If None, consider all models.
1705
+ If specified, consider only models that have a lower predict time per sample than `infer_limit`.
1706
+ infer_limit_as_child: bool, default = False
1707
+ If True, use the predict time per sample of the (theoretical) refit version of the model.
1708
+ If the model is already refit, the predict time per sample is unchanged.
1709
+ If False, use the predict time per sample of the model.
1710
+
1711
+ Returns
1712
+ -------
1713
+ model: str
1714
+ The string name of the model with the best metric score that satisfies all constraints.
1715
+ """
1716
+ models = self.get_model_names(can_infer=can_infer)
1717
+ if not models:
1718
+ raise AssertionError("Trainer has no fit models that can infer.")
1719
+ models_full = self.get_models_attribute_dict(models=models, attribute="refit_full_parent")
1720
+ if not allow_full:
1721
+ models = [model for model in models if model not in models_full]
1722
+
1723
+ predict_1_time_attribute = None
1724
+ if infer_limit is not None:
1725
+ if infer_limit_as_child:
1726
+ predict_1_time_attribute = "predict_1_child_time"
1727
+ else:
1728
+ predict_1_time_attribute = "predict_1_time"
1729
+ models_predict_1_time = self.get_models_attribute_full(models=models, attribute=predict_1_time_attribute)
1730
+ models_og = copy.deepcopy(models)
1731
+ for model_key in models_predict_1_time:
1732
+ if models_predict_1_time[model_key] is None or models_predict_1_time[model_key] > infer_limit:
1733
+ models.remove(model_key)
1734
+ if models_og and not models:
1735
+ # get the fastest model
1736
+ models_predict_time_list = [models_predict_1_time[m] for m in models_og]
1737
+ min_time = np.array(models_predict_time_list).min()
1738
+ infer_limit_new = min_time * 1.2 # Give 20% lee-way
1739
+ logger.log(30, f"WARNING: Impossible to satisfy infer_limit constraint. Relaxing constraint from {infer_limit} to {infer_limit_new} ...")
1740
+ models = models_og
1741
+ for model_key in models_predict_1_time:
1742
+ if models_predict_1_time[model_key] > infer_limit_new:
1743
+ models.remove(model_key)
1744
+ if not models:
1745
+ raise AssertionError(
1746
+ f"Trainer has no fit models that can infer while satisfying the constraints: (infer_limit={infer_limit}, allow_full={allow_full})."
1747
+ )
1748
+ model_performances = self.get_models_attribute_dict(models=models, attribute="val_score")
1749
+
1750
+ predict_time_attr = predict_1_time_attribute if predict_1_time_attribute is not None else "predict_time"
1751
+ models_predict_time = self.get_models_attribute_full(models=models, attribute=predict_time_attr)
1752
+
1753
+ perfs = [(m, model_performances[m], models_predict_time[m]) for m in models if model_performances[m] is not None]
1754
+ if not perfs:
1755
+ models = [m for m in models if m in models_full]
1756
+ perfs = [(m, self.get_model_attribute(model=m, attribute="refit_full_parent_val_score"), models_predict_time[m]) for m in models]
1757
+ if not perfs:
1758
+ raise AssertionError("No fit models that can infer exist with a validation score to choose the best model.")
1759
+ elif not allow_full:
1760
+ raise AssertionError(
1761
+ "No fit models that can infer exist with a validation score to choose the best model, but refit_full models exist. Set `allow_full=True` to get the best refit_full model."
1762
+ )
1763
+ return max(perfs, key=lambda i: (i[1], -i[2]))[0]
1764
+
1765
+ def save_model(self, model, reduce_memory=True):
1766
+ # TODO: In future perhaps give option for the reduce_memory_size arguments, perhaps trainer level variables specified by user?
1767
+ if reduce_memory:
1768
+ model.reduce_memory_size(remove_fit=True, remove_info=False, requires_save=True)
1769
+ if self.low_memory:
1770
+ model.save()
1771
+ else:
1772
+ self.models[model.name] = model
1773
+
1774
+ def save(self) -> None:
1775
+ models = self.models
1776
+ if self.low_memory:
1777
+ self.models = {}
1778
+ save_pkl.save(path=os.path.join(self.path, self.trainer_file_name), object=self)
1779
+ if self.low_memory:
1780
+ self.models = models
1781
+
1782
+ def compile(self, model_names="all", with_ancestors=False, compiler_configs=None) -> list[str]:
1783
+ """
1784
+ Compile a list of models for accelerated prediction.
1785
+
1786
+ Parameters
1787
+ ----------
1788
+ model_names : str or list
1789
+ A list of model names for model compilation. Alternatively, this can be 'all' or 'best'.
1790
+ compiler_configs: dict, default=None
1791
+ Model specific compiler options.
1792
+ This can be useful to specify the compiler backend for a specific model,
1793
+ e.g. {"RandomForest": {"compiler": "onnx"}}
1794
+ """
1795
+ if model_names == "all":
1796
+ model_names = self.get_model_names(can_infer=True)
1797
+ elif model_names == "best":
1798
+ if self.model_best is not None:
1799
+ model_names = [self.model_best]
1800
+ else:
1801
+ model_names = [self.get_model_best(can_infer=True)]
1802
+ if not isinstance(model_names, list):
1803
+ raise ValueError(f"model_names must be a list of model names. Invalid value: {model_names}")
1804
+ if with_ancestors:
1805
+ model_names = self.get_minimum_models_set(model_names)
1806
+
1807
+ logger.log(20, f"Compiling {len(model_names)} Models ...")
1808
+ total_compile_time = 0
1809
+
1810
+ model_names_to_compile = []
1811
+ model_names_to_configs_dict = dict()
1812
+ for model_name in model_names:
1813
+ model_type_inner = self.get_model_attribute(model_name, "type_inner")
1814
+ # Get model specific compiler options
1815
+ # Model type can be described with either model type, or model name as string
1816
+ if model_name in compiler_configs:
1817
+ config = compiler_configs[model_name]
1818
+ elif model_type_inner in compiler_configs:
1819
+ config = compiler_configs[model_type_inner]
1820
+ else:
1821
+ config = None
1822
+ if config is not None:
1823
+ model_names_to_compile.append(model_name)
1824
+ model_names_to_configs_dict[model_name] = config
1825
+ else:
1826
+ logger.log(20, f"Skipping compilation for {model_name} ... (No config specified)")
1827
+ for model_name in model_names_to_compile:
1828
+ model = self.load_model(model_name)
1829
+ config = model_names_to_configs_dict[model_name]
1830
+
1831
+ # Check if already compiled, or if can't compile due to missing dependencies,
1832
+ # or if model hasn't implemented compiling.
1833
+ if "compiler" in config and model.get_compiler_name() == config["compiler"]:
1834
+ logger.log(20, f'Skipping compilation for {model_name} ... (Already compiled with "{model.get_compiler_name()}" backend)')
1835
+ elif model.can_compile(compiler_configs=config):
1836
+ logger.log(20, f"Compiling model: {model.name} ... Config = {config}")
1837
+ compile_start_time = time.time()
1838
+ model.compile(compiler_configs=config)
1839
+ compile_end_time = time.time()
1840
+ model.compile_time = compile_end_time - compile_start_time
1841
+ compile_type = model.get_compiler_name()
1842
+ total_compile_time += model.compile_time
1843
+
1844
+ # Update model_graph in order to put compile_time into leaderboard,
1845
+ # since models are saved right after training.
1846
+ self.model_graph.nodes[model.name]["compile_time"] = model.compile_time
1847
+ self.save_model(model, reduce_memory=False)
1848
+ logger.log(20, f'\tCompiled model with "{compile_type}" backend ...')
1849
+ logger.log(20, f"\t{round(model.compile_time, 2)}s\t = Compile runtime")
1850
+ else:
1851
+ logger.log(20, f"Skipping compilation for {model.name} ... (Unable to compile with the provided config: {config})")
1852
+ logger.log(20, f"Finished compiling models, total runtime = {round(total_compile_time, 2)}s.")
1853
+ self.save()
1854
+ return model_names
1855
+
1856
+ def persist(self, model_names="all", with_ancestors=False, max_memory=None) -> list[str]:
1857
+ if model_names == "all":
1858
+ model_names = self.get_model_names()
1859
+ elif model_names == "best":
1860
+ if self.model_best is not None:
1861
+ model_names = [self.model_best]
1862
+ else:
1863
+ model_names = [self.get_model_best(can_infer=True)]
1864
+ if not isinstance(model_names, list):
1865
+ raise ValueError(f"model_names must be a list of model names. Invalid value: {model_names}")
1866
+ if with_ancestors:
1867
+ model_names = self.get_minimum_models_set(model_names)
1868
+ model_names_already_persisted = [model_name for model_name in model_names if model_name in self.models]
1869
+ if model_names_already_persisted:
1870
+ logger.log(
1871
+ 30,
1872
+ f"The following {len(model_names_already_persisted)} models were already persisted and will be ignored in the model loading process: {model_names_already_persisted}",
1873
+ )
1874
+ model_names = [model_name for model_name in model_names if model_name not in model_names_already_persisted]
1875
+ if not model_names:
1876
+ logger.log(30, f"No valid unpersisted models were specified to be persisted, so no change in model persistence was performed.")
1877
+ return []
1878
+ if max_memory is not None:
1879
+
1880
+ @disable_if_lite_mode(ret=True)
1881
+ def _check_memory():
1882
+ info = self.get_models_info(model_names)
1883
+ model_mem_size_map = {model: info[model]["memory_size"] for model in model_names}
1884
+ for model in model_mem_size_map:
1885
+ if "children_info" in info[model]:
1886
+ for child in info[model]["children_info"].values():
1887
+ model_mem_size_map[model] += child["memory_size"]
1888
+ total_mem_required = sum(model_mem_size_map.values())
1889
+ available_mem = ResourceManager.get_available_virtual_mem()
1890
+ memory_proportion = total_mem_required / available_mem
1891
+ if memory_proportion > max_memory:
1892
+ logger.log(
1893
+ 30,
1894
+ f"Models will not be persisted in memory as they are expected to require {round(memory_proportion * 100, 2)}% of memory, which is greater than the specified max_memory limit of {round(max_memory*100, 2)}%.",
1895
+ )
1896
+ logger.log(
1897
+ 30,
1898
+ f"\tModels will be loaded on-demand from disk to maintain safe memory usage, increasing inference latency. If inference latency is a concern, try to use smaller models or increase the value of max_memory.",
1899
+ )
1900
+ return False
1901
+ else:
1902
+ logger.log(20, f"Persisting {len(model_names)} models in memory. Models will require {round(memory_proportion*100, 2)}% of memory.")
1903
+ return True
1904
+
1905
+ if not _check_memory():
1906
+ return []
1907
+
1908
+ models = []
1909
+ for model_name in model_names:
1910
+ model = self.load_model(model_name)
1911
+ self.models[model.name] = model
1912
+ models.append(model)
1913
+
1914
+ for model in models:
1915
+ # TODO: Move this to model code
1916
+ if isinstance(model, BaggedEnsembleModel):
1917
+ for fold, fold_model in enumerate(model.models):
1918
+ if isinstance(fold_model, str):
1919
+ model.models[fold] = model.load_child(fold_model)
1920
+ return model_names
1921
+
1922
+ def unpersist(self, model_names="all") -> list:
1923
+ if model_names == "all":
1924
+ model_names = list(self.models.keys())
1925
+ if not isinstance(model_names, list):
1926
+ raise ValueError(f"model_names must be a list of model names. Invalid value: {model_names}")
1927
+ unpersisted_models = []
1928
+ for model in model_names:
1929
+ if model in self.models:
1930
+ self.models.pop(model)
1931
+ unpersisted_models.append(model)
1932
+ if unpersisted_models:
1933
+ logger.log(20, f"Unpersisted {len(unpersisted_models)} models: {unpersisted_models}")
1934
+ else:
1935
+ logger.log(30, f"No valid persisted models were specified to be unpersisted, so no change in model persistence was performed.")
1936
+ return unpersisted_models
1937
+
1938
+ def generate_weighted_ensemble(
1939
+ self,
1940
+ X,
1941
+ y,
1942
+ level,
1943
+ base_model_names,
1944
+ k_fold=1,
1945
+ n_repeats=1,
1946
+ stack_name=None,
1947
+ hyperparameters=None,
1948
+ ag_args_fit=None,
1949
+ time_limit=None,
1950
+ name_suffix: str | None = None,
1951
+ save_bag_folds=None,
1952
+ check_if_best=True,
1953
+ child_hyperparameters=None,
1954
+ get_models_func=None,
1955
+ total_resources: dict | None = None,
1956
+ ) -> list[str]:
1957
+ if get_models_func is None:
1958
+ get_models_func = self.construct_model_templates
1959
+ if len(base_model_names) == 0:
1960
+ logger.log(20, "No base models to train on, skipping weighted ensemble...")
1961
+ return []
1962
+
1963
+ if child_hyperparameters is None:
1964
+ child_hyperparameters = {}
1965
+
1966
+ if save_bag_folds is None:
1967
+ can_infer_dict = self.get_models_attribute_dict("can_infer", models=base_model_names)
1968
+ if False in can_infer_dict.values():
1969
+ save_bag_folds = False
1970
+ else:
1971
+ save_bag_folds = True
1972
+
1973
+ feature_metadata = self.get_feature_metadata(use_orig_features=False, base_models=base_model_names)
1974
+
1975
+ base_model_paths_dict = self.get_models_attribute_dict(attribute="path", models=base_model_names)
1976
+ base_model_paths_dict = {key: os.path.join(self.path, val) for key, val in base_model_paths_dict.items()}
1977
+ weighted_ensemble_model, _ = get_models_func(
1978
+ hyperparameters={
1979
+ "default": {
1980
+ "ENS_WEIGHTED": [child_hyperparameters],
1981
+ }
1982
+ },
1983
+ ensemble_type=WeightedEnsembleModel,
1984
+ ensemble_kwargs=dict(
1985
+ base_model_names=base_model_names,
1986
+ base_model_paths_dict=base_model_paths_dict,
1987
+ base_model_types_dict=self.get_models_attribute_dict(attribute="type", models=base_model_names),
1988
+ base_model_types_inner_dict=self.get_models_attribute_dict(attribute="type_inner", models=base_model_names),
1989
+ base_model_performances_dict=self.get_models_attribute_dict(attribute="val_score", models=base_model_names),
1990
+ hyperparameters=hyperparameters,
1991
+ random_state=level + self.random_state,
1992
+ ),
1993
+ ag_args={"name_bag_suffix": ""},
1994
+ ag_args_fit=ag_args_fit,
1995
+ ag_args_ensemble={"save_bag_folds": save_bag_folds},
1996
+ name_suffix=name_suffix,
1997
+ level=level,
1998
+ )
1999
+ weighted_ensemble_model = weighted_ensemble_model[0]
2000
+ w = None
2001
+ if self.weight_evaluation:
2002
+ X, w = extract_column(X, self.sample_weight)
2003
+ models = self._train_multi(
2004
+ X=X,
2005
+ y=y,
2006
+ X_val=None,
2007
+ y_val=None,
2008
+ models=[weighted_ensemble_model],
2009
+ k_fold=k_fold,
2010
+ n_repeats=n_repeats,
2011
+ hyperparameter_tune_kwargs=None,
2012
+ stack_name=stack_name,
2013
+ level=level,
2014
+ time_limit=time_limit,
2015
+ ens_sample_weight=w,
2016
+ fit_kwargs=dict(feature_metadata=feature_metadata, num_classes=self.num_classes, groups=None), # FIXME: Is this the right way to do this?
2017
+ total_resources=total_resources,
2018
+ )
2019
+ for weighted_ensemble_model_name in models:
2020
+ if check_if_best and weighted_ensemble_model_name in self.get_model_names():
2021
+ if self.model_best is None:
2022
+ self.model_best = weighted_ensemble_model_name
2023
+ else:
2024
+ best_score = self.get_model_attribute(self.model_best, "val_score")
2025
+ cur_score = self.get_model_attribute(weighted_ensemble_model_name, "val_score")
2026
+ if best_score is not None and cur_score > best_score:
2027
+ # new best model
2028
+ self.model_best = weighted_ensemble_model_name
2029
+ return models
2030
+
2031
+ def _train_single(
2032
+ self,
2033
+ X: pd.DataFrame,
2034
+ y: pd.Series,
2035
+ model: AbstractModel,
2036
+ X_val: pd.DataFrame | None = None,
2037
+ y_val: pd.Series | None = None,
2038
+ X_test: pd.DataFrame | None = None,
2039
+ y_test: pd.Series | None = None,
2040
+ total_resources: dict = None,
2041
+ **model_fit_kwargs,
2042
+ ) -> AbstractModel:
2043
+ """
2044
+ Trains model but does not add the trained model to this Trainer.
2045
+ Returns trained model object.
2046
+ """
2047
+ model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test, total_resources=total_resources, **model_fit_kwargs)
2048
+ return model
2049
+
2050
+ def _train_and_save(
2051
+ self,
2052
+ X: pd.DataFrame,
2053
+ y: pd.Series,
2054
+ model: AbstractModel,
2055
+ X_val: pd.DataFrame | None = None,
2056
+ y_val: pd.Series | None = None,
2057
+ X_test: pd.DataFrame | None = None,
2058
+ y_test: pd.Series | None = None,
2059
+ X_pseudo: pd.DataFrame | None = None,
2060
+ y_pseudo: pd.DataFrame | None = None,
2061
+ time_limit: float | None = None,
2062
+ stack_name: str = "core",
2063
+ level: int = 1,
2064
+ compute_score: bool = True,
2065
+ total_resources: dict | None = None,
2066
+ errors: Literal["ignore", "raise"] = "ignore",
2067
+ errors_ignore: list | None = None,
2068
+ errors_raise: list | None = None,
2069
+ is_ray_worker: bool = False,
2070
+ **model_fit_kwargs,
2071
+ ) -> list[str]:
2072
+ """
2073
+ Trains model and saves it to disk, returning a list with a single element: The name of the model, or no elements if training failed.
2074
+ If the model name is returned:
2075
+ The model can be accessed via self.load_model(model.name).
2076
+ The model will have metadata information stored in self.model_graph.
2077
+ The model's name will be appended to self.models_level[stack_name][level]
2078
+ The model will be accessible and usable through any Trainer function that takes as input 'model' or 'model_name'.
2079
+ Note: self._train_and_save should not be used outside of self._train_single_full
2080
+
2081
+ Parameters
2082
+ ----------
2083
+ errors: Literal["ignore", "raise"], default = "ignore"
2084
+ Determines how model fit exceptions are handled.
2085
+ If "ignore", will ignore all model exceptions during fit. If an exception occurs, an empty list is returned.
2086
+ If "raise", will raise the model exception if it occurs.
2087
+ Can be overwritten by `errors_ignore` and `errors_raise`.
2088
+ errors_ignore: list[str], optional
2089
+ The exception types specified in `errors_ignore` will be treated as if `errors="ignore"`.
2090
+ errors_raise: list[str], optional
2091
+ The exception types specified in `errors_raise` will be treated as if `errors="raise"`.
2092
+
2093
+ """
2094
+ fit_start_time = time.time()
2095
+ model_names_trained = []
2096
+ y_pred_proba_val = None
2097
+
2098
+ is_distributed_mode = DistributedContext.is_distributed_mode() or is_ray_worker
2099
+
2100
+ fit_log_message = f"Fitting model: {model.name} ..."
2101
+ if time_limit is not None:
2102
+ time_left_total = time_limit
2103
+ not_enough_time = False
2104
+ if time_limit <= 0:
2105
+ not_enough_time = True
2106
+ elif self._time_limit is not None and self._time_train_start is not None:
2107
+ time_left_total = self._time_limit - (fit_start_time - self._time_train_start)
2108
+ # If only a very small amount of time remains, skip training
2109
+ min_time_required = min(self._time_limit * 0.01, 10)
2110
+ if (time_left_total < min_time_required) and (time_limit < min_time_required):
2111
+ not_enough_time = True
2112
+ if not_enough_time:
2113
+ skip_msg = f"Skipping {model.name} due to lack of time remaining."
2114
+ not_enough_time_exception = InsufficientTime(skip_msg)
2115
+ if self._check_raise_exception(exception=not_enough_time_exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise):
2116
+ raise not_enough_time_exception
2117
+ else:
2118
+ logger.log(15, skip_msg)
2119
+ return []
2120
+ fit_log_message += f" Training model for up to {time_limit:.2f}s of the {time_left_total:.2f}s of remaining time."
2121
+ logger.log(10 if is_distributed_mode else 20, fit_log_message)
2122
+
2123
+ if isinstance(model, BaggedEnsembleModel) and not compute_score:
2124
+ # Do not perform OOF predictions when we don't compute a score.
2125
+ model_fit_kwargs["_skip_oof"] = True
2126
+
2127
+ model_fit_kwargs = dict(
2128
+ model=model,
2129
+ X_val=X_val,
2130
+ y_val=y_val,
2131
+ X_test=X_test,
2132
+ y_test=y_test,
2133
+ time_limit=time_limit,
2134
+ total_resources=total_resources,
2135
+ **model_fit_kwargs,
2136
+ )
2137
+
2138
+ # If model is not bagged model and not stacked then pseudolabeled data needs to be incorporated at this level
2139
+ # Bagged model does validation on the fit level where as single models do it separately. Hence this if statement
2140
+ # is required
2141
+ if not isinstance(model, BaggedEnsembleModel) and X_pseudo is not None and y_pseudo is not None and X_pseudo.columns.equals(X.columns):
2142
+ assert_pseudo_column_match(X=X, X_pseudo=X_pseudo)
2143
+ X_w_pseudo = pd.concat([X, X_pseudo])
2144
+ y_w_pseudo = pd.concat([y, y_pseudo])
2145
+ logger.log(15, f"{len(X_pseudo)} extra rows of pseudolabeled data added to training set for {model.name}")
2146
+ model_fit_kwargs["X"] = X_w_pseudo
2147
+ model_fit_kwargs["y"] = y_w_pseudo
2148
+ else:
2149
+ model_fit_kwargs["X"] = X
2150
+ model_fit_kwargs["y"] = y
2151
+ if level > 1:
2152
+ if X_pseudo is not None and y_pseudo is not None:
2153
+ logger.log(15, f"Dropping pseudo in stacking layer due to missing out-of-fold predictions")
2154
+ else:
2155
+ model_fit_kwargs["X_pseudo"] = X_pseudo
2156
+ model_fit_kwargs["y_pseudo"] = y_pseudo
2157
+
2158
+ exception = None
2159
+ try:
2160
+ model = self._train_single(**model_fit_kwargs)
2161
+
2162
+ fit_end_time = time.time()
2163
+ if self.weight_evaluation:
2164
+ w = model_fit_kwargs.get("sample_weight", None)
2165
+ w_val = model_fit_kwargs.get("sample_weight_val", None)
2166
+ else:
2167
+ w = None
2168
+ w_val = None
2169
+ if not compute_score:
2170
+ score = None
2171
+ model.predict_time = None
2172
+ elif X_val is not None and y_val is not None:
2173
+ y_pred_proba_val = model.predict_proba(X_val, record_time=True)
2174
+ score = model.score_with_y_pred_proba(y=y_val, y_pred_proba=y_pred_proba_val, sample_weight=w_val)
2175
+ elif isinstance(model, BaggedEnsembleModel):
2176
+ if model.is_valid_oof() or isinstance(model, WeightedEnsembleModel):
2177
+ score = model.score_with_oof(y=y, sample_weight=w)
2178
+ else:
2179
+ score = None
2180
+ else:
2181
+ score = None
2182
+ pred_end_time = time.time()
2183
+ if model.fit_time is None:
2184
+ model.fit_time = fit_end_time - fit_start_time
2185
+ if model.predict_time is None and score is not None:
2186
+ model.predict_time = pred_end_time - fit_end_time
2187
+ model.val_score = score
2188
+ # TODO: Add recursive=True to avoid repeatedly loading models each time this is called for bagged ensembles (especially during repeated bagging)
2189
+ self.save_model(model=model)
2190
+ except Exception as exc:
2191
+ exception = exc # required to reference exc outside of `except` statement
2192
+ del_model = True
2193
+ if isinstance(exception, TimeLimitExceeded):
2194
+ logger.log(20, f"\tTime limit exceeded... Skipping {model.name}.")
2195
+ elif isinstance(exception, NotEnoughMemoryError):
2196
+ logger.warning(f"\tNot enough memory to train {model.name}... Skipping this model.")
2197
+ elif isinstance(exception, NoStackFeatures):
2198
+ logger.warning(f"\tNo stack features to train {model.name}... Skipping this model. {exception}")
2199
+ elif isinstance(exception, NotValidStacker):
2200
+ logger.warning(f"\tStacking disabled for {model.name}... Skipping this model. {exception}")
2201
+ elif isinstance(exception, NoValidFeatures):
2202
+ logger.warning(f"\tNo valid features to train {model.name}... Skipping this model.")
2203
+ elif isinstance(exception, NoGPUError):
2204
+ logger.warning(f"\tNo GPUs available to train {model.name}... Skipping this model.")
2205
+ elif isinstance(exception, NotEnoughCudaMemoryError):
2206
+ logger.warning(f"\tNot enough CUDA memory available to train {model.name}... Skipping this model.")
2207
+ elif isinstance(exception, ImportError):
2208
+ logger.error(f"\tWarning: Exception caused {model.name} to fail during training (ImportError)... Skipping this model.")
2209
+ logger.error(f"\t\t{exception}")
2210
+ del_model = False
2211
+ if self.verbosity > 2:
2212
+ logger.exception("Detailed Traceback:")
2213
+ else: # all other exceptions
2214
+ logger.error(f"\tWarning: Exception caused {model.name} to fail during training... Skipping this model.")
2215
+ logger.error(f"\t\t{exception}")
2216
+ if self.verbosity > 0:
2217
+ logger.exception("Detailed Traceback:")
2218
+ crash_time = time.time()
2219
+ total_time = crash_time - fit_start_time
2220
+ tb = traceback.format_exc()
2221
+ model_info = self.get_model_info(model=model)
2222
+ self._models_failed_to_train_errors[model.name] = dict(
2223
+ exc_type=exception.__class__.__name__,
2224
+ exc_str=str(exception),
2225
+ exc_traceback=tb,
2226
+ model_info=model_info,
2227
+ total_time=total_time,
2228
+ )
2229
+
2230
+ if del_model:
2231
+ del model
2232
+ else:
2233
+ self._add_model(model=model, stack_name=stack_name, level=level, y_pred_proba_val=y_pred_proba_val, is_ray_worker=is_ray_worker)
2234
+ model_names_trained.append(model.name)
2235
+ if self.low_memory:
2236
+ del model
2237
+ if exception is not None:
2238
+ if self._check_raise_exception(exception=exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise):
2239
+ raise exception
2240
+ return model_names_trained
2241
+
2242
+ # FIXME: v1.0 Move to AbstractModel for most fields
2243
+ def _get_model_metadata(self, model: AbstractModel, stack_name: str = "core", level: int = 1) -> dict[str, Any]:
2244
+ """
2245
+ Returns the model metadata used to initialize a node in the DAG (self.model_graph).
2246
+ """
2247
+ if isinstance(model, BaggedEnsembleModel):
2248
+ type_inner = model._child_type
2249
+ else:
2250
+ type_inner = type(model)
2251
+ num_children = len(model.models) if hasattr(model, "models") else 1
2252
+ predict_child_time = model.predict_time / num_children if model.predict_time is not None else None
2253
+ predict_1_child_time = model.predict_1_time / num_children if model.predict_1_time is not None else None
2254
+ fit_metadata = model.get_fit_metadata()
2255
+
2256
+ model_param_aux = getattr(model, "_params_aux_child", model.params_aux)
2257
+ model_metadata = dict(
2258
+ fit_time=model.fit_time,
2259
+ compile_time=model.compile_time,
2260
+ predict_time=model.predict_time,
2261
+ predict_1_time=model.predict_1_time,
2262
+ predict_child_time=predict_child_time,
2263
+ predict_1_child_time=predict_1_child_time,
2264
+ predict_n_time_per_row=model.predict_n_time_per_row,
2265
+ predict_n_size=model.predict_n_size,
2266
+ val_score=model.val_score,
2267
+ eval_metric=model.eval_metric.name,
2268
+ stopping_metric=model.stopping_metric.name,
2269
+ path=os.path.relpath(model.path, self.path).split(os.sep), # model's relative path to trainer
2270
+ type=type(model), # Outer type, can be BaggedEnsemble, StackEnsemble (Type that is able to load the model)
2271
+ type_inner=type_inner, # Inner type, if Ensemble then it is the type of the inner model (May not be able to load with this type)
2272
+ can_infer=model.can_infer(),
2273
+ can_fit=model.can_fit(),
2274
+ is_valid=model.is_valid(),
2275
+ stack_name=stack_name,
2276
+ level=level,
2277
+ num_children=num_children,
2278
+ fit_num_cpus=model.fit_num_cpus,
2279
+ fit_num_gpus=model.fit_num_gpus,
2280
+ fit_num_cpus_child=model.fit_num_cpus_child,
2281
+ fit_num_gpus_child=model.fit_num_gpus_child,
2282
+ refit_full_requires_gpu=(model.fit_num_gpus_child is not None) and (model.fit_num_gpus_child >= 1) and model._user_params.get("refit_folds", False),
2283
+ **fit_metadata,
2284
+ )
2285
+ return model_metadata
2286
+
2287
+ def _add_model(self, model: AbstractModel, stack_name: str = "core", level: int = 1, y_pred_proba_val=None, _is_refit=False, is_distributed_main=False, is_ray_worker: bool = False) -> bool:
2288
+ """
2289
+ Registers the fit model in the Trainer object. Stores information such as model performance, save path, model type, and more.
2290
+ To use a model in Trainer, self._add_model must be called.
2291
+ If self.low_memory, then the model object will be deleted after this call. Use Trainer directly to leverage the model further.
2292
+
2293
+ Parameters
2294
+ ----------
2295
+ model : AbstractModel
2296
+ Model which has been fit. This model will be registered to the Trainer.
2297
+ stack_name : str, default 'core'
2298
+ Stack name to assign the model to. This is used for advanced functionality.
2299
+ level : int, default 1
2300
+ Stack level of the stack name to assign the model to. This is used for advanced functionality.
2301
+ The model's name is appended to self.models_level[stack_name][level]
2302
+ The model's base_models (if it has any) must all be a lower level than the model.
2303
+ is_distributed_main: bool, default = False
2304
+ If True, the main process in distributed training is calling this function.
2305
+ This is used to avoid redundant logging in distributed training.
2306
+
2307
+ Returns
2308
+ -------
2309
+ boolean, True if model was registered, False if model was found to be invalid and not registered.
2310
+ """
2311
+ if model.val_score is not None and np.isnan(model.val_score):
2312
+ logger.warning(
2313
+ f"WARNING: {model.name} has a val_score of {model.val_score} (NaN)! This should never happen. The model will not be saved to avoid instability."
2314
+ )
2315
+ return False
2316
+ # TODO: Add to HPO
2317
+
2318
+ node_attributes = self._get_model_metadata(model=model, stack_name=stack_name, level=level)
2319
+ if y_pred_proba_val is not None:
2320
+ # Cache y_pred_proba_val for later reuse to avoid redundant predict calls
2321
+ self._save_model_y_pred_proba_val(model=model.name, y_pred_proba_val=y_pred_proba_val)
2322
+ node_attributes["cached_y_pred_proba_val"] = True
2323
+
2324
+ self.model_graph.add_node(
2325
+ model.name,
2326
+ **node_attributes,
2327
+ )
2328
+ if isinstance(model, StackerEnsembleModel):
2329
+ prior_models = self.get_model_names()
2330
+ # TODO: raise exception if no base models and level != 1?
2331
+ for stack_column_prefix in model.stack_column_prefix_lst:
2332
+ base_model_name = model.stack_column_prefix_to_model_map[stack_column_prefix]
2333
+ if base_model_name not in prior_models:
2334
+ raise AssertionError(
2335
+ f"Model '{model.name}' depends on model '{base_model_name}', but '{base_model_name}' is not registered as a trained model! Valid models: {prior_models}"
2336
+ )
2337
+ elif level <= self.model_graph.nodes[base_model_name]["level"]:
2338
+ raise AssertionError(
2339
+ f"Model '{model.name}' depends on model '{base_model_name}', but '{base_model_name}' is not in a lower stack level. ('{model.name}' level: {level}, '{base_model_name}' level: {self.model_graph.nodes[base_model_name]['level']})"
2340
+ )
2341
+ self.model_graph.add_edge(base_model_name, model.name)
2342
+ self._log_model_stats(model, _is_refit=_is_refit, is_distributed_main=is_distributed_main, is_ray_worker=is_ray_worker)
2343
+ if self.low_memory:
2344
+ del model
2345
+ return True
2346
+
2347
+ def _path_attr_model(self, model: str):
2348
+ """Returns directory where attributes are cached"""
2349
+ return os.path.join(self._path_attr, model)
2350
+
2351
+ def _path_to_model_attr(self, model: str, attribute: str):
2352
+ """Returns pkl file path for a cached model attribute"""
2353
+ return os.path.join(self._path_attr_model(model), f"{attribute}.pkl")
2354
+
2355
+ def _save_model_y_pred_proba_val(self, model: str, y_pred_proba_val):
2356
+ """Cache y_pred_proba_val for later reuse to avoid redundant predict calls"""
2357
+ save_pkl.save(path=self._path_to_model_attr(model=model, attribute="y_pred_proba_val"), object=y_pred_proba_val)
2358
+
2359
+ def _load_model_y_pred_proba_val(self, model: str):
2360
+ """Load cached y_pred_proba_val for a given model"""
2361
+ return load_pkl.load(path=self._path_to_model_attr(model=model, attribute="y_pred_proba_val"))
2362
+
2363
+ # TODO: Once Python min-version is 3.8, can refactor to use positional-only argument for model
2364
+ # https://peps.python.org/pep-0570/#empowering-library-authors
2365
+ # Currently this method cannot accept the attribute key 'model' without making usage ugly.
2366
+ def _update_model_attr(self, model: str, **attributes):
2367
+ """Updates model node in graph with the input attributes dictionary"""
2368
+ if model not in self.model_graph:
2369
+ raise AssertionError(f'"{model}" is not a key in self.model_graph, cannot add attributes: {attributes}')
2370
+ self.model_graph.nodes[model].update(attributes)
2371
+
2372
+ def _log_model_stats(self, model, _is_refit=False, is_distributed_main=False, is_ray_worker: bool = False):
2373
+ """Logs model fit time, val score, predict time, and predict_1_time"""
2374
+ model = self.load_model(model)
2375
+ print_weights = model._get_tags().get("print_weights", False)
2376
+
2377
+ is_log_during_distributed_fit = DistributedContext.is_distributed_mode() and (not is_distributed_main)
2378
+ if is_ray_worker:
2379
+ is_log_during_distributed_fit = True
2380
+ log_level = 10 if is_log_during_distributed_fit else 20
2381
+
2382
+ if print_weights:
2383
+ model_weights = model._get_model_weights()
2384
+ model_weights = {k: round(v, 3) for k, v in model_weights.items()}
2385
+ msg_weights = ""
2386
+ is_first = True
2387
+ for key, value in sorted(model_weights.items(), key=lambda x: x[1], reverse=True):
2388
+ if not is_first:
2389
+ msg_weights += ", "
2390
+ msg_weights += f"'{key}': {value}"
2391
+ is_first = False
2392
+ logger.log(log_level, f"\tEnsemble Weights: {{{msg_weights}}}")
2393
+ if model.val_score is not None:
2394
+ if model.eval_metric.name != self.eval_metric.name:
2395
+ logger.log(log_level, f"\tNote: model has different eval_metric than default.")
2396
+ if not model.eval_metric.greater_is_better_internal:
2397
+ sign_str = "-"
2398
+ else:
2399
+ sign_str = ""
2400
+ logger.log(log_level, f"\t{round(model.val_score, 4)}\t = Validation score ({sign_str}{model.eval_metric.name})")
2401
+ if model.fit_time is not None:
2402
+ logger.log(log_level, f"\t{round(model.fit_time, 2)}s\t = Training runtime")
2403
+ if model.predict_time is not None:
2404
+ logger.log(log_level, f"\t{round(model.predict_time, 2)}s\t = Validation runtime")
2405
+ predict_n_time_per_row = self.get_model_attribute_full(model=model.name, attribute="predict_n_time_per_row")
2406
+ predict_n_size = self.get_model_attribute_full(model=model.name, attribute="predict_n_size", func=min)
2407
+ if predict_n_time_per_row is not None and predict_n_size is not None:
2408
+ logger.log(
2409
+ 15,
2410
+ f"\t{round(1/(predict_n_time_per_row if predict_n_time_per_row else np.finfo(np.float16).eps), 1)}"
2411
+ f"\t = Inference throughput (rows/s | {int(predict_n_size)} batch size)",
2412
+ )
2413
+ if model.predict_1_time is not None:
2414
+ fit_metadata = model.get_fit_metadata()
2415
+ predict_1_batch_size = fit_metadata.get("predict_1_batch_size", None)
2416
+ assert predict_1_batch_size is not None, "predict_1_batch_size cannot be None if predict_1_time is not None"
2417
+
2418
+ if _is_refit:
2419
+ predict_1_time = self.get_model_attribute(model=model.name, attribute="predict_1_child_time")
2420
+ predict_1_time_full = self.get_model_attribute_full(model=model.name, attribute="predict_1_child_time")
2421
+ else:
2422
+ predict_1_time = model.predict_1_time
2423
+ predict_1_time_full = self.get_model_attribute_full(model=model.name, attribute="predict_1_time")
2424
+
2425
+ predict_1_time_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time)
2426
+ logger.log(log_level, f"\t{round(predict_1_time_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | MARGINAL)")
2427
+
2428
+ predict_1_time_full_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full)
2429
+ logger.log(log_level, f"\t{round(predict_1_time_full_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size)")
2430
+
2431
+ if not _is_refit:
2432
+ predict_1_time_child = self.get_model_attribute(model=model.name, attribute="predict_1_child_time")
2433
+ predict_1_time_child_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_child)
2434
+ logger.log(
2435
+ log_level,
2436
+ f"\t{round(predict_1_time_child_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | REFIT | MARGINAL)",
2437
+ )
2438
+
2439
+ predict_1_time_full_child = self.get_model_attribute_full(model=model.name, attribute="predict_1_child_time")
2440
+ predict_1_time_full_child_log, time_unit = convert_time_in_s_to_log_friendly(time_in_sec=predict_1_time_full_child)
2441
+ logger.log(
2442
+ log_level, f"\t{round(predict_1_time_full_child_log, 3)}{time_unit}\t = Validation runtime (1 row | {predict_1_batch_size} batch size | REFIT)"
2443
+ )
2444
+
2445
+ # TODO: Split this to avoid confusion, HPO should go elsewhere?
2446
+ def _train_single_full(
2447
+ self,
2448
+ X,
2449
+ y,
2450
+ model: AbstractModel,
2451
+ X_unlabeled=None,
2452
+ X_val=None,
2453
+ y_val=None,
2454
+ X_test=None,
2455
+ y_test=None,
2456
+ X_pseudo=None,
2457
+ y_pseudo=None,
2458
+ hyperparameter_tune_kwargs=None,
2459
+ stack_name="core",
2460
+ k_fold=None,
2461
+ k_fold_start=0,
2462
+ k_fold_end=None,
2463
+ n_repeats=None,
2464
+ n_repeat_start=0,
2465
+ level=1,
2466
+ time_limit=None,
2467
+ fit_kwargs=None,
2468
+ compute_score=True,
2469
+ total_resources: dict | None = None,
2470
+ errors: Literal["ignore", "raise"] = "ignore",
2471
+ errors_ignore: list | None = None,
2472
+ errors_raise: list | None = None,
2473
+ is_ray_worker: bool = False,
2474
+ **kwargs,
2475
+ ) -> list[str]:
2476
+ """
2477
+ Trains a model, with the potential to train multiple versions of this model with hyperparameter tuning and feature pruning.
2478
+ Returns a list of successfully trained and saved model names.
2479
+ Models trained from this method will be accessible in this Trainer.
2480
+
2481
+ Parameters
2482
+ ----------
2483
+ errors: Literal["ignore", "raise"], default = "ignore"
2484
+ Determines how model fit exceptions are handled.
2485
+ If "ignore", will ignore all model exceptions during fit. If an exception occurs, an empty list is returned.
2486
+ If "raise", will raise the model exception if it occurs.
2487
+ Can be overwritten by `errors_ignore` and `errors_raise`.
2488
+ errors_ignore: list[str], optional
2489
+ The exception types specified in `errors_ignore` will be treated as if `errors="ignore"`.
2490
+ errors_raise: list[str], optional
2491
+ The exception types specified in `errors_raise` will be treated as if `errors="raise"`.
2492
+ """
2493
+ if self._callback_early_stop:
2494
+ return []
2495
+ check_callbacks = k_fold_start == 0 and n_repeat_start == 0 and not is_ray_worker
2496
+ skip_model = False
2497
+ if self.callbacks and check_callbacks:
2498
+ skip_model, time_limit = self._callbacks_before_fit(
2499
+ model=model,
2500
+ time_limit=time_limit,
2501
+ stack_name=stack_name,
2502
+ level=level,
2503
+ )
2504
+ if self._callback_early_stop or skip_model:
2505
+ return []
2506
+
2507
+ model_fit_kwargs = self._get_model_fit_kwargs(
2508
+ X=X, X_val=X_val, time_limit=time_limit, k_fold=k_fold, fit_kwargs=fit_kwargs, ens_sample_weight=kwargs.get("ens_sample_weight", None)
2509
+ )
2510
+ exception = None
2511
+ if hyperparameter_tune_kwargs:
2512
+ if n_repeat_start != 0:
2513
+ raise ValueError(f"n_repeat_start must be 0 to hyperparameter_tune, value = {n_repeat_start}")
2514
+ elif k_fold_start != 0:
2515
+ raise ValueError(f"k_fold_start must be 0 to hyperparameter_tune, value = {k_fold_start}")
2516
+ # hpo_models (dict): keys = model_names, values = model_paths
2517
+ fit_log_message = f"Hyperparameter tuning model: {model.name} ..."
2518
+ if time_limit is not None:
2519
+ if time_limit <= 0:
2520
+ logger.log(15, f"Skipping {model.name} due to lack of time remaining.")
2521
+ return []
2522
+ fit_start_time = time.time()
2523
+ if self._time_limit is not None and self._time_train_start is not None:
2524
+ time_left_total = self._time_limit - (fit_start_time - self._time_train_start)
2525
+ else:
2526
+ time_left_total = time_limit
2527
+ fit_log_message += f" Tuning model for up to {round(time_limit, 2)}s of the {round(time_left_total, 2)}s of remaining time."
2528
+ logger.log(20, fit_log_message)
2529
+ try:
2530
+ if isinstance(model, BaggedEnsembleModel):
2531
+ bagged_model_fit_kwargs = self._get_bagged_model_fit_kwargs(
2532
+ k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start
2533
+ )
2534
+ model_fit_kwargs.update(bagged_model_fit_kwargs)
2535
+ hpo_models, hpo_results = model.hyperparameter_tune(
2536
+ X=X,
2537
+ y=y,
2538
+ model=model,
2539
+ X_val=X_val,
2540
+ y_val=y_val,
2541
+ X_unlabeled=X_unlabeled,
2542
+ stack_name=stack_name,
2543
+ level=level,
2544
+ compute_score=compute_score,
2545
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
2546
+ total_resources=total_resources,
2547
+ **model_fit_kwargs,
2548
+ )
2549
+ else:
2550
+ hpo_models, hpo_results = model.hyperparameter_tune(
2551
+ X=X,
2552
+ y=y,
2553
+ X_val=X_val,
2554
+ y_val=y_val,
2555
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
2556
+ total_resources=total_resources,
2557
+ **model_fit_kwargs,
2558
+ )
2559
+ if len(hpo_models) == 0:
2560
+ logger.warning(f"No model was trained during hyperparameter tuning {model.name}... Skipping this model.")
2561
+ except Exception as exc:
2562
+ exception = exc # required to provide exc outside of `except` statement
2563
+ if isinstance(exception, NoStackFeatures):
2564
+ logger.warning(f"\tNo stack features to train {model.name}... Skipping this model. {exception}")
2565
+ elif isinstance(exception, NotValidStacker):
2566
+ logger.warning(f"\tStacking disabled for {model.name}... Skipping this model. {exception}")
2567
+ elif isinstance(exception, NoValidFeatures):
2568
+ logger.warning(f"\tNo valid features to train {model.name}... Skipping this model.")
2569
+ else:
2570
+ logger.exception(f"Warning: Exception caused {model.name} to fail during hyperparameter tuning... Skipping this model.")
2571
+ logger.warning(exception)
2572
+ del model
2573
+ model_names_trained = []
2574
+ else:
2575
+ # Commented out because it takes too much space (>>5 GB if run for an hour on a small-medium sized dataset)
2576
+ # self.hpo_results[model.name] = hpo_results
2577
+ model_names_trained = []
2578
+ self._extra_banned_names.add(model.name)
2579
+ for model_hpo_name, model_info in hpo_models.items():
2580
+ model_hpo = self.load_model(model_hpo_name, path=os.path.relpath(model_info["path"], self.path), model_type=type(model))
2581
+ logger.log(20, f"Fitted model: {model_hpo.name} ...")
2582
+ if self._add_model(model=model_hpo, stack_name=stack_name, level=level):
2583
+ model_names_trained.append(model_hpo.name)
2584
+ else:
2585
+ model_fit_kwargs.update(dict(X_pseudo=X_pseudo, y_pseudo=y_pseudo))
2586
+ if isinstance(model, BaggedEnsembleModel):
2587
+ bagged_model_fit_kwargs = self._get_bagged_model_fit_kwargs(
2588
+ k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start
2589
+ )
2590
+ model_fit_kwargs.update(bagged_model_fit_kwargs)
2591
+ model_names_trained = self._train_and_save(
2592
+ X=X,
2593
+ y=y,
2594
+ model=model,
2595
+ X_val=X_val,
2596
+ y_val=y_val,
2597
+ X_test=X_test,
2598
+ y_test=y_test,
2599
+ X_unlabeled=X_unlabeled,
2600
+ stack_name=stack_name,
2601
+ level=level,
2602
+ compute_score=compute_score,
2603
+ total_resources=total_resources,
2604
+ errors=errors,
2605
+ errors_ignore=errors_ignore,
2606
+ errors_raise=errors_raise,
2607
+ is_ray_worker=is_ray_worker,
2608
+ **model_fit_kwargs,
2609
+ )
2610
+ if self.callbacks and check_callbacks:
2611
+ self._callbacks_after_fit(model_names=model_names_trained, stack_name=stack_name, level=level)
2612
+ self.save()
2613
+ if exception is not None:
2614
+ if self._check_raise_exception(exception=exception, errors=errors, errors_ignore=errors_ignore, errors_raise=errors_raise):
2615
+ raise exception
2616
+ return model_names_trained
2617
+
2618
+ # TODO: Move to a utility function outside of AbstractTabularTrainer
2619
+ @staticmethod
2620
+ def _check_raise_exception(
2621
+ exception: Exception,
2622
+ errors: Literal["ignore", "raise"] = "ignore",
2623
+ errors_ignore: list | None = None,
2624
+ errors_raise: list | None = None,
2625
+ ) -> bool:
2626
+ """
2627
+ Check if an exception should be raised based on the provided error handling logic.
2628
+
2629
+ Parameters
2630
+ ----------
2631
+ exception: Exception
2632
+ The exception to check
2633
+ errors: Literal["ignore", "raise"], default = "ignore"
2634
+ Determines how exceptions are handled.
2635
+ If "ignore", will return False.
2636
+ If "raise", will return True.
2637
+ Can be overwritten by `errors_ignore` and `errors_raise`.
2638
+ errors_ignore: list[str], optional
2639
+ The exception types specified in `errors_ignore` will be treated as if `errors="ignore"`.
2640
+ errors_raise: list[str], optional
2641
+ The exception types specified in `errors_raise` will be treated as if `errors="raise"`.
2642
+
2643
+ Returns
2644
+ -------
2645
+ raise_exception: bool
2646
+ If True, indicates that the exception should be raised based on the provided error handling rules.
2647
+ """
2648
+ raise_exception = None
2649
+ if errors_raise is not None:
2650
+ for err_type in errors_raise:
2651
+ if isinstance(exception, err_type):
2652
+ raise_exception = True
2653
+ break
2654
+ if errors_ignore is not None and raise_exception is None:
2655
+ for err_type in errors_ignore:
2656
+ if isinstance(exception, err_type):
2657
+ raise_exception = False
2658
+ break
2659
+ if raise_exception is None:
2660
+ if errors == "ignore":
2661
+ raise_exception = False
2662
+ elif errors == "raise":
2663
+ raise_exception = True
2664
+ else:
2665
+ raise ValueError(f"Invalid `errors` value: {errors} (valid values: ['ignore', 'raise']")
2666
+ return raise_exception
2667
+
2668
+ def _callbacks_before_fit(
2669
+ self,
2670
+ *,
2671
+ model: AbstractModel,
2672
+ time_limit: float | None,
2673
+ stack_name: str,
2674
+ level: int,
2675
+ ):
2676
+ skip_model = False
2677
+ ts = time.time()
2678
+ for callback in self.callbacks:
2679
+ callback_early_stop, callback_skip_model = callback.before_model_fit(
2680
+ trainer=self,
2681
+ model=model,
2682
+ time_limit=time_limit,
2683
+ stack_name=stack_name,
2684
+ level=level,
2685
+ )
2686
+ if callback_early_stop:
2687
+ self._callback_early_stop = True
2688
+ if callback_skip_model:
2689
+ skip_model = True
2690
+ if time_limit is not None:
2691
+ te = time.time()
2692
+ time_limit -= te - ts
2693
+ ts = te
2694
+ return skip_model, time_limit
2695
+
2696
+ def _callbacks_after_fit(
2697
+ self,
2698
+ *,
2699
+ model_names: list[str],
2700
+ stack_name: str,
2701
+ level: int,
2702
+ ):
2703
+ for callback in self.callbacks:
2704
+ callback_early_stop = callback.after_model_fit(
2705
+ self,
2706
+ model_names=model_names,
2707
+ stack_name=stack_name,
2708
+ level=level,
2709
+ )
2710
+ if callback_early_stop:
2711
+ self._callback_early_stop = True
2712
+
2713
+ # TODO: How to deal with models that fail during this? They have trained valid models before, but should we still use those models or remove the entire model? Currently we still use models.
2714
+ # TODO: Time allowance can be made better by only using time taken during final model training and not during HPO and feature pruning.
2715
+ # TODO: Time allowance not accurate if running from fit_continue
2716
+ # TODO: Remove level and stack_name arguments, can get them automatically
2717
+ # TODO: Make sure that pretraining on X_unlabeled only happens 1 time rather than every fold of bagging. (Do during pretrain API work?)
2718
+ def _train_multi_repeats(self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs) -> list[str]:
2719
+ """
2720
+ Fits bagged ensemble models with additional folds and/or bagged repeats.
2721
+ Models must have already been fit prior to entering this method.
2722
+ This method should only be called in self._train_multi
2723
+ Returns a list of successfully trained and saved model names.
2724
+ """
2725
+ if time_limit_total_level is None:
2726
+ time_limit_total_level = time_limit
2727
+ models_valid = models
2728
+ models_valid_next = []
2729
+ repeats_completed = 0
2730
+ time_start = time.time()
2731
+ for n in range(n_repeat_start, n_repeats):
2732
+ if not models_valid:
2733
+ break # No models to repeat
2734
+ if time_limit is not None:
2735
+ time_start_repeat = time.time()
2736
+ time_left = time_limit - (time_start_repeat - time_start)
2737
+ if n == n_repeat_start:
2738
+ time_required = time_limit_total_level * 0.575 # Require slightly over 50% to be safe
2739
+ else:
2740
+ time_required = (time_start_repeat - time_start) / repeats_completed * (0.575 / 0.425)
2741
+ if time_left < time_required:
2742
+ logger.log(15, "Not enough time left to finish repeated k-fold bagging, stopping early ...")
2743
+ break
2744
+ logger.log(20, f"Repeating k-fold bagging: {n+1}/{n_repeats}")
2745
+ for i, model in enumerate(models_valid):
2746
+ if self._callback_early_stop:
2747
+ break
2748
+ if not self.get_model_attribute(model=model, attribute="can_fit"):
2749
+ if isinstance(model, str):
2750
+ models_valid_next.append(model)
2751
+ else:
2752
+ models_valid_next.append(model.name)
2753
+ continue
2754
+
2755
+ if isinstance(model, str):
2756
+ model = self.load_model(model)
2757
+ if not isinstance(model, BaggedEnsembleModel):
2758
+ raise AssertionError(
2759
+ f"{model.name} must inherit from BaggedEnsembleModel to perform repeated k-fold bagging. Model type: {type(model).__name__}"
2760
+ )
2761
+ if time_limit is None:
2762
+ time_left = None
2763
+ else:
2764
+ time_start_model = time.time()
2765
+ time_left = time_limit - (time_start_model - time_start)
2766
+
2767
+ models_valid_next += self._train_single_full(
2768
+ X=X, y=y, model=model, k_fold_start=0, k_fold_end=None, n_repeats=n + 1, n_repeat_start=n, time_limit=time_left, **kwargs
2769
+ )
2770
+ models_valid = copy.deepcopy(models_valid_next)
2771
+ models_valid_next = []
2772
+ repeats_completed += 1
2773
+ logger.log(20, f"Completed {n_repeat_start + repeats_completed}/{n_repeats} k-fold bagging repeats ...")
2774
+ return models_valid
2775
+
2776
+ def _train_multi_initial(
2777
+ self, X, y, models: list[AbstractModel], k_fold, n_repeats, hyperparameter_tune_kwargs=None, time_limit=None, feature_prune_kwargs=None, **kwargs
2778
+ ):
2779
+ """
2780
+ Fits models that have not previously been fit.
2781
+ This method should only be called in self._train_multi
2782
+ Returns a list of successfully trained and saved model names.
2783
+ """
2784
+ multi_fold_time_start = time.time()
2785
+ fit_args = dict(
2786
+ X=X,
2787
+ y=y,
2788
+ k_fold=k_fold,
2789
+ )
2790
+ fit_args.update(kwargs)
2791
+
2792
+ hpo_enabled = False
2793
+ if hyperparameter_tune_kwargs:
2794
+ for key in hyperparameter_tune_kwargs:
2795
+ if hyperparameter_tune_kwargs[key] is not None:
2796
+ hpo_enabled = True
2797
+ break
2798
+
2799
+ hpo_time_ratio = 0.9
2800
+ if hpo_enabled:
2801
+ time_split = True
2802
+ else:
2803
+ time_split = False
2804
+ k_fold_start = 0
2805
+ bagged = k_fold > 0
2806
+ if not bagged:
2807
+ time_ratio = hpo_time_ratio if hpo_enabled else 1
2808
+ models = self._train_multi_fold(
2809
+ models=models,
2810
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
2811
+ time_limit=time_limit,
2812
+ time_split=time_split,
2813
+ time_ratio=time_ratio,
2814
+ **fit_args,
2815
+ )
2816
+ else:
2817
+ time_ratio = hpo_time_ratio if hpo_enabled else 1
2818
+ models = self._train_multi_fold(
2819
+ models=models,
2820
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
2821
+ k_fold_start=0,
2822
+ k_fold_end=k_fold,
2823
+ n_repeats=n_repeats,
2824
+ n_repeat_start=0,
2825
+ time_limit=time_limit,
2826
+ time_split=time_split,
2827
+ time_ratio=time_ratio,
2828
+ **fit_args,
2829
+ )
2830
+
2831
+ multi_fold_time_elapsed = time.time() - multi_fold_time_start
2832
+ if time_limit is not None:
2833
+ time_limit = time_limit - multi_fold_time_elapsed
2834
+
2835
+ if feature_prune_kwargs is not None and len(models) > 0:
2836
+ feature_prune_time_start = time.time()
2837
+ model_fit_kwargs = self._get_model_fit_kwargs(
2838
+ X=X,
2839
+ X_val=kwargs.get("X_val", None),
2840
+ time_limit=None,
2841
+ k_fold=k_fold,
2842
+ fit_kwargs=kwargs.get("fit_kwargs", {}),
2843
+ ens_sample_weight=kwargs.get("ens_sample_weight"),
2844
+ )
2845
+ model_fit_kwargs.update(dict(X=X, y=y, X_val=kwargs.get("X_val", None), y_val=kwargs.get("y_val", None)))
2846
+ if bagged:
2847
+ bagged_model_fit_kwargs = self._get_bagged_model_fit_kwargs(
2848
+ k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold, n_repeats=n_repeats, n_repeat_start=0
2849
+ )
2850
+ model_fit_kwargs.update(bagged_model_fit_kwargs)
2851
+
2852
+ # FIXME: v1.3: X.columns incorrectly includes sample_weight column
2853
+ # FIXME: v1.3: Move sample_weight logic into fit_stack_core level methods, currently we are editing X too many times in self._get_model_fit_kwargs
2854
+ candidate_features = self._proxy_model_feature_prune(
2855
+ time_limit=time_limit,
2856
+ layer_fit_time=multi_fold_time_elapsed,
2857
+ level=kwargs["level"],
2858
+ features=X.columns.tolist(),
2859
+ model_fit_kwargs=model_fit_kwargs,
2860
+ **feature_prune_kwargs,
2861
+ )
2862
+ if time_limit is not None:
2863
+ time_limit = time_limit - (time.time() - feature_prune_time_start)
2864
+
2865
+ fit_args["X"] = X[candidate_features]
2866
+ fit_args["X_val"] = kwargs["X_val"][candidate_features] if isinstance(kwargs.get("X_val", None), pd.DataFrame) else kwargs.get("X_val", None)
2867
+
2868
+ if len(candidate_features) < len(X.columns):
2869
+ unfit_models = []
2870
+ original_prune_map = {}
2871
+ for model in models:
2872
+ unfit_model = self.load_model(model).convert_to_template()
2873
+ unfit_model.rename(f"{unfit_model.name}_Prune")
2874
+ unfit_models.append(unfit_model)
2875
+ original_prune_map[unfit_model.name] = model
2876
+ pruned_models = self._train_multi_fold(
2877
+ models=unfit_models,
2878
+ hyperparameter_tune_kwargs=None,
2879
+ k_fold_start=k_fold_start,
2880
+ k_fold_end=k_fold,
2881
+ n_repeats=n_repeats,
2882
+ n_repeat_start=0,
2883
+ time_limit=time_limit,
2884
+ **fit_args,
2885
+ )
2886
+ force_prune = feature_prune_kwargs.get("force_prune", False)
2887
+ models = self._retain_better_pruned_models(pruned_models=pruned_models, original_prune_map=original_prune_map, force_prune=force_prune)
2888
+ return models
2889
+
2890
+ # TODO: Ban KNN from being a Stacker model outside of aux. Will need to ensemble select on all stack layers ensemble selector to make it work
2891
+ # TODO: Robert dataset, LightGBM is super good but RF and KNN take all the time away from it on 1h despite being much worse
2892
+ # TODO: Add time_limit_per_model
2893
+ # TODO: Rename for v0.1
2894
+ def _train_multi_fold(
2895
+ self,
2896
+ X: pd.DataFrame,
2897
+ y: pd.Series,
2898
+ models: list[AbstractModel],
2899
+ time_limit: float | None = None,
2900
+ time_split: bool = False,
2901
+ time_ratio: float = 1,
2902
+ hyperparameter_tune_kwargs: dict | None = None,
2903
+ fit_strategy: Literal["sequential", "parallel"] = "sequential",
2904
+ **kwargs,
2905
+ ) -> list[str]:
2906
+ """
2907
+ Trains and saves a list of models sequentially.
2908
+ This method should only be called in self._train_multi_initial
2909
+ Returns a list of trained model names.
2910
+ """
2911
+ time_start = time.time()
2912
+ if time_limit is not None:
2913
+ time_limit = time_limit * time_ratio
2914
+ if time_limit is not None and len(models) > 0:
2915
+ time_limit_model_split = time_limit / len(models)
2916
+ else:
2917
+ time_limit_model_split = time_limit
2918
+
2919
+ if fit_strategy == "parallel" and hyperparameter_tune_kwargs is not None and hyperparameter_tune_kwargs:
2920
+ for k, v in hyperparameter_tune_kwargs.items():
2921
+ if v is not None and (not isinstance(v, dict) or len(v) != 0):
2922
+ logger.log(
2923
+ 30,
2924
+ f"WARNING: fit_strategy='parallel', but `hyperparameter_tune_kwargs` is specified for model '{k}' with value {v}. "
2925
+ f"Hyperparameter tuning does not yet support `parallel` fit_strategy. "
2926
+ f"Falling back to fit_strategy='sequential' ... "
2927
+ )
2928
+ fit_strategy = "sequential"
2929
+ break
2930
+ if fit_strategy == "parallel":
2931
+ num_cpus = kwargs.get("total_resources", {}).get("num_cpus", "auto")
2932
+ if isinstance(num_cpus, str) and num_cpus == "auto":
2933
+ num_cpus = get_resource_manager().get_cpu_count_psutil()
2934
+ if num_cpus < 12:
2935
+ force_parallel = os.environ.get("AG_FORCE_PARALLEL", "False") == "True"
2936
+ if not force_parallel:
2937
+ logger.log(
2938
+ 30,
2939
+ f"Note: fit_strategy='parallel', but `num_cpus={num_cpus}`. "
2940
+ f"Running parallel mode with fewer than 12 CPUs is not recommended and has been disabled. "
2941
+ f'You can override this by specifying `os.environ["AG_FORCE_PARALLEL"] = "True"`. '
2942
+ f"Falling back to fit_strategy='sequential' ..."
2943
+ )
2944
+ fit_strategy = "sequential"
2945
+ if fit_strategy == "parallel":
2946
+ num_gpus = kwargs.get("total_resources", {}).get("num_gpus", 0)
2947
+ if isinstance(num_gpus, str) and num_gpus == "auto":
2948
+ num_gpus = get_resource_manager().get_gpu_count()
2949
+ if isinstance(num_gpus, (float, int)) and num_gpus > 0:
2950
+ logger.log(
2951
+ 30,
2952
+ f"WARNING: fit_strategy='parallel', but `num_gpus={num_gpus}` is specified. "
2953
+ f"GPU is not yet supported for `parallel` fit_strategy. To enable parallel, ensure you specify `num_gpus=0` in the fit call. "
2954
+ f"Falling back to fit_strategy='sequential' ... "
2955
+ )
2956
+ fit_strategy = "sequential"
2957
+ if fit_strategy == "parallel":
2958
+ try:
2959
+ try_import_ray()
2960
+ except Exception as e:
2961
+ logger.log(
2962
+ 30,
2963
+ f"WARNING: Exception encountered when trying to import ray (fit_strategy='parallel'). "
2964
+ f"ray is required for 'parallel' fit_strategy. Falling back to fit_strategy='sequential' ... "
2965
+ f"\n\tException details: {e.__class__.__name__}: {e}"
2966
+ )
2967
+ fit_strategy = "sequential"
2968
+
2969
+ if fit_strategy == "sequential":
2970
+ models_valid = []
2971
+ for model in models:
2972
+ if self._callback_early_stop:
2973
+ return models_valid
2974
+
2975
+ models_valid += _detached_train_multi_fold(
2976
+ _self=self,
2977
+ model=model,
2978
+ X=X,
2979
+ y=y,
2980
+ time_start=time_start,
2981
+ time_split=time_split,
2982
+ time_limit=time_limit,
2983
+ time_limit_model_split=time_limit_model_split,
2984
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
2985
+ is_ray_worker=False,
2986
+ kwargs=kwargs,
2987
+ )
2988
+ elif fit_strategy == "parallel":
2989
+ models_valid = self._train_multi_fold_parallel(
2990
+ X=X,
2991
+ y=y,
2992
+ models=models,
2993
+ time_start=time_start,
2994
+ time_limit_model_split=time_limit_model_split,
2995
+ time_limit=time_limit,
2996
+ time_split=time_split,
2997
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
2998
+ **kwargs,
2999
+ )
3000
+ else:
3001
+ raise ValueError(f"Invalid value for fit_strategy: '{fit_strategy}'")
3002
+ return models_valid
3003
+
3004
+ def _train_multi_fold_parallel(
3005
+ self,
3006
+ X: pd.DataFrame,
3007
+ y: pd.Series,
3008
+ models: list[AbstractModel],
3009
+ time_start: float,
3010
+ time_limit_model_split: float | None,
3011
+ time_limit: float | None = None,
3012
+ time_split: bool = False,
3013
+ hyperparameter_tune_kwargs: dict | None = None,
3014
+ **kwargs,
3015
+ ) -> list[str]:
3016
+ # -- Parallel or Distributed training
3017
+ ray = try_import_ray()
3018
+
3019
+ # FIXME: Need a common utility class for initializing ray so we don't duplicate code
3020
+ if not ray.is_initialized():
3021
+ ray.init(log_to_driver=False, logging_level=logging.ERROR)
3022
+
3023
+ models_valid = []
3024
+
3025
+ if time_limit is not None:
3026
+ # Give models less than the full time limit to account for overheads (predict, cache, ray, etc.)
3027
+ time_limit_models = time_limit * 0.9
3028
+ else:
3029
+ time_limit_models = None
3030
+
3031
+ logger.log(20, "Scheduling parallel model-workers for training...")
3032
+ distributed_manager = ParallelFitManager(
3033
+ mode="fit",
3034
+ X=X, # FIXME: REMOVE
3035
+ y=y, # FIXME: REMOVE
3036
+ func=_remote_train_multi_fold,
3037
+ func_kwargs=dict(
3038
+ time_split=time_split,
3039
+ time_limit_model_split=time_limit_model_split,
3040
+ time_limit=time_limit_models,
3041
+ time_start=time_start,
3042
+ errors="raise",
3043
+ ),
3044
+ func_put_kwargs=dict(
3045
+ _self=self,
3046
+ X=X,
3047
+ y=y,
3048
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
3049
+ kwargs=kwargs,
3050
+ ),
3051
+ num_cpus=kwargs.get("total_resources", {}).get("num_cpus", 1),
3052
+ num_gpus=kwargs.get("total_resources", {}).get("num_gpus", 0),
3053
+ num_splits=kwargs.get("k_fold", 1) * kwargs.get("n_repeats", 1),
3054
+ problem_type=self.problem_type, # FIXME: Should this be passed here?
3055
+ num_classes=self.num_classes, # FIXME: Should this be passed here?
3056
+ )
3057
+ jobs_finished = 0
3058
+ jobs_total = len(models)
3059
+
3060
+ ordered_model_names = [m.name for m in models] # Use to ensure same model order is returned
3061
+ expected_model_names = set(ordered_model_names)
3062
+ unfinished_job_refs = distributed_manager.schedule_jobs(models_to_fit=models)
3063
+
3064
+ timeout = None
3065
+
3066
+ if time_limit is not None:
3067
+ # allow between 5 and 60 seconds overhead before force killing jobs to give some leniency to jobs with overhead.
3068
+ time_overhead = min(max(time_limit * 0.01, 5), 60)
3069
+ min_time_required_base = min(self._time_limit * 0.01, 10) # This is checked in the worker thread, will skip if not satisfied
3070
+ # If time remaining is less than min_time_required, avoid scheduling new jobs and only wait for existing ones to finish.
3071
+ min_time_required = min_time_required_base * 1.5 + 1 # Add 50% buffer and 1 second to account for ray overhead
3072
+ else:
3073
+ time_overhead = None
3074
+ min_time_required = None
3075
+
3076
+ can_schedule_jobs = True
3077
+ while unfinished_job_refs:
3078
+ if time_limit is not None:
3079
+ time_left = time_limit - (time.time() - time_start)
3080
+ timeout = int(time_left + time_overhead) # include overhead.
3081
+ if timeout <= 0:
3082
+ logger.log(20, "Ran into timeout while waiting for model training to finish. Stopping now.")
3083
+ break
3084
+ finished, unfinished_job_refs = ray.wait(unfinished_job_refs, num_returns=1, timeout=timeout)
3085
+
3086
+ if not finished:
3087
+ logger.log(20, "Ran into timeout while waiting for model training to finish. Stopping now.")
3088
+ break
3089
+
3090
+ distributed_manager.deallocate_resources(job_ref=finished[0])
3091
+ model_name, model_path, model_type, exc, model_failure_info = ray.get(finished[0])
3092
+ assert model_name in expected_model_names, (f"Unexpected model name outputted during parallel fit: {model_name}\n"
3093
+ f"Valid Names: {expected_model_names}\n"
3094
+ f"This should never happen. Please create a GitHub Issue.")
3095
+ jobs_finished += 1
3096
+
3097
+ if exc is not None or model_path is None:
3098
+ if exc is None:
3099
+ if model_failure_info is not None:
3100
+ exc_type = model_failure_info["exc_type"]
3101
+ exc_str = model_failure_info["exc_str"]
3102
+ else:
3103
+ exc_type = None
3104
+ exc_str = None
3105
+ else:
3106
+ exc_type = exc.__class__
3107
+ exc_str = str(exc)
3108
+ if exc_type is not None:
3109
+ extra_log = f": {exc_type.__name__}: {exc_str}"
3110
+ else:
3111
+ extra_log = ""
3112
+ if exc_type is not None and issubclass(exc_type, InsufficientTime):
3113
+ logger.log(20, exc_str)
3114
+ else:
3115
+ logger.log(20, f"Skipping {model_name if isinstance(model_name, str) else model_name.name} due to exception{extra_log}")
3116
+ if model_failure_info is not None:
3117
+ self._models_failed_to_train_errors[model_name] = model_failure_info
3118
+ else:
3119
+ logger.log(20, f"Fitted {model_name}:")
3120
+
3121
+ # TODO: figure out a way to avoid calling _add_model in the worker-process to save overhead time.
3122
+ # - Right now, we need to call it within _add_model to be able to pass the model path to the main process without changing
3123
+ # the return signature of _train_single_full. This can be a lot of work to change.
3124
+ # TODO: determine if y_pred_proba_val was cached in the worker-process. Right now, we re-do predictions for holdout data.
3125
+ # Self object is not permanently mutated during worker execution, so we need to add model to the "main" self (again).
3126
+ # This is the synchronization point between the distributed and main processes.
3127
+ if self._add_model(
3128
+ model_type.load(path=os.path.join(self.path, model_path), reset_paths=self.reset_paths),
3129
+ stack_name=kwargs["stack_name"],
3130
+ level=kwargs["level"]
3131
+ ):
3132
+ jobs_running = len(unfinished_job_refs)
3133
+ if can_schedule_jobs:
3134
+ remaining_task_word = "pending"
3135
+ else:
3136
+ remaining_task_word = "skipped"
3137
+ parallel_status_log = (
3138
+ f"\tJobs: {jobs_running} running, "
3139
+ f"{jobs_total - (jobs_finished + jobs_running)} {remaining_task_word}, "
3140
+ f"{jobs_finished}/{jobs_total} finished"
3141
+ )
3142
+ if time_limit is not None:
3143
+ time_left = time_limit - (time.time() - time_start)
3144
+ parallel_status_log += f" | {time_left:.0f}s remaining"
3145
+ logger.log(20, parallel_status_log)
3146
+ models_valid.append(model_name)
3147
+ else:
3148
+ logger.log(40, f"Failed to add {model_name} to model graph. This should never happen. Please create a GitHub issue.")
3149
+
3150
+ if not unfinished_job_refs and not distributed_manager.models_to_schedule:
3151
+ # Completed all jobs
3152
+ break
3153
+
3154
+ # TODO: look into what this does / how this works for distributed training
3155
+ if self._callback_early_stop:
3156
+ logger.log(20, "Callback triggered in parallel setting. Stopping model training and cancelling remaining jobs.")
3157
+ break
3158
+
3159
+ # Stop due to time limit after adding model
3160
+ if time_limit is not None:
3161
+ time_elapsed = time.time() - time_start
3162
+ time_left = time_limit - time_elapsed
3163
+ time_left_models = time_limit_models - time_elapsed
3164
+ if (time_left + time_overhead) <= 0:
3165
+ logger.log(20, "Time limit reached for this stacking layer. Stopping model training and cancelling remaining jobs.")
3166
+ break
3167
+ elif time_left_models < min_time_required:
3168
+ if can_schedule_jobs:
3169
+ if len(distributed_manager.models_to_schedule) > 0:
3170
+ logger.log(
3171
+ 20,
3172
+ f"Low on time, skipping {len(distributed_manager.models_to_schedule)} "
3173
+ f"pending jobs and waiting for running jobs to finish... ({time_left:.0f}s remaining time)"
3174
+ )
3175
+ can_schedule_jobs = False
3176
+
3177
+ if can_schedule_jobs:
3178
+ # Re-schedule jobs
3179
+ unfinished_job_refs += distributed_manager.schedule_jobs()
3180
+
3181
+ distributed_manager.clean_up_ray(unfinished_job_refs=unfinished_job_refs)
3182
+ logger.log(20, "Finished all parallel work for this stacking layer.")
3183
+
3184
+ models_valid = set(models_valid)
3185
+ models_valid = [m for m in ordered_model_names if m in models_valid] # maintain original order
3186
+
3187
+ return models_valid
3188
+
3189
+ def _train_multi(
3190
+ self,
3191
+ X,
3192
+ y,
3193
+ models: list[AbstractModel],
3194
+ hyperparameter_tune_kwargs=None,
3195
+ feature_prune_kwargs=None,
3196
+ k_fold=None,
3197
+ n_repeats=None,
3198
+ n_repeat_start=0,
3199
+ time_limit=None,
3200
+ delay_bag_sets: bool = False,
3201
+ **kwargs,
3202
+ ) -> list[str]:
3203
+ """
3204
+ Train a list of models using the same data.
3205
+ Assumes that input data has already been processed in the form the models will receive as input (including stack feature generation).
3206
+ Trained models are available in the trainer object.
3207
+ Note: Consider using public APIs instead of this.
3208
+ Returns a list of trained model names.
3209
+ """
3210
+ time_limit_total_level = time_limit
3211
+ if k_fold is None:
3212
+ k_fold = self.k_fold
3213
+ if n_repeats is None:
3214
+ n_repeats = self.n_repeats
3215
+ if (k_fold == 0) and (n_repeats != 1):
3216
+ raise ValueError(f"n_repeats must be 1 when k_fold is 0, values: ({n_repeats}, {k_fold})")
3217
+ if (time_limit is None and feature_prune_kwargs is None) or (not delay_bag_sets):
3218
+ n_repeats_initial = n_repeats
3219
+ else:
3220
+ n_repeats_initial = 1
3221
+ if n_repeat_start == 0:
3222
+ time_start = time.time()
3223
+ model_names_trained = self._train_multi_initial(
3224
+ X=X,
3225
+ y=y,
3226
+ models=models,
3227
+ k_fold=k_fold,
3228
+ n_repeats=n_repeats_initial,
3229
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
3230
+ feature_prune_kwargs=feature_prune_kwargs,
3231
+ time_limit=time_limit,
3232
+ **kwargs,
3233
+ )
3234
+ n_repeat_start = n_repeats_initial
3235
+ if time_limit is not None:
3236
+ time_limit = time_limit - (time.time() - time_start)
3237
+ else:
3238
+ model_names_trained = models
3239
+ if (n_repeats > 1) and (n_repeat_start < n_repeats):
3240
+ model_names_trained = self._train_multi_repeats(
3241
+ X=X,
3242
+ y=y,
3243
+ models=model_names_trained,
3244
+ k_fold=k_fold,
3245
+ n_repeats=n_repeats,
3246
+ n_repeat_start=n_repeat_start,
3247
+ time_limit=time_limit,
3248
+ time_limit_total_level=time_limit_total_level,
3249
+ **kwargs,
3250
+ )
3251
+ return model_names_trained
3252
+
3253
+ def _train_multi_and_ensemble(
3254
+ self,
3255
+ X,
3256
+ y,
3257
+ X_val,
3258
+ y_val,
3259
+ X_test=None,
3260
+ y_test=None,
3261
+ hyperparameters: dict | None = None,
3262
+ X_unlabeled=None,
3263
+ num_stack_levels=0,
3264
+ time_limit=None,
3265
+ groups=None,
3266
+ **kwargs,
3267
+ ) -> list[str]:
3268
+ """Identical to self.train_multi_levels, but also saves the data to disk. This should only ever be called once."""
3269
+ if time_limit is not None and time_limit <= 0:
3270
+ raise AssertionError(f"Not enough time left to train models. Consider specifying a larger time_limit. Time remaining: {round(time_limit, 2)}s")
3271
+ if self.save_data and not self.is_data_saved:
3272
+ self.save_X(X)
3273
+ self.save_y(y)
3274
+ if X_val is not None:
3275
+ self.save_X_val(X_val)
3276
+ if y_val is not None:
3277
+ self.save_y_val(y_val)
3278
+ if X_test is not None:
3279
+ self.save_X_test(X_test)
3280
+ if y_test is not None:
3281
+ self.save_y_test(y_test)
3282
+ self.is_data_saved = True
3283
+ if self._groups is None:
3284
+ self._groups = groups
3285
+ self._num_rows_train = len(X)
3286
+ if X_val is not None:
3287
+ self._num_rows_val = len(X_val)
3288
+ if X_test is not None:
3289
+ self._num_rows_test = len(X_test)
3290
+ self._num_cols_train = len(list(X.columns))
3291
+ model_names_fit = self.train_multi_levels(
3292
+ X,
3293
+ y,
3294
+ hyperparameters=hyperparameters,
3295
+ X_val=X_val,
3296
+ y_val=y_val,
3297
+ X_test=X_test,
3298
+ y_test=y_test,
3299
+ X_unlabeled=X_unlabeled,
3300
+ level_start=1,
3301
+ level_end=num_stack_levels + 1,
3302
+ time_limit=time_limit,
3303
+ **kwargs,
3304
+ )
3305
+ if len(self.get_model_names()) == 0:
3306
+ # TODO v1.0: Add toggle to raise exception if no models trained
3307
+ logger.log(30, "Warning: AutoGluon did not successfully train any models")
3308
+ return model_names_fit
3309
+
3310
+ def _predict_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
3311
+ y_pred_proba = self._predict_proba_model(X=X, model=model, model_pred_proba_dict=model_pred_proba_dict)
3312
+ return get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=self.problem_type)
3313
+
3314
+ def _predict_proba_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
3315
+ model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=[model], model_pred_proba_dict=model_pred_proba_dict)
3316
+ if not isinstance(model, str):
3317
+ model = model.name
3318
+ return model_pred_proba_dict[model]
3319
+
3320
+ def _proxy_model_feature_prune(
3321
+ self, model_fit_kwargs: dict, time_limit: float, layer_fit_time: float, level: int, features: list[str], **feature_prune_kwargs: dict
3322
+ ) -> list[str]:
3323
+ """
3324
+ Uses the best LightGBM-based base learner of this layer to perform time-aware permutation feature importance based feature pruning.
3325
+ If all LightGBM models fail, use the model that achieved the highest validation accuracy. Feature pruning gets the smaller of the
3326
+ remaining layer time limit and k times (default=2) it took to fit the base learners of this layer as its resource. Note that feature pruning can
3327
+ exit earlier based on arguments in feature_prune_kwargs. The method returns the list of feature names that survived the pruning procedure.
3328
+
3329
+ Parameters
3330
+ ----------
3331
+ feature_prune_kwargs : dict
3332
+ Feature pruning kwarg arguments. Should contain arguments passed to FeatureSelector.select_features. One can optionally attach the following
3333
+ additional kwargs that are consumed at this level: 'proxy_model_class' to use a model of particular type with the highest validation score as the
3334
+ proxy model, 'feature_prune_time_limit' to manually specify how long we should perform the feature pruning procedure for, 'k' to specify how long
3335
+ we should perform feature pruning for if 'feature_prune_time_limit' has not been set (feature selection time budget is set to k * layer_fit_time),
3336
+ and 'raise_exception' to signify that AutoGluon should throw an exception if feature pruning errors out.
3337
+ time_limit : float
3338
+ Time limit left within the current stack layer in seconds. Feature pruning should never take more than this time.
3339
+ layer_fit_time : float
3340
+ How long it took to fit all the models in this layer once. Used to calculate how long to feature prune for.
3341
+ level : int
3342
+ Level of this stack layer.
3343
+ features: list[str]
3344
+ The list of feature names in the inputted dataset.
3345
+
3346
+ Returns
3347
+ -------
3348
+ candidate_features : list[str]
3349
+ Feature names that survived the pruning procedure.
3350
+ """
3351
+ k = feature_prune_kwargs.pop("k", 2)
3352
+ proxy_model_class = feature_prune_kwargs.pop("proxy_model_class", self._get_default_proxy_model_class())
3353
+ feature_prune_time_limit = feature_prune_kwargs.pop("feature_prune_time_limit", None)
3354
+ raise_exception_on_fail = feature_prune_kwargs.pop("raise_exception", False)
3355
+
3356
+ proxy_model = self._get_feature_prune_proxy_model(proxy_model_class=proxy_model_class, level=level)
3357
+ if proxy_model is None:
3358
+ return features
3359
+
3360
+ if feature_prune_time_limit is not None:
3361
+ feature_prune_time_limit = min(max(time_limit - layer_fit_time, 0), feature_prune_time_limit)
3362
+ elif time_limit is not None:
3363
+ feature_prune_time_limit = min(max(time_limit - layer_fit_time, 0), max(k * layer_fit_time, 0.05 * time_limit))
3364
+ else:
3365
+ feature_prune_time_limit = max(k * layer_fit_time, 300)
3366
+
3367
+ if feature_prune_time_limit < 2 * proxy_model.fit_time:
3368
+ logger.warning(
3369
+ f"Insufficient time to train even a single feature pruning model (remaining: {feature_prune_time_limit}, "
3370
+ f"needed: {proxy_model.fit_time}). Skipping feature pruning."
3371
+ )
3372
+ return features
3373
+ selector = FeatureSelector(
3374
+ model=proxy_model, time_limit=feature_prune_time_limit, raise_exception=raise_exception_on_fail, problem_type=self.problem_type
3375
+ )
3376
+ candidate_features = selector.select_features(**feature_prune_kwargs, **model_fit_kwargs)
3377
+ return candidate_features
3378
+
3379
+ def _get_default_proxy_model_class(self):
3380
+ return None
3381
+
3382
+ def _retain_better_pruned_models(self, pruned_models: list[str], original_prune_map: dict, force_prune: bool = False) -> list[str]:
3383
+ """
3384
+ Compares models fit on the pruned set of features with their counterpart, models fit on full set of features.
3385
+ Take the model that achieved a higher validation set score and delete the other from self.model_graph.
3386
+
3387
+ Parameters
3388
+ ----------
3389
+ pruned_models : list[str]
3390
+ A list of pruned model names.
3391
+ original_prune_map : dict
3392
+ A dictionary mapping the names of models fitted on pruned features to the names of models fitted on original features.
3393
+ force_prune : bool, default = False
3394
+ If set to true, force all base learners to work with the pruned set of features.
3395
+
3396
+ Returns
3397
+ ----------
3398
+ models : list[str]
3399
+ A list of model names.
3400
+ """
3401
+ models = []
3402
+ for pruned_model in pruned_models:
3403
+ original_model = original_prune_map[pruned_model]
3404
+ leaderboard = self.leaderboard()
3405
+ original_score = leaderboard[leaderboard["model"] == original_model]["score_val"].item()
3406
+ pruned_score = leaderboard[leaderboard["model"] == pruned_model]["score_val"].item()
3407
+ score_str = f"({round(pruned_score, 4)} vs {round(original_score, 4)})"
3408
+ if force_prune:
3409
+ logger.log(30, f"Pruned score vs original score is {score_str}. Replacing original model since force_prune=True...")
3410
+ self.delete_models(models_to_delete=original_model, dry_run=False)
3411
+ models.append(pruned_model)
3412
+ elif pruned_score > original_score:
3413
+ logger.log(30, f"Model trained with feature pruning score is better than original model's score {score_str}. Replacing original model...")
3414
+ self.delete_models(models_to_delete=original_model, dry_run=False)
3415
+ models.append(pruned_model)
3416
+ else:
3417
+ logger.log(30, f"Model trained with feature pruning score is not better than original model's score {score_str}. Keeping original model...")
3418
+ self.delete_models(models_to_delete=pruned_model, dry_run=False)
3419
+ models.append(original_model)
3420
+ return models
3421
+
3422
+ # TODO: Enable raw=True for bagged models when X=None
3423
+ # This is non-trivial to implement for multi-layer stacking ensembles on the OOF data.
3424
+ # TODO: Consider limiting X to 10k rows here instead of inside the model call
3425
+ def get_feature_importance(self, model=None, X=None, y=None, raw=True, **kwargs) -> pd.DataFrame:
3426
+ if model is None:
3427
+ model = self.model_best
3428
+ model: AbstractModel = self.load_model(model)
3429
+ if X is None and model.val_score is None:
3430
+ raise AssertionError(
3431
+ f"Model {model.name} is not valid for generating feature importances on original training data because no validation data was used during training, please specify new test data to compute feature importances."
3432
+ )
3433
+
3434
+ if X is None:
3435
+ if isinstance(model, WeightedEnsembleModel):
3436
+ if self.bagged_mode:
3437
+ if raw:
3438
+ raise AssertionError(
3439
+ "`feature_stage='transformed'` feature importance on the original training data is not yet supported when bagging is enabled, please specify new test data to compute feature importances."
3440
+ )
3441
+ X = None
3442
+ is_oof = True
3443
+ else:
3444
+ if raw:
3445
+ X = self.load_X_val()
3446
+ else:
3447
+ X = None
3448
+ is_oof = False
3449
+ elif isinstance(model, BaggedEnsembleModel):
3450
+ if raw:
3451
+ raise AssertionError(
3452
+ "`feature_stage='transformed'` feature importance on the original training data is not yet supported when bagging is enabled, please specify new test data to compute feature importances."
3453
+ )
3454
+ X = self.load_X()
3455
+ X = self.get_inputs_to_model(model=model, X=X, fit=True)
3456
+ is_oof = True
3457
+ else:
3458
+ X = self.load_X_val()
3459
+ if not raw:
3460
+ X = self.get_inputs_to_model(model=model, X=X, fit=False)
3461
+ is_oof = False
3462
+ else:
3463
+ is_oof = False
3464
+ if not raw:
3465
+ X = self.get_inputs_to_model(model=model, X=X, fit=False)
3466
+
3467
+ if y is None and X is not None:
3468
+ if is_oof:
3469
+ y = self.load_y()
3470
+ else:
3471
+ y = self.load_y_val()
3472
+
3473
+ if raw:
3474
+ return self._get_feature_importance_raw(X=X, y=y, model=model, **kwargs)
3475
+ else:
3476
+ if is_oof:
3477
+ kwargs["is_oof"] = is_oof
3478
+ return model.compute_feature_importance(X=X, y=y, **kwargs)
3479
+
3480
+ # TODO: Can get feature importances of all children of model at no extra cost, requires scoring the values after predict_proba on each model
3481
+ # Could solve by adding a self.score_all() function which takes model as input and also returns scores of all children models.
3482
+ # This would be best solved after adding graph representation, it lives most naturally in AbstractModel
3483
+ # TODO: Can skip features which were pruned on all models that model depends on (Complex to implement, requires graph representation)
3484
+ # TODO: Note that raw importance will not equal non-raw importance for bagged models, even if raw features are identical to the model features.
3485
+ # This is because for non-raw, we do an optimization where each fold model calls .compute_feature_importance(), and then the feature importances are averaged across the folds.
3486
+ # This is different from raw, where the predictions of the folds are averaged and then feature importance is computed.
3487
+ # Consider aligning these methods so they produce the same result.
3488
+ # The output of this function is identical to non-raw when model is level 1 and non-bagged
3489
+ def _get_feature_importance_raw(self, X, y, model, eval_metric=None, **kwargs) -> pd.DataFrame:
3490
+ if eval_metric is None:
3491
+ eval_metric = self.eval_metric
3492
+ if model is None:
3493
+ model = self._get_best()
3494
+ if eval_metric.needs_pred:
3495
+ predict_func = self.predict
3496
+ else:
3497
+ predict_func = self.predict_proba
3498
+ model: AbstractModel = self.load_model(model)
3499
+ predict_func_kwargs = dict(model=model)
3500
+ return compute_permutation_feature_importance(
3501
+ X=X,
3502
+ y=y,
3503
+ predict_func=predict_func,
3504
+ predict_func_kwargs=predict_func_kwargs,
3505
+ eval_metric=eval_metric,
3506
+ quantile_levels=self.quantile_levels,
3507
+ **kwargs,
3508
+ )
3509
+
3510
+ def _get_models_load_info(self, model_names):
3511
+ model_names = copy.deepcopy(model_names)
3512
+ model_paths = self.get_models_attribute_dict(attribute="path", models=model_names)
3513
+ model_types = self.get_models_attribute_dict(attribute="type", models=model_names)
3514
+ return model_names, model_paths, model_types
3515
+
3516
+ def get_model_attribute_full(self, model: str | list[str], attribute: str, func=sum) -> float | int:
3517
+ """
3518
+ Sums the attribute value across all models that the provided model depends on, including itself.
3519
+ For instance, this function can return the expected total predict_time of a model.
3520
+ attribute is the name of the desired attribute to be summed,
3521
+ or a dictionary of model name -> attribute value if the attribute is not present in the graph.
3522
+ """
3523
+ if isinstance(model, list):
3524
+ base_model_set = self.get_minimum_models_set(model)
3525
+ else:
3526
+ base_model_set = self.get_minimum_model_set(model)
3527
+ if isinstance(attribute, dict):
3528
+ is_dict = True
3529
+ else:
3530
+ is_dict = False
3531
+ if len(base_model_set) == 1:
3532
+ if is_dict:
3533
+ return attribute[model]
3534
+ else:
3535
+ return self.model_graph.nodes[base_model_set[0]][attribute]
3536
+ # attribute_full = 0
3537
+ attribute_lst = []
3538
+ for base_model in base_model_set:
3539
+ if is_dict:
3540
+ attribute_base_model = attribute[base_model]
3541
+ else:
3542
+ attribute_base_model = self.model_graph.nodes[base_model][attribute]
3543
+ if attribute_base_model is None:
3544
+ return None
3545
+ attribute_lst.append(attribute_base_model)
3546
+ # attribute_full += attribute_base_model
3547
+ if attribute_lst:
3548
+ attribute_full = func(attribute_lst)
3549
+ else:
3550
+ attribute_full = 0
3551
+ return attribute_full
3552
+
3553
+ def get_models_attribute_full(self, models: list[str], attribute: str, func=sum):
3554
+ """
3555
+ For each model in models, returns the output of self.get_model_attribute_full mapped to a dict.
3556
+ """
3557
+ d = dict()
3558
+ for model in models:
3559
+ d[model] = self.get_model_attribute_full(model=model, attribute=attribute, func=func)
3560
+ return d
3561
+
3562
+ # Gets the minimum set of models that the provided models depend on, including themselves
3563
+ # Returns a list of model names
3564
+ def get_minimum_models_set(self, models: list) -> list:
3565
+ models_set = set()
3566
+ for model in models:
3567
+ models_set = models_set.union(self.get_minimum_model_set(model))
3568
+ return list(models_set)
3569
+
3570
+ # Gets the set of base models used directly by the provided model
3571
+ # Returns a list of model names
3572
+ def get_base_model_names(self, model) -> list:
3573
+ if not isinstance(model, str):
3574
+ model = model.name
3575
+ base_model_set = list(self.model_graph.predecessors(model))
3576
+ return base_model_set
3577
+
3578
+ def model_refit_map(self, inverse=False) -> dict[str, str]:
3579
+ """
3580
+ Returns dict of parent model -> refit model
3581
+
3582
+ If inverse=True, return dict of refit model -> parent model
3583
+ """
3584
+ model_refit_map = self.get_models_attribute_dict(attribute="refit_full_parent")
3585
+ if not inverse:
3586
+ model_refit_map = {parent: refit for refit, parent in model_refit_map.items()}
3587
+ return model_refit_map
3588
+
3589
+ def model_exists(self, model: str) -> bool:
3590
+ return model in self.get_model_names()
3591
+
3592
+ def _flatten_model_info(self, model_info: dict) -> dict:
3593
+ """
3594
+ Flattens the model_info nested dictionary into a shallow dictionary to convert to a pandas DataFrame row.
3595
+
3596
+ Parameters
3597
+ ----------
3598
+ model_info: dict
3599
+ A nested dictionary of model metadata information
3600
+
3601
+ Returns
3602
+ -------
3603
+ A flattened dictionary of model info.
3604
+ """
3605
+ model_info_keys = [
3606
+ "num_features",
3607
+ "model_type",
3608
+ "hyperparameters",
3609
+ "hyperparameters_fit",
3610
+ "ag_args_fit",
3611
+ "features",
3612
+ "is_initialized",
3613
+ "is_fit",
3614
+ "is_valid",
3615
+ "can_infer",
3616
+ ]
3617
+ model_info_flat = {k: v for k, v in model_info.items() if k in model_info_keys}
3618
+
3619
+ custom_info = {}
3620
+ bagged_info = model_info.get("bagged_info", {})
3621
+ custom_info["num_models"] = bagged_info.get("num_child_models", 1)
3622
+ custom_info["memory_size"] = bagged_info.get("max_memory_size", model_info["memory_size"])
3623
+ custom_info["memory_size_min"] = bagged_info.get("min_memory_size", model_info["memory_size"])
3624
+ custom_info["compile_time"] = bagged_info.get("compile_time", model_info["compile_time"])
3625
+ custom_info["child_model_type"] = bagged_info.get("child_model_type", None)
3626
+ custom_info["child_hyperparameters"] = bagged_info.get("child_hyperparameters", None)
3627
+ custom_info["child_hyperparameters_fit"] = bagged_info.get("child_hyperparameters_fit", None)
3628
+ custom_info["child_ag_args_fit"] = bagged_info.get("child_ag_args_fit", None)
3629
+
3630
+ model_info_keys = [
3631
+ "num_models",
3632
+ "memory_size",
3633
+ "memory_size_min",
3634
+ "compile_time",
3635
+ "child_model_type",
3636
+ "child_hyperparameters",
3637
+ "child_hyperparameters_fit",
3638
+ "child_ag_args_fit",
3639
+ ]
3640
+ for key in model_info_keys:
3641
+ model_info_flat[key] = custom_info[key]
3642
+ return model_info_flat
3643
+
3644
+ def leaderboard(self, extra_info=False, refit_full: bool | None = None, set_refit_score_to_parent: bool = False):
3645
+ model_names = self.get_model_names()
3646
+ models_full_dict = self.get_models_attribute_dict(models=model_names, attribute="refit_full_parent")
3647
+ if refit_full is not None:
3648
+ if refit_full:
3649
+ model_names = [model for model in model_names if model in models_full_dict]
3650
+ else:
3651
+ model_names = [model for model in model_names if model not in models_full_dict]
3652
+ score_val = []
3653
+ eval_metric = []
3654
+ stopping_metric = []
3655
+ fit_time_marginal = []
3656
+ pred_time_val_marginal = []
3657
+ stack_level = []
3658
+ fit_time = []
3659
+ pred_time_val = []
3660
+ can_infer = []
3661
+ fit_order = list(range(1, len(model_names) + 1))
3662
+ score_val_dict = self.get_models_attribute_dict("val_score")
3663
+ eval_metric_dict = self.get_models_attribute_dict("eval_metric")
3664
+ stopping_metric_dict = self.get_models_attribute_dict("stopping_metric")
3665
+ fit_time_marginal_dict = self.get_models_attribute_dict("fit_time")
3666
+ predict_time_marginal_dict = self.get_models_attribute_dict("predict_time")
3667
+ fit_time_dict = self.get_models_attribute_full(attribute="fit_time", models=model_names, func=sum)
3668
+ pred_time_val_dict = self.get_models_attribute_full(attribute="predict_time", models=model_names, func=sum)
3669
+ can_infer_dict = self.get_models_attribute_full(attribute="can_infer", models=model_names, func=min)
3670
+ for model_name in model_names:
3671
+ if set_refit_score_to_parent and (model_name in models_full_dict):
3672
+ if models_full_dict[model_name] not in score_val_dict:
3673
+ raise AssertionError(
3674
+ f"Model parent is missing from leaderboard when `set_refit_score_to_parent=True`, "
3675
+ f"this is invalid. The parent model may have been deleted. "
3676
+ f"(model='{model_name}', parent='{models_full_dict[model_name]}')"
3677
+ )
3678
+ score_val.append(score_val_dict[models_full_dict[model_name]])
3679
+ else:
3680
+ score_val.append(score_val_dict[model_name])
3681
+ eval_metric.append(eval_metric_dict[model_name])
3682
+ stopping_metric.append(stopping_metric_dict[model_name])
3683
+ fit_time_marginal.append(fit_time_marginal_dict[model_name])
3684
+ fit_time.append(fit_time_dict[model_name])
3685
+ pred_time_val_marginal.append(predict_time_marginal_dict[model_name])
3686
+ pred_time_val.append(pred_time_val_dict[model_name])
3687
+ stack_level.append(self.get_model_level(model_name))
3688
+ can_infer.append(can_infer_dict[model_name])
3689
+
3690
+ model_info_dict = defaultdict(list)
3691
+ extra_info_dict = dict()
3692
+ if extra_info:
3693
+ # TODO: feature_metadata
3694
+ # TODO: disk size
3695
+ # TODO: load time
3696
+ # TODO: Add persist_if_mem_safe() function to persist in memory all models if reasonable memory size (or a specific model+ancestors)
3697
+ # TODO: Add is_persisted() function to check which models are persisted in memory
3698
+ # TODO: package_dependencies, package_dependencies_full
3699
+
3700
+ info = self.get_info(include_model_info=True)
3701
+ model_info = info["model_info"]
3702
+ custom_model_info = {}
3703
+ for model_name in model_info:
3704
+ custom_info = {}
3705
+ bagged_info = model_info[model_name].get("bagged_info", {})
3706
+ custom_info["num_models"] = bagged_info.get("num_child_models", 1)
3707
+ custom_info["memory_size"] = bagged_info.get("max_memory_size", model_info[model_name]["memory_size"])
3708
+ custom_info["memory_size_min"] = bagged_info.get("min_memory_size", model_info[model_name]["memory_size"])
3709
+ custom_info["compile_time"] = bagged_info.get("compile_time", model_info[model_name]["compile_time"])
3710
+ custom_info["child_model_type"] = bagged_info.get("child_model_type", None)
3711
+ custom_info["child_hyperparameters"] = bagged_info.get("child_hyperparameters", None)
3712
+ custom_info["child_hyperparameters_fit"] = bagged_info.get("child_hyperparameters_fit", None)
3713
+ custom_info["child_ag_args_fit"] = bagged_info.get("child_ag_args_fit", None)
3714
+ custom_model_info[model_name] = custom_info
3715
+
3716
+ model_info_keys = ["num_features", "model_type", "hyperparameters", "hyperparameters_fit", "ag_args_fit", "features"]
3717
+ model_info_sum_keys = []
3718
+ for key in model_info_keys:
3719
+ model_info_dict[key] = [model_info[model_name][key] for model_name in model_names]
3720
+ if key in model_info_sum_keys:
3721
+ key_dict = {model_name: model_info[model_name][key] for model_name in model_names}
3722
+ model_info_dict[key + "_full"] = [self.get_model_attribute_full(model=model_name, attribute=key_dict) for model_name in model_names]
3723
+
3724
+ model_info_keys = [
3725
+ "num_models",
3726
+ "memory_size",
3727
+ "memory_size_min",
3728
+ "compile_time",
3729
+ "child_model_type",
3730
+ "child_hyperparameters",
3731
+ "child_hyperparameters_fit",
3732
+ "child_ag_args_fit",
3733
+ ]
3734
+ model_info_full_keys = {
3735
+ "memory_size": [("memory_size_w_ancestors", sum)],
3736
+ "memory_size_min": [("memory_size_min_w_ancestors", max)],
3737
+ "num_models": [("num_models_w_ancestors", sum)],
3738
+ }
3739
+ for key in model_info_keys:
3740
+ model_info_dict[key] = [custom_model_info[model_name][key] for model_name in model_names]
3741
+ if key in model_info_full_keys:
3742
+ key_dict = {model_name: custom_model_info[model_name][key] for model_name in model_names}
3743
+ for column_name, func in model_info_full_keys[key]:
3744
+ model_info_dict[column_name] = [
3745
+ self.get_model_attribute_full(model=model_name, attribute=key_dict, func=func) for model_name in model_names
3746
+ ]
3747
+
3748
+ ancestors = [list(nx.dag.ancestors(self.model_graph, model_name)) for model_name in model_names]
3749
+ descendants = [list(nx.dag.descendants(self.model_graph, model_name)) for model_name in model_names]
3750
+
3751
+ model_info_dict["num_ancestors"] = [len(ancestor_lst) for ancestor_lst in ancestors]
3752
+ model_info_dict["num_descendants"] = [len(descendant_lst) for descendant_lst in descendants]
3753
+ model_info_dict["ancestors"] = ancestors
3754
+ model_info_dict["descendants"] = descendants
3755
+
3756
+ extra_info_dict = {
3757
+ "stopping_metric": stopping_metric,
3758
+ }
3759
+
3760
+ df = pd.DataFrame(
3761
+ data={
3762
+ "model": model_names,
3763
+ "score_val": score_val,
3764
+ "eval_metric": eval_metric,
3765
+ "pred_time_val": pred_time_val,
3766
+ "fit_time": fit_time,
3767
+ "pred_time_val_marginal": pred_time_val_marginal,
3768
+ "fit_time_marginal": fit_time_marginal,
3769
+ "stack_level": stack_level,
3770
+ "can_infer": can_infer,
3771
+ "fit_order": fit_order,
3772
+ **extra_info_dict,
3773
+ **model_info_dict,
3774
+ }
3775
+ )
3776
+ df_sorted = df.sort_values(by=["score_val", "pred_time_val", "model"], ascending=[False, True, False]).reset_index(drop=True)
3777
+
3778
+ df_columns_lst = df_sorted.columns.tolist()
3779
+ explicit_order = [
3780
+ "model",
3781
+ "score_val",
3782
+ "eval_metric",
3783
+ "pred_time_val",
3784
+ "fit_time",
3785
+ "pred_time_val_marginal",
3786
+ "fit_time_marginal",
3787
+ "stack_level",
3788
+ "can_infer",
3789
+ "fit_order",
3790
+ "num_features",
3791
+ "num_models",
3792
+ "num_models_w_ancestors",
3793
+ "memory_size",
3794
+ "memory_size_w_ancestors",
3795
+ "memory_size_min",
3796
+ "memory_size_min_w_ancestors",
3797
+ "num_ancestors",
3798
+ "num_descendants",
3799
+ "model_type",
3800
+ "child_model_type",
3801
+ ]
3802
+ explicit_order = [column for column in explicit_order if column in df_columns_lst]
3803
+ df_columns_other = [column for column in df_columns_lst if column not in explicit_order]
3804
+ df_columns_new = explicit_order + df_columns_other
3805
+ df_sorted = df_sorted[df_columns_new]
3806
+
3807
+ return df_sorted
3808
+
3809
+ def model_failures(self) -> pd.DataFrame:
3810
+ """
3811
+ [Advanced] Get the model failures that occurred during the fitting of this predictor, in the form of a pandas DataFrame.
3812
+
3813
+ This is useful for in-depth debugging of model failures and identifying bugs.
3814
+
3815
+ Returns
3816
+ -------
3817
+ model_failures_df: pd.DataFrame
3818
+ A DataFrame of model failures. Each row corresponds to a model failure, and columns correspond to meta information about that model.
3819
+ """
3820
+ model_infos = dict()
3821
+ for i, (model_name, model_info) in enumerate(self._models_failed_to_train_errors.items()):
3822
+ model_info = copy.deepcopy(model_info)
3823
+ model_info_inner = model_info["model_info"]
3824
+
3825
+ model_info_inner = self._flatten_model_info(model_info_inner)
3826
+
3827
+ valid_keys = [
3828
+ "exc_type",
3829
+ "exc_str",
3830
+ "exc_traceback",
3831
+ "total_time",
3832
+ ]
3833
+ valid_keys_inner = [
3834
+ "model_type",
3835
+ "hyperparameters",
3836
+ "hyperparameters_fit",
3837
+ "is_initialized",
3838
+ "is_fit",
3839
+ "is_valid",
3840
+ "can_infer",
3841
+ "num_features",
3842
+ "memory_size",
3843
+ "num_models",
3844
+ "child_model_type",
3845
+ "child_hyperparameters",
3846
+ "child_hyperparameters_fit",
3847
+ ]
3848
+ model_info_out = {k: v for k, v in model_info.items() if k in valid_keys}
3849
+ model_info_inner_out = {k: v for k, v in model_info_inner.items() if k in valid_keys_inner}
3850
+
3851
+ model_info_out.update(model_info_inner_out)
3852
+ model_info_out["model"] = model_name
3853
+ model_info_out["exc_order"] = i + 1
3854
+
3855
+ model_infos[model_name] = model_info_out
3856
+
3857
+ df = pd.DataFrame(
3858
+ data=model_infos,
3859
+ ).T
3860
+
3861
+ explicit_order = [
3862
+ "model",
3863
+ "exc_type",
3864
+ "total_time",
3865
+ "model_type",
3866
+ "child_model_type",
3867
+ "is_initialized",
3868
+ "is_fit",
3869
+ "is_valid",
3870
+ "can_infer",
3871
+ "num_features",
3872
+ "num_models",
3873
+ "memory_size",
3874
+ "hyperparameters",
3875
+ "hyperparameters_fit",
3876
+ "child_hyperparameters",
3877
+ "child_hyperparameters_fit",
3878
+ "exc_str",
3879
+ "exc_traceback",
3880
+ "exc_order",
3881
+ ]
3882
+
3883
+ df_columns_lst = list(df.columns)
3884
+ explicit_order = [column for column in explicit_order if column in df_columns_lst]
3885
+ df_columns_other = [column for column in df_columns_lst if column not in explicit_order]
3886
+ df_columns_new = explicit_order + df_columns_other
3887
+ df_sorted = df[df_columns_new]
3888
+ df_sorted = df_sorted.reset_index(drop=True)
3889
+
3890
+ return df_sorted
3891
+
3892
+ def get_info(self, include_model_info=False, include_model_failures=True) -> dict:
3893
+ num_models_trained = len(self.get_model_names())
3894
+ if self.model_best is not None:
3895
+ best_model = self.model_best
3896
+ else:
3897
+ try:
3898
+ best_model = self.get_model_best()
3899
+ except AssertionError:
3900
+ best_model = None
3901
+ if best_model is not None:
3902
+ best_model_score_val = self.get_model_attribute(model=best_model, attribute="val_score")
3903
+ best_model_stack_level = self.get_model_level(best_model)
3904
+ else:
3905
+ best_model_score_val = None
3906
+ best_model_stack_level = None
3907
+ # fit_time = None
3908
+ num_bag_folds = self.k_fold
3909
+ max_core_stack_level = self.get_max_level("core")
3910
+ max_stack_level = self.get_max_level()
3911
+
3912
+ problem_type = self.problem_type
3913
+ eval_metric = self.eval_metric.name
3914
+ time_train_start = self._time_train_start_last
3915
+ num_rows_train = self._num_rows_train
3916
+ num_cols_train = self._num_cols_train
3917
+ num_rows_val = self._num_rows_val
3918
+ num_rows_test = self._num_rows_test
3919
+ num_classes = self.num_classes
3920
+ # TODO:
3921
+ # Disk size of models
3922
+ # Raw feature count
3923
+ # HPO time
3924
+ # Bag time
3925
+ # Feature prune time
3926
+ # Exception count / models failed count
3927
+ # True model count (models * kfold)
3928
+ # AutoGluon version fit on
3929
+ # Max memory usage
3930
+ # CPU count used / GPU count used
3931
+
3932
+ info = {
3933
+ "time_train_start": time_train_start,
3934
+ "num_rows_train": num_rows_train,
3935
+ "num_cols_train": num_cols_train,
3936
+ "num_rows_val": num_rows_val,
3937
+ "num_rows_test": num_rows_test,
3938
+ "num_classes": num_classes,
3939
+ "problem_type": problem_type,
3940
+ "eval_metric": eval_metric,
3941
+ "best_model": best_model,
3942
+ "best_model_score_val": best_model_score_val,
3943
+ "best_model_stack_level": best_model_stack_level,
3944
+ "num_models_trained": num_models_trained,
3945
+ "num_bag_folds": num_bag_folds,
3946
+ "max_stack_level": max_stack_level,
3947
+ "max_core_stack_level": max_core_stack_level,
3948
+ }
3949
+
3950
+ if include_model_info:
3951
+ info["model_info"] = self.get_models_info()
3952
+ if include_model_failures:
3953
+ info["model_info_failures"] = copy.deepcopy(self._models_failed_to_train_errors)
3954
+
3955
+ return info
3956
+
3957
+ def reduce_memory_size(
3958
+ self, remove_data=True, remove_fit_stack=False, remove_fit=True, remove_info=False, requires_save=True, reduce_children=False, **kwargs
3959
+ ):
3960
+ if remove_data and self.is_data_saved:
3961
+ data_files = [
3962
+ os.path.join(self.path_data, "X.pkl"),
3963
+ os.path.join(self.path_data, "X_val.pkl"),
3964
+ os.path.join(self.path_data, "y.pkl"),
3965
+ os.path.join(self.path_data, "y_val.pkl"),
3966
+ ]
3967
+ for data_file in data_files:
3968
+ try:
3969
+ os.remove(data_file)
3970
+ except FileNotFoundError:
3971
+ pass
3972
+ if requires_save:
3973
+ self.is_data_saved = False
3974
+ try:
3975
+ os.rmdir(self.path_data)
3976
+ except OSError:
3977
+ pass
3978
+ shutil.rmtree(path=Path(self._path_attr), ignore_errors=True)
3979
+ try:
3980
+ os.rmdir(self.path_utils)
3981
+ except OSError:
3982
+ pass
3983
+ if remove_info and requires_save:
3984
+ # Remove model failure info artifacts
3985
+ self._models_failed_to_train_errors = dict()
3986
+ models = self.get_model_names()
3987
+ for model in models:
3988
+ model = self.load_model(model)
3989
+ model.reduce_memory_size(
3990
+ remove_fit_stack=remove_fit_stack,
3991
+ remove_fit=remove_fit,
3992
+ remove_info=remove_info,
3993
+ requires_save=requires_save,
3994
+ reduce_children=reduce_children,
3995
+ **kwargs,
3996
+ )
3997
+ if requires_save:
3998
+ self.save_model(model, reduce_memory=False)
3999
+ if requires_save:
4000
+ self.save()
4001
+
4002
+ # TODO: Also enable deletion of models which didn't succeed in training (files may still be persisted)
4003
+ # This includes the original HPO fold for stacking
4004
+ # Deletes specified models from trainer and from disk (if delete_from_disk=True).
4005
+ def delete_models(self, models_to_keep=None, models_to_delete=None, allow_delete_cascade=False, delete_from_disk=True, dry_run=True):
4006
+ if models_to_keep is not None and models_to_delete is not None:
4007
+ raise ValueError("Exactly one of [models_to_keep, models_to_delete] must be set.")
4008
+ if models_to_keep is not None:
4009
+ if not isinstance(models_to_keep, list):
4010
+ models_to_keep = [models_to_keep]
4011
+ minimum_model_set = set()
4012
+ for model in models_to_keep:
4013
+ minimum_model_set.update(self.get_minimum_model_set(model))
4014
+ minimum_model_set = list(minimum_model_set)
4015
+ models_to_remove = [model for model in self.get_model_names() if model not in minimum_model_set]
4016
+ elif models_to_delete is not None:
4017
+ if not isinstance(models_to_delete, list):
4018
+ models_to_delete = [models_to_delete]
4019
+ minimum_model_set = set(models_to_delete)
4020
+ minimum_model_set_orig = copy.deepcopy(minimum_model_set)
4021
+ for model in models_to_delete:
4022
+ minimum_model_set.update(nx.algorithms.dag.descendants(self.model_graph, model))
4023
+ if not allow_delete_cascade:
4024
+ if minimum_model_set != minimum_model_set_orig:
4025
+ raise AssertionError(
4026
+ "models_to_delete contains models which cause a delete cascade due to other models being dependent on them. Set allow_delete_cascade=True to enable the deletion."
4027
+ )
4028
+ minimum_model_set = list(minimum_model_set)
4029
+ models_to_remove = [model for model in self.get_model_names() if model in minimum_model_set]
4030
+ else:
4031
+ raise ValueError("Exactly one of [models_to_keep, models_to_delete] must be set.")
4032
+
4033
+ if dry_run:
4034
+ logger.log(30, f"Dry run enabled, AutoGluon would have deleted the following models: {models_to_remove}")
4035
+ if delete_from_disk:
4036
+ for model in models_to_remove:
4037
+ model = self.load_model(model)
4038
+ logger.log(30, f"\tDirectory {model.path} would have been deleted.")
4039
+ logger.log(30, "To perform the deletion, set dry_run=False")
4040
+ return
4041
+
4042
+ if delete_from_disk:
4043
+ for model in models_to_remove:
4044
+ model = self.load_model(model)
4045
+ model.delete_from_disk()
4046
+
4047
+ for model in models_to_remove:
4048
+ self._delete_model_from_graph(model=model)
4049
+
4050
+ models_kept = self.get_model_names()
4051
+
4052
+ if self.model_best is not None and self.model_best not in models_kept:
4053
+ try:
4054
+ self.model_best = self.get_model_best()
4055
+ except AssertionError:
4056
+ self.model_best = None
4057
+
4058
+ # TODO: Delete from all the other model dicts
4059
+ self.save()
4060
+
4061
+ def _delete_model_from_graph(self, model: str):
4062
+ self.model_graph.remove_node(model)
4063
+ if model in self.models:
4064
+ self.models.pop(model)
4065
+ path_attr_model = Path(self._path_attr_model(model))
4066
+ shutil.rmtree(path=path_attr_model, ignore_errors=True)
4067
+
4068
+ @staticmethod
4069
+ def _process_hyperparameters(hyperparameters: dict) -> dict:
4070
+ return process_hyperparameters(hyperparameters=hyperparameters)
4071
+
4072
+ def distill(
4073
+ self,
4074
+ X=None,
4075
+ y=None,
4076
+ X_val=None,
4077
+ y_val=None,
4078
+ X_unlabeled=None,
4079
+ time_limit=None,
4080
+ hyperparameters=None,
4081
+ holdout_frac=None,
4082
+ verbosity=None,
4083
+ models_name_suffix=None,
4084
+ teacher=None,
4085
+ teacher_preds="soft",
4086
+ augmentation_data=None,
4087
+ augment_method="spunge",
4088
+ augment_args={"size_factor": 5, "max_size": int(1e5)},
4089
+ augmented_sample_weight=1.0,
4090
+ ):
4091
+ """Various distillation algorithms.
4092
+ Args:
4093
+ X, y: pd.DataFrame and pd.Series of training data.
4094
+ If None, original training data used during predictor.fit() will be loaded.
4095
+ This data is split into train/validation if X_val, y_val are None.
4096
+ X_val, y_val: pd.DataFrame and pd.Series of validation data.
4097
+ time_limit, hyperparameters, holdout_frac: defined as in predictor.fit()
4098
+ teacher (None or str):
4099
+ If None, uses the model with the highest validation score as the teacher model, otherwise use the specified model name as the teacher.
4100
+ teacher_preds (None or str): If None, we only train with original labels (no data augmentation, overrides augment_method)
4101
+ If 'hard', labels are hard teacher predictions given by: teacher.predict()
4102
+ If 'soft', labels are soft teacher predictions given by: teacher.predict_proba()
4103
+ Note: 'hard' and 'soft' are equivalent for regression problems.
4104
+ If augment_method specified, teacher predictions are only used to label augmented data (training data keeps original labels).
4105
+ To apply label-smoothing: teacher_preds='onehot' will use original training data labels converted to one-hots for multiclass (no data augmentation). # TODO: expose smoothing-hyperparameter.
4106
+ models_name_suffix (str): Suffix to append to each student model's name, new names will look like: 'MODELNAME_dstl_SUFFIX'
4107
+ augmentation_data: pd.DataFrame of additional data to use as "augmented data" (does not contain labels).
4108
+ When specified, augment_method, augment_args are ignored, and this is the only augmented data that is used (teacher_preds cannot be None).
4109
+ augment_method (None or str): specifies which augmentation strategy to utilize. Options: [None, 'spunge','munge']
4110
+ If None, no augmentation gets applied.
4111
+ }
4112
+ augment_args (dict): args passed into the augmentation function corresponding to augment_method.
4113
+ augmented_sample_weight (float): Nonnegative value indicating how much to weight augmented samples. This is only considered if sample_weight was initially specified in Predictor.
4114
+ """
4115
+ if verbosity is None:
4116
+ verbosity = self.verbosity
4117
+
4118
+ if teacher is None:
4119
+ teacher = self._get_best()
4120
+
4121
+ hyperparameter_tune = False # TODO: add as argument with scheduler options.
4122
+ if augmentation_data is not None and teacher_preds is None:
4123
+ raise ValueError("augmentation_data must be None if teacher_preds is None")
4124
+
4125
+ logger.log(20, f"Distilling with teacher='{teacher}', teacher_preds={str(teacher_preds)}, augment_method={str(augment_method)} ...")
4126
+ if teacher not in self.get_model_names(can_infer=True):
4127
+ raise AssertionError(
4128
+ f"Teacher model '{teacher}' is not a valid teacher model! Either it does not exist or it cannot infer on new data.\n"
4129
+ f"Valid teacher models: {self.get_model_names(can_infer=True)}"
4130
+ )
4131
+ if X is None:
4132
+ if y is not None:
4133
+ raise ValueError("X cannot be None when y specified.")
4134
+ X = self.load_X()
4135
+ X_val = self.load_X_val()
4136
+
4137
+ if y is None:
4138
+ y = self.load_y()
4139
+ y_val = self.load_y_val()
4140
+
4141
+ if X_val is None:
4142
+ if y_val is not None:
4143
+ raise ValueError("X_val cannot be None when y_val specified.")
4144
+ if holdout_frac is None:
4145
+ holdout_frac = default_holdout_frac(len(X), hyperparameter_tune)
4146
+ X, X_val, y, y_val = generate_train_test_split(X, y, problem_type=self.problem_type, test_size=holdout_frac)
4147
+
4148
+ y_val_og = y_val.copy()
4149
+ og_bagged_mode = self.bagged_mode
4150
+ og_verbosity = self.verbosity
4151
+ self.bagged_mode = False # turn off bagging
4152
+ self.verbosity = verbosity # change verbosity for distillation
4153
+
4154
+ if self.sample_weight is not None:
4155
+ X, w = extract_column(X, self.sample_weight)
4156
+
4157
+ if teacher_preds is None or teacher_preds == "onehot":
4158
+ augment_method = None
4159
+ logger.log(
4160
+ 20, "Training students without a teacher model. Set teacher_preds = 'soft' or 'hard' to distill using the best AutoGluon predictor as teacher."
4161
+ )
4162
+
4163
+ if teacher_preds in ["onehot", "soft"]:
4164
+ y = format_distillation_labels(y, self.problem_type, self.num_classes)
4165
+ y_val = format_distillation_labels(y_val, self.problem_type, self.num_classes)
4166
+
4167
+ if augment_method is None and augmentation_data is None:
4168
+ if teacher_preds == "hard":
4169
+ y_pred = pd.Series(self.predict(X, model=teacher))
4170
+ if (self.problem_type != REGRESSION) and (len(y_pred.unique()) < len(y.unique())): # add missing labels
4171
+ logger.log(15, "Adding missing labels to distillation dataset by including some real training examples")
4172
+ indices_to_add = []
4173
+ for clss in y.unique():
4174
+ if clss not in y_pred.unique():
4175
+ logger.log(15, f"Fetching a row with label={clss} from training data")
4176
+ clss_index = y[y == clss].index[0]
4177
+ indices_to_add.append(clss_index)
4178
+ X_extra = X.loc[indices_to_add].copy()
4179
+ y_extra = y.loc[indices_to_add].copy() # these are actually real training examples
4180
+ X = pd.concat([X, X_extra])
4181
+ y_pred = pd.concat([y_pred, y_extra])
4182
+ if self.sample_weight is not None:
4183
+ w = pd.concat([w, w[indices_to_add]])
4184
+ y = y_pred
4185
+ elif teacher_preds == "soft":
4186
+ y = self.predict_proba(X, model=teacher)
4187
+ if self.problem_type == MULTICLASS:
4188
+ y = pd.DataFrame(y)
4189
+ else:
4190
+ y = pd.Series(y)
4191
+ else:
4192
+ X_aug = augment_data(
4193
+ X=X, feature_metadata=self.feature_metadata, augmentation_data=augmentation_data, augment_method=augment_method, augment_args=augment_args
4194
+ )
4195
+ if len(X_aug) > 0:
4196
+ if teacher_preds == "hard":
4197
+ y_aug = pd.Series(self.predict(X_aug, model=teacher))
4198
+ elif teacher_preds == "soft":
4199
+ y_aug = self.predict_proba(X_aug, model=teacher)
4200
+ if self.problem_type == MULTICLASS:
4201
+ y_aug = pd.DataFrame(y_aug)
4202
+ else:
4203
+ y_aug = pd.Series(y_aug)
4204
+ else:
4205
+ raise ValueError(f"Unknown teacher_preds specified: {teacher_preds}")
4206
+
4207
+ X = pd.concat([X, X_aug])
4208
+ y = pd.concat([y, y_aug])
4209
+ if self.sample_weight is not None:
4210
+ w = pd.concat([w, pd.Series([augmented_sample_weight] * len(X_aug))])
4211
+
4212
+ X.reset_index(drop=True, inplace=True)
4213
+ y.reset_index(drop=True, inplace=True)
4214
+ if self.sample_weight is not None:
4215
+ w.reset_index(drop=True, inplace=True)
4216
+ X[self.sample_weight] = w
4217
+
4218
+ name_suffix = "_DSTL" # all student model names contain this substring
4219
+ if models_name_suffix is not None:
4220
+ name_suffix = name_suffix + "_" + models_name_suffix
4221
+
4222
+ if hyperparameters is None:
4223
+ hyperparameters = {"GBM": {}, "CAT": {}, "NN_TORCH": {}, "RF": {}}
4224
+ hyperparameters = self._process_hyperparameters(
4225
+ hyperparameters=hyperparameters
4226
+ ) # TODO: consider exposing ag_args_fit, excluded_model_types as distill() arguments.
4227
+ if teacher_preds is not None and teacher_preds != "hard" and self.problem_type != REGRESSION:
4228
+ self._regress_preds_asprobas = True
4229
+
4230
+ core_kwargs = {
4231
+ "stack_name": self.distill_stackname,
4232
+ "get_models_func": self.construct_model_templates_distillation,
4233
+ }
4234
+ aux_kwargs = {
4235
+ "get_models_func": self.construct_model_templates_distillation,
4236
+ "check_if_best": False,
4237
+ }
4238
+
4239
+ # self.bagged_mode = True # TODO: Add options for bagging
4240
+ models = self.train_multi_levels(
4241
+ X=X,
4242
+ y=y,
4243
+ X_val=X_val,
4244
+ y_val=y_val,
4245
+ hyperparameters=hyperparameters,
4246
+ time_limit=time_limit, # FIXME: Also limit augmentation time
4247
+ name_suffix=name_suffix,
4248
+ core_kwargs=core_kwargs,
4249
+ aux_kwargs=aux_kwargs,
4250
+ )
4251
+
4252
+ distilled_model_names = []
4253
+ w_val = None
4254
+ if self.weight_evaluation:
4255
+ X_val, w_val = extract_column(X_val, self.sample_weight)
4256
+ for model_name in models: # finally measure original metric on validation data and overwrite stored val_scores
4257
+ model_score = self.score(X_val, y_val_og, model=model_name, weights=w_val)
4258
+ model_obj = self.load_model(model_name)
4259
+ model_obj.val_score = model_score
4260
+ model_obj.save() # TODO: consider omitting for sake of efficiency
4261
+ self.model_graph.nodes[model_name]["val_score"] = model_score
4262
+ distilled_model_names.append(model_name)
4263
+ leaderboard = self.leaderboard()
4264
+ logger.log(20, "Distilled model leaderboard:")
4265
+ leaderboard_distilled = leaderboard[leaderboard["model"].isin(models)].reset_index(drop=True)
4266
+ with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000):
4267
+ logger.log(20, leaderboard_distilled)
4268
+
4269
+ # reset trainer to old state before distill() was called:
4270
+ self.bagged_mode = og_bagged_mode # TODO: Confirm if safe to train future models after training models in both bagged and non-bagged modes
4271
+ self.verbosity = og_verbosity
4272
+ return distilled_model_names
4273
+
4274
+ def _get_model_fit_kwargs(
4275
+ self, X: pd.DataFrame, X_val: pd.DataFrame, time_limit: float, k_fold: int, fit_kwargs: dict, ens_sample_weight: list | None = None
4276
+ ) -> dict:
4277
+ # Returns kwargs to be passed to AbstractModel's fit function
4278
+ if fit_kwargs is None:
4279
+ fit_kwargs = dict()
4280
+
4281
+ model_fit_kwargs = dict(time_limit=time_limit, verbosity=self.verbosity, **fit_kwargs)
4282
+ if self.sample_weight is not None:
4283
+ X, w_train = extract_column(X, self.sample_weight)
4284
+ if w_train is not None: # may be None for ensemble
4285
+ # TODO: consider moving weight normalization into AbstractModel.fit()
4286
+ model_fit_kwargs["sample_weight"] = w_train.values / w_train.mean() # normalization can affect gradient algorithms like boosting
4287
+ if X_val is not None:
4288
+ X_val, w_val = extract_column(X_val, self.sample_weight)
4289
+ if self.weight_evaluation and w_val is not None: # ignore validation sample weights unless weight_evaluation specified
4290
+ model_fit_kwargs["sample_weight_val"] = w_val.values / w_val.mean()
4291
+ if ens_sample_weight is not None:
4292
+ model_fit_kwargs["sample_weight"] = ens_sample_weight # sample weights to use for weighted ensemble only
4293
+ if self._groups is not None and "groups" not in model_fit_kwargs:
4294
+ if k_fold == self.k_fold: # don't do this on refit full
4295
+ model_fit_kwargs["groups"] = self._groups
4296
+
4297
+ # FIXME: Sample weight `extract_column` is a hack, have to compute feature_metadata here because sample weight column could be in X upstream, extract sample weight column upstream instead.
4298
+ if "feature_metadata" not in model_fit_kwargs:
4299
+ raise AssertionError(f"Missing expected parameter 'feature_metadata'.")
4300
+ return model_fit_kwargs
4301
+
4302
+ def _get_bagged_model_fit_kwargs(self, k_fold: int, k_fold_start: int, k_fold_end: int, n_repeats: int, n_repeat_start: int) -> dict:
4303
+ # Returns additional kwargs (aside from _get_model_fit_kwargs) to be passed to BaggedEnsembleModel's fit function
4304
+ if k_fold is None:
4305
+ k_fold = self.k_fold
4306
+ if n_repeats is None:
4307
+ n_repeats = self.n_repeats
4308
+ return dict(
4309
+ k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start, compute_base_preds=False
4310
+ )
4311
+
4312
+ def _get_feature_prune_proxy_model(self, proxy_model_class: AbstractModel | None, level: int) -> AbstractModel:
4313
+ """
4314
+ Returns proxy model to be used for feature pruning - the base learner that has the highest validation score in a particular stack layer.
4315
+ Ties are broken by inference speed. If proxy_model_class is not None, take the best base learner belonging to proxy_model_class.
4316
+ proxy_model_class is an AbstractModel class (ex. LGBModel).
4317
+ """
4318
+ proxy_model = None
4319
+ if isinstance(proxy_model_class, str):
4320
+ raise AssertionError(f"proxy_model_class must be a subclass of AbstractModel. Was instead a string: {proxy_model_class}")
4321
+ banned_models = [GreedyWeightedEnsembleModel, SimpleWeightedEnsembleModel]
4322
+ assert proxy_model_class not in banned_models, "WeightedEnsemble models cannot be feature pruning proxy models."
4323
+
4324
+ leaderboard = self.leaderboard()
4325
+ banned_names = []
4326
+ candidate_model_rows = leaderboard[(~leaderboard["score_val"].isna()) & (leaderboard["stack_level"] == level)]
4327
+ candidate_models_type_inner = self.get_models_attribute_dict(attribute="type_inner", models=candidate_model_rows["model"])
4328
+ for model_name, type_inner in candidate_models_type_inner.copy().items():
4329
+ if type_inner in banned_models:
4330
+ banned_names.append(model_name)
4331
+ candidate_models_type_inner.pop(model_name, None)
4332
+ banned_names = set(banned_names)
4333
+ candidate_model_rows = candidate_model_rows[~candidate_model_rows["model"].isin(banned_names)]
4334
+ if proxy_model_class is not None:
4335
+ candidate_model_names = [model_name for model_name, model_class in candidate_models_type_inner.items() if model_class == proxy_model_class]
4336
+ candidate_model_rows = candidate_model_rows[candidate_model_rows["model"].isin(candidate_model_names)]
4337
+ if len(candidate_model_rows) == 0:
4338
+ if proxy_model_class is None:
4339
+ logger.warning(f"No models from level {level} have been successfully fit. Skipping feature pruning.")
4340
+ else:
4341
+ logger.warning(f"No models of type {proxy_model_class} have finished training in level {level}. Skipping feature pruning.")
4342
+ return proxy_model
4343
+ best_candidate_model_rows = candidate_model_rows.loc[candidate_model_rows["score_val"] == candidate_model_rows["score_val"].max()]
4344
+ return self.load_model(best_candidate_model_rows.loc[best_candidate_model_rows["fit_time"].idxmin()]["model"])
4345
+
4346
+ def calibrate_model(self, model_name: str | None = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0):
4347
+ """
4348
+ Applies temperature scaling to a model.
4349
+ Applies inverse softmax to predicted probs then trains temperature scalar
4350
+ on validation data to maximize negative log likelihood.
4351
+ Inversed softmaxes are divided by temperature scalar
4352
+ then softmaxed to return predicted probs.
4353
+
4354
+ Parameters:
4355
+ -----------
4356
+ model_name: str: default = None
4357
+ model name to tune temperature scaling on.
4358
+ If None, will tune best model only. Best model chosen by validation score
4359
+ lr: float: default = 0.1
4360
+ The learning rate for temperature scaling algorithm
4361
+ max_iter: int: default = 200
4362
+ Number of iterations optimizer should take for
4363
+ tuning temperature scaler
4364
+ init_val: float: default = 1.0
4365
+ The initial value for temperature scalar term
4366
+ """
4367
+ # TODO: Note that temperature scaling is known to worsen calibration in the face of shifted test data.
4368
+ try:
4369
+ # FIXME: Avoid depending on torch for temp scaling
4370
+ try_import_torch()
4371
+ except ImportError:
4372
+ logger.log(30, "Warning: Torch is not installed, skipping calibration step...")
4373
+ return
4374
+
4375
+ if model_name is None:
4376
+ if self.has_val:
4377
+ can_infer = True
4378
+ else:
4379
+ can_infer = None
4380
+ if self.model_best is not None:
4381
+ models = self.get_model_names(can_infer=can_infer)
4382
+ if self.model_best in models:
4383
+ model_name = self.model_best
4384
+ if model_name is None:
4385
+ model_name = self.get_model_best(can_infer=can_infer)
4386
+
4387
+ model_refit_map = self.model_refit_map()
4388
+ model_name_og = model_name
4389
+ for m, m_full in model_refit_map.items():
4390
+ if m_full == model_name:
4391
+ model_name_og = m
4392
+ break
4393
+ if self.has_val:
4394
+ X_val = self.load_X_val()
4395
+ y_val_probs = self.predict_proba(X_val, model_name_og)
4396
+ y_val = self.load_y_val().to_numpy()
4397
+ else: # bagged mode
4398
+ y_val_probs = self.get_model_oof(model_name_og)
4399
+ y_val = self.load_y().to_numpy()
4400
+
4401
+ y_val_probs_og = y_val_probs
4402
+ if self.problem_type == BINARY:
4403
+ # Convert one-dimensional array to be in the form of a 2-class multiclass predict_proba output
4404
+ y_val_probs = LabelCleanerMulticlassToBinary.convert_binary_proba_to_multiclass_proba(y_val_probs)
4405
+
4406
+ model = self.load_model(model_name=model_name)
4407
+ if self.problem_type == QUANTILE:
4408
+ logger.log(15, f"Conformity scores being computed to calibrate model: {model_name}")
4409
+ conformalize = compute_conformity_score(y_val_pred=y_val_probs, y_val=y_val, quantile_levels=self.quantile_levels)
4410
+ model.conformalize = conformalize
4411
+ model.save()
4412
+ else:
4413
+ logger.log(15, f"Temperature scaling term being tuned for model: {model_name}")
4414
+ temp_scalar = tune_temperature_scaling(y_val_probs=y_val_probs, y_val=y_val, init_val=init_val, max_iter=max_iter, lr=lr)
4415
+ if temp_scalar is None:
4416
+ logger.log(
4417
+ 15,
4418
+ f"Warning: Infinity found during calibration, skipping calibration on {model.name}! "
4419
+ f"This can occur when the model is absolutely certain of a validation prediction (1.0 pred_proba).",
4420
+ )
4421
+ elif temp_scalar <= 0:
4422
+ logger.log(
4423
+ 30,
4424
+ f"Warning: Temperature scaling found optimal at a negative value ({temp_scalar}). Disabling temperature scaling to avoid overfitting.",
4425
+ )
4426
+ else:
4427
+ # Check that scaling improves performance for the target metric
4428
+ score_without_temp = self.score_with_y_pred_proba(y=y_val, y_pred_proba=y_val_probs_og, weights=None)
4429
+ scaled_y_val_probs = apply_temperature_scaling(y_val_probs, temp_scalar, problem_type=self.problem_type, transform_binary_proba=False)
4430
+ score_with_temp = self.score_with_y_pred_proba(y=y_val, y_pred_proba=scaled_y_val_probs, weights=None)
4431
+
4432
+ if score_with_temp > score_without_temp:
4433
+ logger.log(15, f"Temperature term found is: {temp_scalar}")
4434
+ model.params_aux["temperature_scalar"] = temp_scalar
4435
+ model.save()
4436
+ else:
4437
+ logger.log(15, "Temperature did not improve performance, skipping calibration.")
4438
+
4439
+ def calibrate_decision_threshold(
4440
+ self,
4441
+ X: pd.DataFrame | None = None,
4442
+ y: np.ndarray | None = None,
4443
+ metric: str | Scorer | None = None,
4444
+ model: str = "best",
4445
+ weights=None,
4446
+ decision_thresholds: int | list[float] = 25,
4447
+ secondary_decision_thresholds: int | None = 19,
4448
+ verbose: bool = True,
4449
+ **kwargs,
4450
+ ) -> float:
4451
+ # TODO: Docstring
4452
+ assert self.problem_type == BINARY, f'calibrate_decision_threshold is only available for `problem_type="{BINARY}"`'
4453
+
4454
+ if metric is None:
4455
+ metric = self.eval_metric
4456
+ elif isinstance(metric, str):
4457
+ metric = get_metric(metric, self.problem_type, "eval_metric")
4458
+
4459
+ if model == "best":
4460
+ model = self.get_model_best()
4461
+
4462
+ if y is None:
4463
+ # If model is refit_full, use its parent to avoid over-fitting
4464
+ model_parent = self.get_refit_full_parent(model=model)
4465
+ if not self.model_exists(model_parent):
4466
+ raise AssertionError(
4467
+ f"Unable to calibrate the decision threshold on the internal data because the "
4468
+ f'model "{model}" is a refit_full model trained on all of the internal data, '
4469
+ f'whose parent model "{model_parent}" does not exist or was deleted.\n'
4470
+ f"It may have been deleted due to `predictor.fit(..., keep_only_best=True)`. "
4471
+ f"Ensure `keep_only_best=False` to be able to calibrate refit_full models."
4472
+ )
4473
+ model = model_parent
4474
+
4475
+ # TODO: Add helpful logging when data is not available, for example post optimize for deployment
4476
+ if self.has_val:
4477
+ # Use validation data
4478
+ X = self.load_X_val()
4479
+ if self.weight_evaluation:
4480
+ X, weights = extract_column(X=X, col_name=self.sample_weight)
4481
+ y: np.array = self.load_y_val()
4482
+ y_pred_proba = self.predict_proba(X=X, model=model)
4483
+ else:
4484
+ # Use out-of-fold data
4485
+ if self.weight_evaluation:
4486
+ X = self.load_X()
4487
+ X, weights = extract_column(X=X, col_name=self.sample_weight)
4488
+ y: np.array = self.load_y()
4489
+ y_pred_proba = self.get_model_oof(model=model)
4490
+ else:
4491
+ y_pred_proba = self.predict_proba(X=X, model=model)
4492
+
4493
+ if not metric.needs_pred:
4494
+ logger.warning(
4495
+ f'WARNING: The provided metric "{metric.name}" does not use class predictions for scoring, '
4496
+ f"and thus is invalid for decision threshold calibration. "
4497
+ f"Falling back to `decision_threshold=0.5`."
4498
+ )
4499
+ return 0.5
4500
+
4501
+ return calibrate_decision_threshold(
4502
+ y=y,
4503
+ y_pred_proba=y_pred_proba,
4504
+ metric=lambda y, y_pred: self.score_with_y_pred(y=y, y_pred=y_pred, weights=weights, metric=metric),
4505
+ decision_thresholds=decision_thresholds,
4506
+ secondary_decision_thresholds=secondary_decision_thresholds,
4507
+ metric_name=metric.name,
4508
+ verbose=verbose,
4509
+ **kwargs,
4510
+ )
4511
+
4512
+ @staticmethod
4513
+ def _validate_num_classes(num_classes: int | None, problem_type: str):
4514
+ if problem_type == BINARY:
4515
+ assert num_classes is not None and num_classes == 2, f"num_classes must be 2 when problem_type='{problem_type}' (num_classes={num_classes})"
4516
+ elif problem_type in [MULTICLASS, SOFTCLASS]:
4517
+ assert num_classes is not None and num_classes >= 2, f"num_classes must be >=2 when problem_type='{problem_type}' (num_classes={num_classes})"
4518
+ elif problem_type in [REGRESSION, QUANTILE]:
4519
+ assert num_classes is None, f"num_classes must be None when problem_type='{problem_type}' (num_classes={num_classes})"
4520
+ else:
4521
+ raise AssertionError(f"Unknown problem_type: '{problem_type}'. Valid problem types: {[BINARY, MULTICLASS, REGRESSION, SOFTCLASS, QUANTILE]}")
4522
+
4523
+ @staticmethod
4524
+ def _validate_quantile_levels(quantile_levels: list[float] | np.ndarray | None, problem_type: str):
4525
+ if problem_type == QUANTILE:
4526
+ assert quantile_levels is not None, f"quantile_levels must not be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4527
+ assert isinstance(quantile_levels, (list, np.ndarray)), f"quantile_levels must be a list or np.ndarray (quantile_levels={quantile_levels})"
4528
+ assert len(quantile_levels) > 0, f"quantile_levels must not be an empty list (quantile_levels={quantile_levels})"
4529
+ else:
4530
+ assert quantile_levels is None, f"quantile_levels must be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
4531
+
4532
+
4533
+ def _detached_train_multi_fold(
4534
+ *,
4535
+ _self: AbstractTabularTrainer,
4536
+ model: str | AbstractModel,
4537
+ X: pd.DataFrame,
4538
+ y: pd.Series,
4539
+ time_split: bool,
4540
+ time_start: float,
4541
+ time_limit: float | None,
4542
+ time_limit_model_split: float | None,
4543
+ hyperparameter_tune_kwargs: dict,
4544
+ is_ray_worker: bool = False,
4545
+ kwargs: dict,
4546
+ ) -> list[str]:
4547
+ """Dedicated class-detached function to train a single model on multiple folds."""
4548
+ if isinstance(model,str):
4549
+ model = _self.load_model(model)
4550
+ elif _self.low_memory:
4551
+ model = copy.deepcopy(model)
4552
+ if hyperparameter_tune_kwargs is not None and isinstance(hyperparameter_tune_kwargs,dict):
4553
+ hyperparameter_tune_kwargs_model = hyperparameter_tune_kwargs.get(model.name,None)
4554
+ else:
4555
+ hyperparameter_tune_kwargs_model=None
4556
+ # TODO: Only update scores when finished, only update model as part of final models if finished!
4557
+ if time_split:
4558
+ time_left=time_limit_model_split
4559
+ else:
4560
+ if time_limit is None:
4561
+ time_left=None
4562
+ else:
4563
+ time_start_model=time.time()
4564
+ time_left=time_limit-(time_start_model-time_start)
4565
+
4566
+ model_name_trained_lst = _self._train_single_full(
4567
+ X,
4568
+ y,
4569
+ model,
4570
+ time_limit=time_left,
4571
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs_model,
4572
+ is_ray_worker=is_ray_worker,
4573
+ **kwargs
4574
+ )
4575
+
4576
+ if _self.low_memory:
4577
+ del model
4578
+
4579
+ return model_name_trained_lst
4580
+
4581
+
4582
+ def _remote_train_multi_fold(
4583
+ *,
4584
+ _self: AbstractTabularTrainer,
4585
+ model: str | AbstractModel,
4586
+ X: pd.DataFrame,
4587
+ y: pd.Series,
4588
+ time_split: bool,
4589
+ time_start: float,
4590
+ time_limit: float | None,
4591
+ time_limit_model_split: float | None,
4592
+ hyperparameter_tune_kwargs: dict,
4593
+ kwargs: dict,
4594
+ errors: Literal["ignore", "raise"] | None = None,
4595
+ ) -> tuple[str, str | None, str | None, Exception | None, dict | None]:
4596
+ reset_logger_for_remote_call(verbosity=_self.verbosity)
4597
+
4598
+ if errors is not None:
4599
+ kwargs["errors"] = errors
4600
+
4601
+ exception = None
4602
+ try:
4603
+ model_name_list = _detached_train_multi_fold(
4604
+ _self=_self,
4605
+ model=model,
4606
+ X=X,
4607
+ y=y,
4608
+ time_start=time_start,
4609
+ time_split=time_split,
4610
+ time_limit=time_limit,
4611
+ time_limit_model_split=time_limit_model_split,
4612
+ hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
4613
+ is_ray_worker=True,
4614
+ kwargs=kwargs,
4615
+ )
4616
+ except Exception as exc:
4617
+ model_name_list = []
4618
+ if errors is not None and errors == "raise":
4619
+ # If training fails and exception is returned, collect the exception information and return
4620
+ exception = exc # required to use in outer scope
4621
+ else:
4622
+ raise exc
4623
+
4624
+ if not model_name_list:
4625
+ model_name = model if isinstance(model, str) else model.name
4626
+ # Get model_failure metadata if it exists
4627
+ model_failure_info = None
4628
+ if model_name in _self._models_failed_to_train_errors:
4629
+ model_failure_info = _self._models_failed_to_train_errors[model_name]
4630
+ return model_name, None, None, exception, model_failure_info
4631
+
4632
+ # Fallback, return original model name if training failed.
4633
+ if not model_name_list:
4634
+ model_name = model if isinstance(model, str) else model.name
4635
+ return model_name, None, None, None, None
4636
+ model_name = model_name_list[0]
4637
+ return model_name, _self.get_model_attribute(model=model_name, attribute="path"), _self.get_model_attribute(model=model_name, attribute="type"), None, None
4638
+
4639
+
4640
+ def _detached_refit_single_full(
4641
+ *,
4642
+ _self: AbstractTabularTrainer,
4643
+ model: str,
4644
+ X: pd.DataFrame,
4645
+ y: pd.Series,
4646
+ X_val: pd.DataFrame,
4647
+ y_val: pd.Series,
4648
+ X_unlabeled: pd.DataFrame,
4649
+ level: int,
4650
+ kwargs: dict,
4651
+ fit_strategy: Literal["sequential", "parallel"] = "sequential",
4652
+ ) -> tuple[str, list[str]]:
4653
+ # TODO: loading the model is the reasons we must allocate GPU resources for this job in cases where models require GPU when loaded from disk
4654
+ model=_self.load_model(model)
4655
+ model_name = model.name
4656
+ reuse_first_fold = False
4657
+
4658
+ if isinstance(model,BaggedEnsembleModel):
4659
+ # Reuse if model is already _FULL and no X_val
4660
+ if X_val is None:
4661
+ reuse_first_fold = not model._bagged_mode
4662
+
4663
+ if not reuse_first_fold:
4664
+ if isinstance(model,BaggedEnsembleModel):
4665
+ can_refit_full=model._get_tags_child().get("can_refit_full",False)
4666
+ else:
4667
+ can_refit_full=model._get_tags().get("can_refit_full",False)
4668
+ reuse_first_fold = not can_refit_full
4669
+
4670
+ if not reuse_first_fold:
4671
+ model_full=model.convert_to_refit_full_template()
4672
+ # Mitigates situation where bagged models barely had enough memory and refit requires more. Worst case results in OOM, but this lowers chance of failure.
4673
+ model_full._user_params_aux["max_memory_usage_ratio"]=model.params_aux["max_memory_usage_ratio"]*1.15
4674
+ # Re-set user specified training resources.
4675
+ # FIXME: this is technically also a bug for non-distributed mode, but there it is good to use more/all resources per refit.
4676
+ # FIXME: Unsure if it is better to do model.fit_num_cpus or model.fit_num_cpus_child,
4677
+ # (Nick): I'm currently leaning towards model.fit_num_cpus, it is also less memory intensive
4678
+ # Better to not specify this for sequential fits, since we want the models to use the optimal amount of resources,
4679
+ # which could be less than the available resources (ex: LightGBM fits faster using 50% of the cores)
4680
+ if fit_strategy == "parallel":
4681
+ # FIXME: Why use `model.fit_num_cpus_child` when we can use the same values as was passed to `ray` for the process, just pass those values as kwargs. Eliminates chance of inconsistency.
4682
+ if model.fit_num_cpus_child is not None:
4683
+ model_full._user_params_aux["num_cpus"] = model.fit_num_cpus_child
4684
+ if model.fit_num_gpus_child is not None:
4685
+ model_full._user_params_aux["num_gpus"] = model.fit_num_gpus_child
4686
+ # TODO: Do it for all models in the level at once to avoid repeated processing of data?
4687
+ base_model_names=_self.get_base_model_names(model_name)
4688
+ # FIXME: Logs for inference speed (1 row) are incorrect because
4689
+ # parents are non-refit models in this sequence and later correct after logging.
4690
+ # Avoiding fix at present to minimize hacks in the code.
4691
+ # Return to this later when Trainer controls all stacking logic to map correct parent.
4692
+ models_trained = _self.stack_new_level_core(
4693
+ X=X,
4694
+ y=y,
4695
+ X_val=X_val,
4696
+ y_val=y_val,
4697
+ X_unlabeled=X_unlabeled,
4698
+ models=[model_full],
4699
+ base_model_names=base_model_names,
4700
+ level=level,
4701
+ stack_name=REFIT_FULL_NAME,
4702
+ hyperparameter_tune_kwargs=None,
4703
+ feature_prune=False,
4704
+ k_fold=0,
4705
+ n_repeats=1,
4706
+ ensemble_type=type(model),
4707
+ refit_full=True,
4708
+ **kwargs,
4709
+ )
4710
+ if len(models_trained)==0:
4711
+ reuse_first_fold=True
4712
+ logger.log(30,f"WARNING: Refit training failure detected for '{model_name}'... "
4713
+ f"Falling back to using first fold to avoid downstream exception."
4714
+ f"\n\tThis is likely due to an out-of-memory error or other memory related issue. "
4715
+ f"\n\tPlease create a GitHub issue if this was triggered from a non-memory related problem.",)
4716
+ if not model.params.get("save_bag_folds",True):
4717
+ raise AssertionError(f"Cannot avoid training failure during refit for '{model_name}' by falling back to "
4718
+ f"copying the first fold because it does not exist! (save_bag_folds=False)"
4719
+ f"\n\tPlease specify `save_bag_folds=True` in the `.fit` call to avoid this exception.")
4720
+
4721
+ if reuse_first_fold:
4722
+ # Perform fallback black-box refit logic that doesn't retrain.
4723
+ model_full=model.convert_to_refit_full_via_copy()
4724
+ # FIXME: validation time not correct for infer 1 batch time, needed to hack _is_refit=True to fix
4725
+ logger.log(20,f"Fitting model: {model_full.name} | Skipping fit via cloning parent ...")
4726
+ _self._add_model(model_full,stack_name=REFIT_FULL_NAME,level=level,_is_refit=True)
4727
+ _self.save_model(model_full)
4728
+ models_trained=[model_full.name]
4729
+
4730
+ return model_name, models_trained
4731
+
4732
+
4733
+ def _remote_refit_single_full(
4734
+ *,
4735
+ _self: AbstractTabularTrainer,
4736
+ model: str,
4737
+ X: pd.DataFrame,
4738
+ y: pd.Series,
4739
+ X_val: pd.DataFrame,
4740
+ y_val: pd.Series,
4741
+ X_unlabeled: pd.DataFrame,
4742
+ level: int,
4743
+ kwargs: dict,
4744
+ fit_strategy: Literal["sequential", "parallel"],
4745
+ ) -> tuple[str, str, list[str], str, str]:
4746
+ reset_logger_for_remote_call(verbosity=_self.verbosity)
4747
+
4748
+ model_name, models_trained = _detached_refit_single_full(
4749
+ _self=_self,
4750
+ model=model,
4751
+ X=X,
4752
+ y=y,
4753
+ X_val=X_val,
4754
+ y_val=y_val,
4755
+ X_unlabeled=X_unlabeled,
4756
+ level=level,
4757
+ kwargs=kwargs,
4758
+ fit_strategy=fit_strategy,
4759
+ )
4760
+
4761
+ # We always just refit one model per call, so this must be the case.
4762
+ assert len(models_trained) == 1
4763
+ refitted_model_name = models_trained[0]
4764
+ return model_name, refitted_model_name, _self.get_model_attribute(model=refitted_model_name,attribute="path"),_self.get_model_attribute(model=refitted_model_name, attribute="type")