autogluon.core 1.2.1b20250115__py3-none-any.whl → 1.2.1b20250130__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/core/_setup_utils.py +2 -2
- autogluon/core/learner/abstract_learner.py +1 -1
- autogluon/core/trainer/__init__.py +2 -0
- autogluon/core/trainer/abstract_trainer.py +380 -289
- autogluon/core/version.py +2 -1
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/METADATA +7 -5
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/RECORD +14 -14
- /autogluon.core-1.2.1b20250115-py3.8-nspkg.pth → /autogluon.core-1.2.1b20250130-py3.9-nspkg.pth +0 -0
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/LICENSE +0 -0
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/NOTICE +0 -0
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/WHEEL +0 -0
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/namespace_packages.txt +0 -0
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/top_level.txt +0 -0
- {autogluon.core-1.2.1b20250115.dist-info → autogluon.core-1.2.1b20250130.dist-info}/zip-safe +0 -0
@@ -4,39 +4,43 @@ import copy
|
|
4
4
|
import logging
|
5
5
|
import os
|
6
6
|
import shutil
|
7
|
-
import sys
|
8
7
|
import time
|
9
8
|
import traceback
|
10
|
-
import typing
|
11
9
|
from collections import defaultdict
|
12
10
|
from pathlib import Path
|
13
|
-
from typing import Any,
|
11
|
+
from typing import Any, Generic, Literal, Optional, Type, TypeVar
|
14
12
|
|
15
13
|
import networkx as nx
|
16
14
|
import numpy as np
|
17
15
|
import pandas as pd
|
16
|
+
from typing_extensions import Self
|
18
17
|
|
19
18
|
from autogluon.common.features.feature_metadata import FeatureMetadata
|
20
19
|
from autogluon.common.features.types import R_FLOAT, S_STACK
|
21
20
|
from autogluon.common.utils.distribute_utils import DistributedContext
|
22
21
|
from autogluon.common.utils.lite import disable_if_lite_mode
|
23
22
|
from autogluon.common.utils.log_utils import convert_time_in_s_to_log_friendly, reset_logger_for_remote_call
|
24
|
-
from autogluon.common.utils.path_converter import PathConverter
|
25
23
|
from autogluon.common.utils.resource_utils import ResourceManager, get_resource_manager
|
26
24
|
from autogluon.common.utils.try_import import try_import_ray, try_import_torch
|
27
|
-
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
32
|
-
from
|
33
|
-
from
|
34
|
-
from
|
35
|
-
from
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
25
|
+
from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels
|
26
|
+
from autogluon.core.calibrate import calibrate_decision_threshold
|
27
|
+
from autogluon.core.calibrate.conformity_score import compute_conformity_score
|
28
|
+
from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling
|
29
|
+
from autogluon.core.callbacks import AbstractCallback
|
30
|
+
from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS
|
31
|
+
from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary
|
32
|
+
from autogluon.core.metrics import Scorer, compute_metric, get_metric
|
33
|
+
from autogluon.core.models import (
|
34
|
+
AbstractModel,
|
35
|
+
BaggedEnsembleModel,
|
36
|
+
GreedyWeightedEnsembleModel,
|
37
|
+
SimpleWeightedEnsembleModel,
|
38
|
+
StackerEnsembleModel,
|
39
|
+
WeightedEnsembleModel,
|
40
|
+
)
|
41
|
+
from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match
|
42
|
+
from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager
|
43
|
+
from autogluon.core.utils import (
|
40
44
|
compute_permutation_feature_importance,
|
41
45
|
convert_pred_probas_to_df,
|
42
46
|
default_holdout_frac,
|
@@ -45,18 +49,208 @@ from ..utils import (
|
|
45
49
|
get_pred_from_proba,
|
46
50
|
infer_eval_metric,
|
47
51
|
)
|
48
|
-
from
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
+
from autogluon.core.utils.exceptions import (
|
53
|
+
InsufficientTime,
|
54
|
+
NoGPUError,
|
55
|
+
NoStackFeatures,
|
56
|
+
NotEnoughCudaMemoryError,
|
57
|
+
NotEnoughMemoryError,
|
58
|
+
NotValidStacker,
|
59
|
+
NoValidFeatures,
|
60
|
+
TimeLimitExceeded,
|
61
|
+
)
|
62
|
+
from autogluon.core.utils.feature_selection import FeatureSelector
|
63
|
+
from autogluon.core.utils.loaders import load_pkl
|
64
|
+
from autogluon.core.utils.savers import save_json, save_pkl
|
65
|
+
|
52
66
|
from .utils import process_hyperparameters
|
53
67
|
|
54
68
|
logger = logging.getLogger(__name__)
|
55
69
|
|
56
70
|
|
57
|
-
|
71
|
+
ModelTypeT = TypeVar("ModelTypeT", bound=AbstractModel)
|
72
|
+
|
73
|
+
|
74
|
+
class AbstractTrainer(Generic[ModelTypeT]):
|
75
|
+
trainer_file_name = "trainer.pkl"
|
76
|
+
trainer_info_name = "info.pkl"
|
77
|
+
trainer_info_json_name = "info.json"
|
78
|
+
|
79
|
+
def __init__(self, path: str, *, low_memory: bool, save_data: bool):
|
80
|
+
self.path = path
|
81
|
+
self.reset_paths = False
|
82
|
+
|
83
|
+
self.low_memory: bool = low_memory
|
84
|
+
self.save_data: bool = save_data
|
85
|
+
|
86
|
+
self.models: dict[str, Any] = {}
|
87
|
+
self.model_graph = nx.DiGraph()
|
88
|
+
self.model_best: str | None = None
|
89
|
+
|
90
|
+
self._extra_banned_names: set[str] = set()
|
91
|
+
|
92
|
+
def _get_banned_model_names(self) -> list[str]:
|
93
|
+
"""Gets all model names which would cause model files to be overwritten if a new model
|
94
|
+
was trained with the name
|
95
|
+
"""
|
96
|
+
return self.get_model_names() + list(self._extra_banned_names)
|
97
|
+
|
98
|
+
@property
|
99
|
+
def path_root(self) -> str:
|
100
|
+
"""directory containing learner.pkl"""
|
101
|
+
return os.path.dirname(self.path)
|
102
|
+
|
103
|
+
@property
|
104
|
+
def path_utils(self) -> str:
|
105
|
+
return os.path.join(self.path_root, "utils")
|
106
|
+
|
107
|
+
@property
|
108
|
+
def path_data(self) -> str:
|
109
|
+
return os.path.join(self.path_utils, "data")
|
110
|
+
|
111
|
+
def set_contexts(self, path_context: str) -> None:
|
112
|
+
self.path = self.create_contexts(path_context)
|
113
|
+
|
114
|
+
def create_contexts(self, path_context: str) -> str:
|
115
|
+
path = path_context
|
116
|
+
return path
|
117
|
+
|
118
|
+
def save_model(self, model: ModelTypeT, **kwargs) -> None:
|
119
|
+
model.save()
|
120
|
+
if not self.low_memory:
|
121
|
+
self.models[model.name] = model
|
122
|
+
|
123
|
+
def get_models_attribute_dict(self, attribute: str, models: list[str] | None = None) -> dict[str, Any]:
|
124
|
+
raise NotImplementedError
|
125
|
+
|
126
|
+
def get_model_attribute(self, model: str | ModelTypeT, attribute: str, **kwargs) -> Any:
|
127
|
+
"""Return model attribute value.
|
128
|
+
If `default` is specified, return default value if attribute does not exist.
|
129
|
+
If `default` is not specified, raise ValueError if attribute does not exist.
|
130
|
+
"""
|
131
|
+
if not isinstance(model, str):
|
132
|
+
model = model.name
|
133
|
+
if model not in self.model_graph.nodes:
|
134
|
+
raise ValueError(f"Model does not exist: (model={model})")
|
135
|
+
if attribute not in self.model_graph.nodes[model]:
|
136
|
+
if "default" in kwargs:
|
137
|
+
return kwargs["default"]
|
138
|
+
else:
|
139
|
+
raise ValueError(f"Model does not contain attribute: (model={model}, attribute={attribute})")
|
140
|
+
if attribute == "path":
|
141
|
+
return os.path.join(*self.model_graph.nodes[model][attribute])
|
142
|
+
return self.model_graph.nodes[model][attribute]
|
143
|
+
|
144
|
+
def set_model_attribute(self, model: str | ModelTypeT, attribute: str, val: Any):
|
145
|
+
if not isinstance(model, str):
|
146
|
+
model = model.name
|
147
|
+
self.model_graph.nodes[model][attribute] = val
|
148
|
+
|
149
|
+
def get_minimum_model_set(self, model: str | ModelTypeT, include_self: bool = True) -> list:
|
150
|
+
"""Gets the minimum set of models that the provided model depends on, including itself
|
151
|
+
Returns a list of model names
|
152
|
+
"""
|
153
|
+
if not isinstance(model, str):
|
154
|
+
model = model.name
|
155
|
+
minimum_model_set = list(nx.bfs_tree(self.model_graph, model, reverse=True))
|
156
|
+
if not include_self:
|
157
|
+
minimum_model_set = [m for m in minimum_model_set if m != model]
|
158
|
+
return minimum_model_set
|
159
|
+
|
160
|
+
def get_model_info(self, model: str | ModelTypeT) -> dict[str, Any]:
|
161
|
+
if isinstance(model, str):
|
162
|
+
if model in self.models.keys():
|
163
|
+
model = self.models[model]
|
164
|
+
if isinstance(model, str):
|
165
|
+
model_type = self.get_model_attribute(model=model, attribute="type")
|
166
|
+
model_path = self.get_model_attribute(model=model, attribute="path")
|
167
|
+
model_info = model_type.load_info(path=os.path.join(self.path, model_path))
|
168
|
+
else:
|
169
|
+
model_info = model.get_info()
|
170
|
+
return model_info
|
171
|
+
|
172
|
+
def get_model_names(self, **kwargs) -> list[str]:
|
173
|
+
"""Get all model names that are registered in the model graph, in no particular order."""
|
174
|
+
return list(self.model_graph.nodes)
|
175
|
+
|
176
|
+
def get_models_info(self, models: list[str | ModelTypeT] | None = None) -> dict[str, dict[str, Any]]:
|
177
|
+
models_ = self.get_model_names() if models is None else models
|
178
|
+
model_info_dict = dict()
|
179
|
+
for model in models_:
|
180
|
+
model_name = model if isinstance(model, str) else model.name
|
181
|
+
model_info_dict[model_name] = self.get_model_info(model=model)
|
182
|
+
return model_info_dict
|
183
|
+
|
184
|
+
# TODO: model_name change to model in params
|
185
|
+
def load_model(self, model_name: str | ModelTypeT, path: str | None = None, model_type: Type[ModelTypeT] | None = None) -> ModelTypeT:
|
186
|
+
if isinstance(model_name, AbstractModel):
|
187
|
+
return model_name
|
188
|
+
if model_name in self.models.keys():
|
189
|
+
return self.models[model_name]
|
190
|
+
else:
|
191
|
+
if path is None:
|
192
|
+
path = self.get_model_attribute(model=model_name, attribute="path") # get relative location of the model to the trainer
|
193
|
+
assert path is not None
|
194
|
+
if model_type is None:
|
195
|
+
model_type = self.get_model_attribute(model=model_name, attribute="type")
|
196
|
+
assert model_type is not None
|
197
|
+
return model_type.load(path=os.path.join(self.path, path), reset_paths=self.reset_paths)
|
198
|
+
|
199
|
+
@classmethod
|
200
|
+
def load_info(cls, path: str, reset_paths: bool = False, load_model_if_required: bool = True) -> dict[str, Any]:
|
201
|
+
load_path = os.path.join(path, cls.trainer_info_name)
|
202
|
+
try:
|
203
|
+
return load_pkl.load(path=load_path)
|
204
|
+
except:
|
205
|
+
if load_model_if_required:
|
206
|
+
trainer = cls.load(path=path, reset_paths=reset_paths)
|
207
|
+
return trainer.get_info()
|
208
|
+
else:
|
209
|
+
raise
|
210
|
+
|
211
|
+
def save_info(self, include_model_info: bool = False) -> dict[str, Any]:
|
212
|
+
info = self.get_info(include_model_info=include_model_info)
|
213
|
+
|
214
|
+
save_pkl.save(path=os.path.join(self.path, self.trainer_info_name), object=info)
|
215
|
+
save_json.save(path=os.path.join(self.path, self.trainer_info_json_name), obj=info)
|
216
|
+
return info
|
217
|
+
|
218
|
+
def construct_model_templates(
|
219
|
+
self, hyperparameters: str | dict[str, Any], **kwargs
|
220
|
+
) -> tuple[list[ModelTypeT], dict] | list[ModelTypeT]:
|
221
|
+
raise NotImplementedError
|
222
|
+
|
223
|
+
def get_model_best(self, *args, **kwargs) -> str:
|
224
|
+
raise NotImplementedError
|
225
|
+
|
226
|
+
def get_info(self, include_model_info: bool = False, **kwargs) -> dict[str, Any]:
|
227
|
+
raise NotImplementedError
|
228
|
+
|
229
|
+
def save(self) -> None:
|
230
|
+
raise NotImplementedError
|
231
|
+
|
232
|
+
@classmethod
|
233
|
+
def load(cls, path: str, reset_paths: bool = False) -> Self:
|
234
|
+
load_path = os.path.join(path, cls.trainer_file_name)
|
235
|
+
if not reset_paths:
|
236
|
+
return load_pkl.load(path=load_path)
|
237
|
+
else:
|
238
|
+
obj = load_pkl.load(path=load_path)
|
239
|
+
obj.set_contexts(path)
|
240
|
+
obj.reset_paths = reset_paths
|
241
|
+
return obj
|
242
|
+
|
243
|
+
def fit(self, *args, **kwargs):
|
244
|
+
raise NotImplementedError
|
245
|
+
|
246
|
+
def predict(self, *args, **kwargs) -> Any:
|
247
|
+
raise NotImplementedError
|
248
|
+
|
249
|
+
|
250
|
+
# TODO: This class will be moved to autogluon.tabular
|
251
|
+
class AbstractTabularTrainer(AbstractTrainer[AbstractModel]):
|
58
252
|
"""
|
59
|
-
|
253
|
+
AbstractTabularTrainer contains logic to train a variety of models under a variety of constraints and automatically generate a multi-layer stack ensemble.
|
60
254
|
Beyond the basic functionality, it also has support for model refitting, distillation, pseudo-labelling, unlabeled data, and much more.
|
61
255
|
|
62
256
|
It is not recommended to directly use Trainer. Instead, use Predictor or Learner which internally uses Trainer.
|
@@ -65,7 +259,7 @@ class AbstractTrainer:
|
|
65
259
|
Due to the complexity of the logic within this class, a text description will not give the full picture.
|
66
260
|
It is recommended to carefully read the code and use a debugger to understand how it works.
|
67
261
|
|
68
|
-
|
262
|
+
AbstractTabularTrainer makes much fewer assumptions about the problem than Learner and Predictor.
|
69
263
|
It expects these ambiguities to have already been resolved upstream. For example, problem_type, feature_metadata, num_classes, etc.
|
70
264
|
|
71
265
|
Parameters
|
@@ -84,7 +278,7 @@ class AbstractTrainer:
|
|
84
278
|
FeatureMetadata for X. Sent to each model during fit.
|
85
279
|
eval_metric : Scorer, default = None
|
86
280
|
Metric to optimize. If None, a default metric is used depending on the problem_type.
|
87
|
-
quantile_levels :
|
281
|
+
quantile_levels : list[float] | np.ndarray, default = None
|
88
282
|
# TODO: Add documentation, not documented in Predictor.
|
89
283
|
Only used when problem_type=quantile
|
90
284
|
low_memory : bool, default = True
|
@@ -116,9 +310,6 @@ class AbstractTrainer:
|
|
116
310
|
where `L` ranges from 0 to 50 (Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels).
|
117
311
|
"""
|
118
312
|
|
119
|
-
trainer_file_name = "trainer.pkl"
|
120
|
-
trainer_info_name = "info.pkl"
|
121
|
-
trainer_info_json_name = "info.json"
|
122
313
|
distill_stackname = "distill" # name of stack-level for distilled student models
|
123
314
|
|
124
315
|
def __init__(
|
@@ -126,22 +317,26 @@ class AbstractTrainer:
|
|
126
317
|
path: str,
|
127
318
|
*,
|
128
319
|
problem_type: str,
|
129
|
-
num_classes: int = None,
|
130
|
-
feature_metadata: FeatureMetadata = None,
|
131
|
-
eval_metric: Scorer = None,
|
132
|
-
quantile_levels:
|
320
|
+
num_classes: int | None = None,
|
321
|
+
feature_metadata: FeatureMetadata | None = None,
|
322
|
+
eval_metric: Scorer | None = None,
|
323
|
+
quantile_levels: list[float] | np.ndarray | None = None,
|
133
324
|
low_memory: bool = True,
|
134
325
|
k_fold: int = 0,
|
135
326
|
n_repeats: int = 1,
|
136
|
-
sample_weight: str = None,
|
327
|
+
sample_weight: str | None = None,
|
137
328
|
weight_evaluation: bool = False,
|
138
329
|
save_data: bool = False,
|
139
330
|
random_state: int = 0,
|
140
331
|
verbosity: int = 2,
|
141
332
|
):
|
333
|
+
super().__init__(
|
334
|
+
path=path,
|
335
|
+
low_memory=low_memory,
|
336
|
+
save_data=save_data,
|
337
|
+
)
|
142
338
|
self._validate_num_classes(num_classes=num_classes, problem_type=problem_type)
|
143
339
|
self._validate_quantile_levels(quantile_levels=quantile_levels, problem_type=problem_type)
|
144
|
-
self.path = path
|
145
340
|
self.problem_type = problem_type
|
146
341
|
self.feature_metadata = feature_metadata
|
147
342
|
self.save_data = save_data
|
@@ -185,7 +380,7 @@ class AbstractTrainer:
|
|
185
380
|
|
186
381
|
self.model_best = None
|
187
382
|
|
188
|
-
self.models = {} #
|
383
|
+
self.models = {} # dict of model name -> model object. A key, value pair only exists if a model is persisted in memory. # TODO: v0.1 Rename and consider making private
|
189
384
|
self.model_graph = nx.DiGraph() # Directed Acyclic Graph (DAG) of model interactions. Describes how certain models depend on the predictions of certain other models. Contains numerous metadata regarding each model.
|
190
385
|
self.reset_paths = False
|
191
386
|
|
@@ -210,31 +405,18 @@ class AbstractTrainer:
|
|
210
405
|
|
211
406
|
self._extra_banned_names = set() # Names which are banned but are not used by a trained model.
|
212
407
|
|
213
|
-
self._models_failed_to_train_errors = dict() #
|
408
|
+
self._models_failed_to_train_errors = dict() # dict of model name -> model failure metadata
|
214
409
|
|
215
410
|
# self._exceptions_list = [] # TODO: Keep exceptions list for debugging during benchmarking.
|
216
411
|
|
217
|
-
self.callbacks:
|
412
|
+
self.callbacks: list[AbstractCallback] = []
|
218
413
|
self._callback_early_stop = False
|
219
414
|
|
220
|
-
# path_root is the directory containing learner.pkl
|
221
|
-
@property
|
222
|
-
def path_root(self) -> str:
|
223
|
-
return os.path.dirname(self.path)
|
224
|
-
|
225
|
-
@property
|
226
|
-
def path_utils(self) -> str:
|
227
|
-
return os.path.join(self.path_root, "utils")
|
228
|
-
|
229
415
|
@property
|
230
416
|
def _path_attr(self) -> str:
|
231
417
|
"""Path to cached model graph attributes"""
|
232
418
|
return os.path.join(self.path_utils, "attr")
|
233
419
|
|
234
|
-
@property
|
235
|
-
def path_data(self) -> str:
|
236
|
-
return os.path.join(self.path_utils, "data")
|
237
|
-
|
238
420
|
@property
|
239
421
|
def has_val(self) -> bool:
|
240
422
|
"""Whether the trainer uses validation data"""
|
@@ -246,6 +428,7 @@ class AbstractTrainer:
|
|
246
428
|
if self._num_rows_val is not None:
|
247
429
|
return self._num_rows_val
|
248
430
|
elif self.bagged_mode:
|
431
|
+
assert self._num_rows_train is not None
|
249
432
|
return self._num_rows_train
|
250
433
|
else:
|
251
434
|
return 0
|
@@ -334,8 +517,12 @@ class AbstractTrainer:
|
|
334
517
|
self._y_test_saved = True
|
335
518
|
|
336
519
|
def get_model_names(
|
337
|
-
self,
|
338
|
-
|
520
|
+
self,
|
521
|
+
stack_name: list[str] | str | None = None,
|
522
|
+
level: list[int] | int | None = None,
|
523
|
+
can_infer: bool | None = None,
|
524
|
+
models: list[str] | None = None
|
525
|
+
) -> list[str]:
|
339
526
|
if models is None:
|
340
527
|
models = list(self.model_graph.nodes)
|
341
528
|
if stack_name is not None:
|
@@ -354,7 +541,7 @@ class AbstractTrainer:
|
|
354
541
|
models = [model for model in models if node_attributes[model] == can_infer]
|
355
542
|
return models
|
356
543
|
|
357
|
-
def get_max_level(self, stack_name: str = None, models:
|
544
|
+
def get_max_level(self, stack_name: str | None = None, models: list[str] | None = None) -> int:
|
358
545
|
models = self.get_model_names(stack_name=stack_name, models=models)
|
359
546
|
models_attribute_dict = self.get_models_attribute_dict(attribute="level", models=models)
|
360
547
|
if models_attribute_dict:
|
@@ -362,25 +549,17 @@ class AbstractTrainer:
|
|
362
549
|
else:
|
363
550
|
return -1
|
364
551
|
|
365
|
-
def construct_model_templates(self, hyperparameters: dict, **kwargs) ->
|
552
|
+
def construct_model_templates(self, hyperparameters: dict, **kwargs) -> tuple[list[AbstractModel], dict]:
|
366
553
|
"""Constructs a list of unfit models based on the hyperparameters dict."""
|
367
554
|
raise NotImplementedError
|
368
555
|
|
369
|
-
def construct_model_templates_distillation(self, hyperparameters: dict, **kwargs) ->
|
556
|
+
def construct_model_templates_distillation(self, hyperparameters: dict, **kwargs) -> tuple[list[AbstractModel], dict]:
|
370
557
|
"""Constructs a list of unfit models based on the hyperparameters dict for softclass distillation."""
|
371
558
|
raise NotImplementedError
|
372
559
|
|
373
560
|
def get_model_level(self, model_name: str) -> int:
|
374
561
|
return self.get_model_attribute(model=model_name, attribute="level")
|
375
562
|
|
376
|
-
def set_contexts(self, path_context):
|
377
|
-
self.path = self.create_contexts(path_context)
|
378
|
-
|
379
|
-
def create_contexts(self, path_context: str) -> str:
|
380
|
-
path = path_context
|
381
|
-
|
382
|
-
return path
|
383
|
-
|
384
563
|
def fit(self, X, y, hyperparameters: dict, X_val=None, y_val=None, **kwargs):
|
385
564
|
raise NotImplementedError
|
386
565
|
|
@@ -395,19 +574,19 @@ class AbstractTrainer:
|
|
395
574
|
X_test=None,
|
396
575
|
y_test=None,
|
397
576
|
X_unlabeled=None,
|
398
|
-
base_model_names:
|
399
|
-
core_kwargs: dict = None,
|
400
|
-
aux_kwargs: dict = None,
|
577
|
+
base_model_names: list[str] | None = None,
|
578
|
+
core_kwargs: dict | None = None,
|
579
|
+
aux_kwargs: dict | None = None,
|
401
580
|
level_start=1,
|
402
581
|
level_end=1,
|
403
582
|
time_limit=None,
|
404
|
-
name_suffix: str = None,
|
583
|
+
name_suffix: str | None = None,
|
405
584
|
relative_stack=True,
|
406
585
|
level_time_modifier=0.333,
|
407
586
|
infer_limit=None,
|
408
587
|
infer_limit_batch_size=None,
|
409
|
-
callbacks:
|
410
|
-
) ->
|
588
|
+
callbacks: list[AbstractCallback] | None = None,
|
589
|
+
) -> list[str]:
|
411
590
|
"""
|
412
591
|
Trains a multi-layer stack ensemble using the input data on the hyperparameters dict input.
|
413
592
|
hyperparameters is used to determine the models used in each stack layer.
|
@@ -426,6 +605,8 @@ class AbstractTrainer:
|
|
426
605
|
"""
|
427
606
|
self._fit_setup(time_limit=time_limit, callbacks=callbacks)
|
428
607
|
time_train_start = self._time_train_start
|
608
|
+
assert time_train_start is not None
|
609
|
+
|
429
610
|
if self.callbacks:
|
430
611
|
callback_classes = [c.__class__.__name__ for c in self.callbacks]
|
431
612
|
logger.log(20, f"User-specified callbacks ({len(self.callbacks)}): {callback_classes}")
|
@@ -517,7 +698,7 @@ class AbstractTrainer:
|
|
517
698
|
self.save()
|
518
699
|
return model_names_fit
|
519
700
|
|
520
|
-
def _fit_setup(self, time_limit: float | None = None, callbacks:
|
701
|
+
def _fit_setup(self, time_limit: float | None = None, callbacks: list[AbstractCallback] | None = None):
|
521
702
|
"""
|
522
703
|
Prepare the trainer state at the start of / prior to a fit call.
|
523
704
|
Should be paired with a `self._fit_cleanup()` at the conclusion of the fit call.
|
@@ -561,8 +742,13 @@ class AbstractTrainer:
|
|
561
742
|
|
562
743
|
# TODO: Consider better greedy approximation method such as via fitting a weighted ensemble to evaluate the value of a subset.
|
563
744
|
def _filter_base_models_via_infer_limit(
|
564
|
-
self,
|
565
|
-
|
745
|
+
self,
|
746
|
+
base_model_names: list[str],
|
747
|
+
infer_limit: float | None,
|
748
|
+
infer_limit_modifier: float = 1.0,
|
749
|
+
as_child: bool = True,
|
750
|
+
verbose: bool = True,
|
751
|
+
) -> list[str]:
|
566
752
|
"""
|
567
753
|
Returns a subset of base_model_names whose combined prediction time for 1 row of data does not exceed infer_limit seconds.
|
568
754
|
With the goal of selecting the best valid subset that is most valuable to stack ensembles who use them as base models,
|
@@ -572,9 +758,9 @@ class AbstractTrainer:
|
|
572
758
|
|
573
759
|
Parameters
|
574
760
|
----------
|
575
|
-
base_model_names:
|
576
|
-
|
577
|
-
infer_limit: float
|
761
|
+
base_model_names: list[str]
|
762
|
+
list of model names. These models must already be added to the trainer.
|
763
|
+
infer_limit: float, optional
|
578
764
|
Inference limit in seconds for 1 row of data. This is compared against values pre-computed during fit for the models.
|
579
765
|
infer_limit_modifier: float, default = 1.0
|
580
766
|
Modifier to multiply infer_limit by.
|
@@ -655,22 +841,22 @@ class AbstractTrainer:
|
|
655
841
|
self,
|
656
842
|
X,
|
657
843
|
y,
|
658
|
-
models:
|
844
|
+
models: list[AbstractModel] | dict,
|
659
845
|
X_val=None,
|
660
846
|
y_val=None,
|
661
847
|
X_test=None,
|
662
848
|
y_test=None,
|
663
849
|
X_unlabeled=None,
|
664
850
|
level=1,
|
665
|
-
base_model_names:
|
666
|
-
core_kwargs: dict = None,
|
667
|
-
aux_kwargs: dict = None,
|
668
|
-
name_suffix: str = None,
|
851
|
+
base_model_names: list[str] | None = None,
|
852
|
+
core_kwargs: dict | None = None,
|
853
|
+
aux_kwargs: dict | None = None,
|
854
|
+
name_suffix: str | None = None,
|
669
855
|
infer_limit=None,
|
670
856
|
infer_limit_batch_size=None,
|
671
857
|
full_weighted_ensemble: bool = False,
|
672
858
|
additional_full_weighted_ensemble: bool = False,
|
673
|
-
) ->
|
859
|
+
) -> tuple[list[str], list[str]]:
|
674
860
|
"""
|
675
861
|
Similar to calling self.stack_new_level_core, except auxiliary models will also be trained via a call to self.stack_new_level_aux, with the models trained from self.stack_new_level_core used as base models.
|
676
862
|
"""
|
@@ -718,14 +904,14 @@ class AbstractTrainer:
|
|
718
904
|
self,
|
719
905
|
X,
|
720
906
|
y,
|
721
|
-
models:
|
907
|
+
models: list[AbstractModel] | dict,
|
722
908
|
X_val=None,
|
723
909
|
y_val=None,
|
724
910
|
X_test=None,
|
725
911
|
y_test=None,
|
726
912
|
X_unlabeled=None,
|
727
913
|
level=1,
|
728
|
-
base_model_names:
|
914
|
+
base_model_names: list[str] | None = None,
|
729
915
|
fit_strategy: Literal["sequential", "parallel"] = "sequential",
|
730
916
|
stack_name="core",
|
731
917
|
ag_args=None,
|
@@ -734,13 +920,13 @@ class AbstractTrainer:
|
|
734
920
|
included_model_types=None,
|
735
921
|
excluded_model_types=None,
|
736
922
|
ensemble_type=StackerEnsembleModel,
|
737
|
-
name_suffix: str = None,
|
923
|
+
name_suffix: str | None = None,
|
738
924
|
get_models_func=None,
|
739
925
|
refit_full=False,
|
740
926
|
infer_limit=None,
|
741
927
|
infer_limit_batch_size=None,
|
742
928
|
**kwargs,
|
743
|
-
) ->
|
929
|
+
) -> list[str]:
|
744
930
|
"""
|
745
931
|
Trains all models using the data provided.
|
746
932
|
If level > 1, then the models will use base model predictions as additional features.
|
@@ -757,7 +943,11 @@ class AbstractTrainer:
|
|
757
943
|
if not self.bagged_mode and level != 1:
|
758
944
|
raise ValueError("Stack Ensembling is not valid for non-bagged mode.")
|
759
945
|
|
760
|
-
base_model_names = self._filter_base_models_via_infer_limit(
|
946
|
+
base_model_names = self._filter_base_models_via_infer_limit(
|
947
|
+
base_model_names=base_model_names,
|
948
|
+
infer_limit=infer_limit,
|
949
|
+
infer_limit_modifier=0.8,
|
950
|
+
)
|
761
951
|
if ag_args_fit is None:
|
762
952
|
ag_args_fit = {}
|
763
953
|
ag_args_fit = ag_args_fit.copy()
|
@@ -779,7 +969,7 @@ class AbstractTrainer:
|
|
779
969
|
(base_model_names, base_model_paths, base_model_types) = (None, None, None)
|
780
970
|
elif level > 1:
|
781
971
|
base_model_names, base_model_paths, base_model_types = self._get_models_load_info(model_names=base_model_names)
|
782
|
-
if len(base_model_names) == 0:
|
972
|
+
if len(base_model_names) == 0: # type: ignore
|
783
973
|
logger.log(20, f"No base models to train on, skipping stack level {level}...")
|
784
974
|
return []
|
785
975
|
else:
|
@@ -874,7 +1064,7 @@ class AbstractTrainer:
|
|
874
1064
|
self,
|
875
1065
|
X,
|
876
1066
|
y,
|
877
|
-
base_model_names:
|
1067
|
+
base_model_names: list[str],
|
878
1068
|
level: int | str = "auto",
|
879
1069
|
fit=True,
|
880
1070
|
stack_name="aux1",
|
@@ -888,7 +1078,7 @@ class AbstractTrainer:
|
|
888
1078
|
fit_weighted_ensemble: bool = True,
|
889
1079
|
name_extra: str | None = None,
|
890
1080
|
total_resources: dict | None = None,
|
891
|
-
) ->
|
1081
|
+
) -> list[str]:
|
892
1082
|
"""
|
893
1083
|
Trains auxiliary models (currently a single weighted ensemble) using the provided base models.
|
894
1084
|
Level must be greater than the level of any of the base models.
|
@@ -1038,13 +1228,13 @@ class AbstractTrainer:
|
|
1038
1228
|
)
|
1039
1229
|
|
1040
1230
|
# TODO: Slow if large ensemble with many models, could cache output result to speed up during inference
|
1041
|
-
def _construct_model_pred_order(self, models:
|
1231
|
+
def _construct_model_pred_order(self, models: list[str]) -> list[str]:
|
1042
1232
|
"""
|
1043
1233
|
Constructs a list of model names in order of inference calls required to infer on all the models.
|
1044
1234
|
|
1045
1235
|
Parameters
|
1046
1236
|
----------
|
1047
|
-
models :
|
1237
|
+
models : list[str]
|
1048
1238
|
The list of models to construct the prediction order from.
|
1049
1239
|
If a model has dependencies, the dependency models will be put earlier in the output list.
|
1050
1240
|
Models explicitly mentioned in the `models` input will be placed as early as possible in the output list.
|
@@ -1068,17 +1258,17 @@ class AbstractTrainer:
|
|
1068
1258
|
model_set = set(model_order)
|
1069
1259
|
return model_order
|
1070
1260
|
|
1071
|
-
def _construct_model_pred_order_with_pred_dict(self, models:
|
1261
|
+
def _construct_model_pred_order_with_pred_dict(self, models: list[str], models_to_ignore: list[str] = None) -> list[str]:
|
1072
1262
|
"""
|
1073
1263
|
Constructs a list of model names in order of inference calls required to infer on all the models.
|
1074
1264
|
Unlike `_construct_model_pred_order`, this method's output is in undefined order when multiple models are valid to infer at the same time.
|
1075
1265
|
|
1076
1266
|
Parameters
|
1077
1267
|
----------
|
1078
|
-
models :
|
1268
|
+
models : list[str]
|
1079
1269
|
The list of models to construct the prediction order from.
|
1080
1270
|
If a model has dependencies, the dependency models will be put earlier in the output list.
|
1081
|
-
models_to_ignore :
|
1271
|
+
models_to_ignore : list[str], optional
|
1082
1272
|
A list of models that have already been computed and can be ignored.
|
1083
1273
|
Models in this list and their dependencies (if not depended on by other models in `models`) will be pruned from the final output.
|
1084
1274
|
|
@@ -1109,7 +1299,23 @@ class AbstractTrainer:
|
|
1109
1299
|
|
1110
1300
|
# Get model prediction order
|
1111
1301
|
return list(nx.lexicographical_topological_sort(subgraph))
|
1112
|
-
|
1302
|
+
|
1303
|
+
def get_models_attribute_dict(self, attribute: str, models: list | None = None) -> dict[str, Any]:
|
1304
|
+
"""Returns dictionary of model name -> attribute value for the provided attribute.
|
1305
|
+
"""
|
1306
|
+
models_attribute_dict = nx.get_node_attributes(self.model_graph, attribute)
|
1307
|
+
if models is not None:
|
1308
|
+
model_names = []
|
1309
|
+
for model in models:
|
1310
|
+
if not isinstance(model, str):
|
1311
|
+
model = model.name
|
1312
|
+
model_names.append(model)
|
1313
|
+
if attribute == "path":
|
1314
|
+
models_attribute_dict = {key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names}
|
1315
|
+
else:
|
1316
|
+
models_attribute_dict = {key: val for key, val in models_attribute_dict.items() if key in model_names}
|
1317
|
+
return models_attribute_dict
|
1318
|
+
|
1113
1319
|
# TODO: Consider adding persist to disk functionality for pred_proba dictionary to lessen memory burden on large multiclass problems.
|
1114
1320
|
# For datasets with 100+ classes, this function could potentially run the system OOM due to each pred_proba numpy array taking significant amounts of space.
|
1115
1321
|
# This issue already existed in the previous level-based version but only had the minimum required predictions in memory at a time, whereas this has all model predictions in memory.
|
@@ -1117,7 +1323,7 @@ class AbstractTrainer:
|
|
1117
1323
|
def get_model_pred_proba_dict(
|
1118
1324
|
self,
|
1119
1325
|
X: pd.DataFrame,
|
1120
|
-
models:
|
1326
|
+
models: list[str],
|
1121
1327
|
model_pred_proba_dict: dict = None,
|
1122
1328
|
model_pred_time_dict: dict = None,
|
1123
1329
|
record_pred_time: bool = False,
|
@@ -1132,7 +1338,7 @@ class AbstractTrainer:
|
|
1132
1338
|
----------
|
1133
1339
|
X : pd.DataFrame
|
1134
1340
|
Input data to predict on.
|
1135
|
-
models :
|
1341
|
+
models : list[str]
|
1136
1342
|
The list of models to predict with.
|
1137
1343
|
Note that if models have dependency models, their dependencies will also be predicted with and included in the output.
|
1138
1344
|
model_pred_proba_dict : dict, optional
|
@@ -1194,13 +1400,13 @@ class AbstractTrainer:
|
|
1194
1400
|
else:
|
1195
1401
|
return model_pred_proba_dict
|
1196
1402
|
|
1197
|
-
def get_model_oof_dict(self, models:
|
1403
|
+
def get_model_oof_dict(self, models: list[str]) -> dict:
|
1198
1404
|
"""
|
1199
1405
|
Returns a dictionary of out-of-fold prediction probabilities, keyed by model name
|
1200
1406
|
"""
|
1201
1407
|
return {model: self.get_model_oof(model) for model in models}
|
1202
1408
|
|
1203
|
-
def get_model_pred_dict(self, X: pd.DataFrame, models:
|
1409
|
+
def get_model_pred_dict(self, X: pd.DataFrame, models: list[str], record_pred_time: bool = False, **kwargs):
|
1204
1410
|
"""
|
1205
1411
|
Optimally computes predictions for each model in `models`.
|
1206
1412
|
Will compute each necessary model only once and store predictions in a `model_pred_dict` dictionary.
|
@@ -1212,7 +1418,7 @@ class AbstractTrainer:
|
|
1212
1418
|
----------
|
1213
1419
|
X : pd.DataFrame
|
1214
1420
|
Input data to predict on.
|
1215
|
-
models :
|
1421
|
+
models : list[str]
|
1216
1422
|
The list of models to predict with.
|
1217
1423
|
Note that if models have dependency models, their dependencies will also be predicted with and included in the output.
|
1218
1424
|
record_pred_time : bool, default = False
|
@@ -1293,8 +1499,8 @@ class AbstractTrainer:
|
|
1293
1499
|
self,
|
1294
1500
|
X: pd.DataFrame,
|
1295
1501
|
*,
|
1296
|
-
model: str = None,
|
1297
|
-
base_models:
|
1502
|
+
model: str | None = None,
|
1503
|
+
base_models: list[str] | None = None,
|
1298
1504
|
model_pred_proba_dict: Optional[dict] = None,
|
1299
1505
|
fit: bool = False,
|
1300
1506
|
use_orig_features: bool = True,
|
@@ -1311,7 +1517,7 @@ class AbstractTrainer:
|
|
1311
1517
|
model : str, default = None
|
1312
1518
|
The model to derive `base_models` from.
|
1313
1519
|
Cannot be specified alongside `base_models`.
|
1314
|
-
base_models :
|
1520
|
+
base_models : list[str], default = None
|
1315
1521
|
The list of base models to augment X with.
|
1316
1522
|
Base models will add their prediction probabilities as extra features to X.
|
1317
1523
|
Cannot be specified alongside `model`.
|
@@ -1355,7 +1561,7 @@ class AbstractTrainer:
|
|
1355
1561
|
X = X_stacker
|
1356
1562
|
return X
|
1357
1563
|
|
1358
|
-
def get_feature_metadata(self, use_orig_features: bool = True, model: str | None = None, base_models:
|
1564
|
+
def get_feature_metadata(self, use_orig_features: bool = True, model: str | None = None, base_models: list[str] | None = None) -> FeatureMetadata:
|
1359
1565
|
"""
|
1360
1566
|
Returns the FeatureMetadata input to a `model.fit` call.
|
1361
1567
|
Pairs with `X = self.get_inputs_to_stacker(...)`. The returned FeatureMetadata should reflect the contents of `X`.
|
@@ -1368,7 +1574,7 @@ class AbstractTrainer:
|
|
1368
1574
|
model : str, default = None
|
1369
1575
|
If specified, it must be an already existing model.
|
1370
1576
|
`base_models` will be set to the base models of `model`.
|
1371
|
-
base_models :
|
1577
|
+
base_models : list[str], default = None
|
1372
1578
|
If specified, will add the stack features of the `base_models` to FeatureMetadata.
|
1373
1579
|
|
1374
1580
|
Returns
|
@@ -1397,7 +1603,7 @@ class AbstractTrainer:
|
|
1397
1603
|
feature_metadata = FeatureMetadata(type_map_raw={})
|
1398
1604
|
return feature_metadata
|
1399
1605
|
|
1400
|
-
def _get_stack_column_names(self, models:
|
1606
|
+
def _get_stack_column_names(self, models: list[str]) -> tuple[list[str], int]:
|
1401
1607
|
"""
|
1402
1608
|
Get the stack column names generated when the provided models are used as base models in a stack ensemble.
|
1403
1609
|
Additionally output the number of columns per model as an int.
|
@@ -1562,7 +1768,7 @@ class AbstractTrainer:
|
|
1562
1768
|
# Fits _FULL models and links them in the stack so _FULL models only use other _FULL models as input during stacking
|
1563
1769
|
# If model is specified, will fit all _FULL models that are ancestors of the provided model, automatically linking them.
|
1564
1770
|
# If no model is specified, all models are refit and linked appropriately.
|
1565
|
-
def refit_ensemble_full(self, model: str |
|
1771
|
+
def refit_ensemble_full(self, model: str | list[str] = "all", **kwargs) -> dict:
|
1566
1772
|
if model == "all":
|
1567
1773
|
ensemble_set = self.get_model_names()
|
1568
1774
|
elif isinstance(model, list):
|
@@ -1618,7 +1824,13 @@ class AbstractTrainer:
|
|
1618
1824
|
"""Get refit full model's parent. If model does not have a parent, return `model`."""
|
1619
1825
|
return self.get_model_attribute(model=model, attribute="refit_full_parent", default=model)
|
1620
1826
|
|
1621
|
-
def get_model_best(
|
1827
|
+
def get_model_best(
|
1828
|
+
self,
|
1829
|
+
can_infer: bool | None = None,
|
1830
|
+
allow_full: bool = True,
|
1831
|
+
infer_limit: float | None = None,
|
1832
|
+
infer_limit_as_child: bool = False
|
1833
|
+
) -> str:
|
1622
1834
|
"""
|
1623
1835
|
Returns the name of the model with the best validation score that satisfies all specified constraints.
|
1624
1836
|
If no model satisfies the constraints, an AssertionError will be raised.
|
@@ -1704,7 +1916,7 @@ class AbstractTrainer:
|
|
1704
1916
|
else:
|
1705
1917
|
self.models[model.name] = model
|
1706
1918
|
|
1707
|
-
def save(self):
|
1919
|
+
def save(self) -> None:
|
1708
1920
|
models = self.models
|
1709
1921
|
if self.low_memory:
|
1710
1922
|
self.models = {}
|
@@ -1712,7 +1924,7 @@ class AbstractTrainer:
|
|
1712
1924
|
if self.low_memory:
|
1713
1925
|
self.models = models
|
1714
1926
|
|
1715
|
-
def compile(self, model_names="all", with_ancestors=False, compiler_configs=None) ->
|
1927
|
+
def compile(self, model_names="all", with_ancestors=False, compiler_configs=None) -> list[str]:
|
1716
1928
|
"""
|
1717
1929
|
Compile a list of models for accelerated prediction.
|
1718
1930
|
|
@@ -1786,7 +1998,7 @@ class AbstractTrainer:
|
|
1786
1998
|
self.save()
|
1787
1999
|
return model_names
|
1788
2000
|
|
1789
|
-
def persist(self, model_names="all", with_ancestors=False, max_memory=None) ->
|
2001
|
+
def persist(self, model_names="all", with_ancestors=False, max_memory=None) -> list[str]:
|
1790
2002
|
if model_names == "all":
|
1791
2003
|
model_names = self.get_model_names()
|
1792
2004
|
elif model_names == "best":
|
@@ -1852,19 +2064,6 @@ class AbstractTrainer:
|
|
1852
2064
|
model.models[fold] = model.load_child(fold_model)
|
1853
2065
|
return model_names
|
1854
2066
|
|
1855
|
-
# TODO: model_name change to model in params
|
1856
|
-
def load_model(self, model_name: str, path: str = None, model_type=None) -> AbstractModel:
|
1857
|
-
if isinstance(model_name, AbstractModel):
|
1858
|
-
return model_name
|
1859
|
-
if model_name in self.models.keys():
|
1860
|
-
return self.models[model_name]
|
1861
|
-
else:
|
1862
|
-
if path is None:
|
1863
|
-
path = self.get_model_attribute(model=model_name, attribute="path") # get relative location of the model to the trainer
|
1864
|
-
if model_type is None:
|
1865
|
-
model_type = self.get_model_attribute(model=model_name, attribute="type")
|
1866
|
-
return model_type.load(path=os.path.join(self.path, path), reset_paths=self.reset_paths)
|
1867
|
-
|
1868
2067
|
def unpersist(self, model_names="all") -> list:
|
1869
2068
|
if model_names == "all":
|
1870
2069
|
model_names = list(self.models.keys())
|
@@ -1893,13 +2092,13 @@ class AbstractTrainer:
|
|
1893
2092
|
hyperparameters=None,
|
1894
2093
|
ag_args_fit=None,
|
1895
2094
|
time_limit=None,
|
1896
|
-
name_suffix: str = None,
|
2095
|
+
name_suffix: str | None = None,
|
1897
2096
|
save_bag_folds=None,
|
1898
2097
|
check_if_best=True,
|
1899
2098
|
child_hyperparameters=None,
|
1900
2099
|
get_models_func=None,
|
1901
2100
|
total_resources: dict | None = None,
|
1902
|
-
) ->
|
2101
|
+
) -> list[str]:
|
1903
2102
|
if get_models_func is None:
|
1904
2103
|
get_models_func = self.construct_model_templates
|
1905
2104
|
if len(base_model_names) == 0:
|
@@ -1979,10 +2178,10 @@ class AbstractTrainer:
|
|
1979
2178
|
X: pd.DataFrame,
|
1980
2179
|
y: pd.Series,
|
1981
2180
|
model: AbstractModel,
|
1982
|
-
X_val: pd.DataFrame = None,
|
1983
|
-
y_val: pd.Series = None,
|
1984
|
-
X_test: pd.DataFrame = None,
|
1985
|
-
y_test: pd.Series = None,
|
2181
|
+
X_val: pd.DataFrame | None = None,
|
2182
|
+
y_val: pd.Series | None = None,
|
2183
|
+
X_test: pd.DataFrame | None = None,
|
2184
|
+
y_test: pd.Series | None = None,
|
1986
2185
|
total_resources: dict = None,
|
1987
2186
|
**model_fit_kwargs,
|
1988
2187
|
) -> AbstractModel:
|
@@ -1998,20 +2197,20 @@ class AbstractTrainer:
|
|
1998
2197
|
X: pd.DataFrame,
|
1999
2198
|
y: pd.Series,
|
2000
2199
|
model: AbstractModel,
|
2001
|
-
X_val: pd.DataFrame = None,
|
2002
|
-
y_val: pd.Series = None,
|
2003
|
-
X_test: pd.DataFrame = None,
|
2004
|
-
y_test: pd.Series = None,
|
2005
|
-
X_pseudo: pd.DataFrame = None,
|
2006
|
-
y_pseudo: pd.DataFrame = None,
|
2007
|
-
time_limit: float = None,
|
2200
|
+
X_val: pd.DataFrame | None = None,
|
2201
|
+
y_val: pd.Series | None = None,
|
2202
|
+
X_test: pd.DataFrame | None = None,
|
2203
|
+
y_test: pd.Series | None = None,
|
2204
|
+
X_pseudo: pd.DataFrame | None = None,
|
2205
|
+
y_pseudo: pd.DataFrame | None = None,
|
2206
|
+
time_limit: float | None = None,
|
2008
2207
|
stack_name: str = "core",
|
2009
2208
|
level: int = 1,
|
2010
2209
|
compute_score: bool = True,
|
2011
|
-
total_resources: dict = None,
|
2210
|
+
total_resources: dict | None = None,
|
2012
2211
|
errors: Literal["ignore", "raise"] = "ignore",
|
2013
|
-
errors_ignore: list = None,
|
2014
|
-
errors_raise: list = None,
|
2212
|
+
errors_ignore: list | None = None,
|
2213
|
+
errors_raise: list | None = None,
|
2015
2214
|
is_ray_worker: bool = False,
|
2016
2215
|
**model_fit_kwargs,
|
2017
2216
|
) -> list[str]:
|
@@ -2186,7 +2385,7 @@ class AbstractTrainer:
|
|
2186
2385
|
return model_names_trained
|
2187
2386
|
|
2188
2387
|
# FIXME: v1.0 Move to AbstractModel for most fields
|
2189
|
-
def _get_model_metadata(self, model: AbstractModel, stack_name: str = "core", level: int = 1) ->
|
2388
|
+
def _get_model_metadata(self, model: AbstractModel, stack_name: str = "core", level: int = 1) -> dict[str, Any]:
|
2190
2389
|
"""
|
2191
2390
|
Returns the model metadata used to initialize a node in the DAG (self.model_graph).
|
2192
2391
|
"""
|
@@ -2415,11 +2614,11 @@ class AbstractTrainer:
|
|
2415
2614
|
compute_score=True,
|
2416
2615
|
total_resources: dict | None = None,
|
2417
2616
|
errors: Literal["ignore", "raise"] = "ignore",
|
2418
|
-
errors_ignore: list = None,
|
2419
|
-
errors_raise: list = None,
|
2617
|
+
errors_ignore: list | None = None,
|
2618
|
+
errors_raise: list | None = None,
|
2420
2619
|
is_ray_worker: bool = False,
|
2421
2620
|
**kwargs,
|
2422
|
-
) ->
|
2621
|
+
) -> list[str]:
|
2423
2622
|
"""
|
2424
2623
|
Trains a model, with the potential to train multiple versions of this model with hyperparameter tuning and feature pruning.
|
2425
2624
|
Returns a list of successfully trained and saved model names.
|
@@ -2562,7 +2761,7 @@ class AbstractTrainer:
|
|
2562
2761
|
raise exception
|
2563
2762
|
return model_names_trained
|
2564
2763
|
|
2565
|
-
# TODO: Move to a utility function outside of
|
2764
|
+
# TODO: Move to a utility function outside of AbstractTabularTrainer
|
2566
2765
|
@staticmethod
|
2567
2766
|
def _check_raise_exception(
|
2568
2767
|
exception: Exception,
|
@@ -2643,7 +2842,7 @@ class AbstractTrainer:
|
|
2643
2842
|
def _callbacks_after_fit(
|
2644
2843
|
self,
|
2645
2844
|
*,
|
2646
|
-
model_names:
|
2845
|
+
model_names: list[str],
|
2647
2846
|
stack_name: str,
|
2648
2847
|
level: int,
|
2649
2848
|
):
|
@@ -2662,7 +2861,7 @@ class AbstractTrainer:
|
|
2662
2861
|
# TODO: Time allowance not accurate if running from fit_continue
|
2663
2862
|
# TODO: Remove level and stack_name arguments, can get them automatically
|
2664
2863
|
# TODO: Make sure that pretraining on X_unlabeled only happens 1 time rather than every fold of bagging. (Do during pretrain API work?)
|
2665
|
-
def _train_multi_repeats(self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs) ->
|
2864
|
+
def _train_multi_repeats(self, X, y, models: list, n_repeats, n_repeat_start=1, time_limit=None, time_limit_total_level=None, **kwargs) -> list[str]:
|
2666
2865
|
"""
|
2667
2866
|
Fits bagged ensemble models with additional folds and/or bagged repeats.
|
2668
2867
|
Models must have already been fit prior to entering this method.
|
@@ -2721,7 +2920,7 @@ class AbstractTrainer:
|
|
2721
2920
|
return models_valid
|
2722
2921
|
|
2723
2922
|
def _train_multi_initial(
|
2724
|
-
self, X, y, models:
|
2923
|
+
self, X, y, models: list[AbstractModel], k_fold, n_repeats, hyperparameter_tune_kwargs=None, time_limit=None, feature_prune_kwargs=None, **kwargs
|
2725
2924
|
):
|
2726
2925
|
"""
|
2727
2926
|
Fits models that have not previously been fit.
|
@@ -3137,7 +3336,7 @@ class AbstractTrainer:
|
|
3137
3336
|
self,
|
3138
3337
|
X,
|
3139
3338
|
y,
|
3140
|
-
models:
|
3339
|
+
models: list[AbstractModel],
|
3141
3340
|
hyperparameter_tune_kwargs=None,
|
3142
3341
|
feature_prune_kwargs=None,
|
3143
3342
|
k_fold=None,
|
@@ -3146,7 +3345,7 @@ class AbstractTrainer:
|
|
3146
3345
|
time_limit=None,
|
3147
3346
|
delay_bag_sets: bool = False,
|
3148
3347
|
**kwargs,
|
3149
|
-
) ->
|
3348
|
+
) -> list[str]:
|
3150
3349
|
"""
|
3151
3350
|
Train a list of models using the same data.
|
3152
3351
|
Assumes that input data has already been processed in the form the models will receive as input (including stack feature generation).
|
@@ -3205,13 +3404,13 @@ class AbstractTrainer:
|
|
3205
3404
|
y_val,
|
3206
3405
|
X_test=None,
|
3207
3406
|
y_test=None,
|
3208
|
-
hyperparameters: dict = None,
|
3407
|
+
hyperparameters: dict | None = None,
|
3209
3408
|
X_unlabeled=None,
|
3210
3409
|
num_stack_levels=0,
|
3211
3410
|
time_limit=None,
|
3212
3411
|
groups=None,
|
3213
3412
|
**kwargs,
|
3214
|
-
) ->
|
3413
|
+
) -> list[str]:
|
3215
3414
|
"""Identical to self.train_multi_levels, but also saves the data to disk. This should only ever be called once."""
|
3216
3415
|
if time_limit is not None and time_limit <= 0:
|
3217
3416
|
raise AssertionError(f"Not enough time left to train models. Consider specifying a larger time_limit. Time remaining: {round(time_limit, 2)}s")
|
@@ -3254,19 +3453,19 @@ class AbstractTrainer:
|
|
3254
3453
|
logger.log(30, "Warning: AutoGluon did not successfully train any models")
|
3255
3454
|
return model_names_fit
|
3256
3455
|
|
3257
|
-
def _predict_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict = None) -> np.ndarray:
|
3456
|
+
def _predict_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
|
3258
3457
|
y_pred_proba = self._predict_proba_model(X=X, model=model, model_pred_proba_dict=model_pred_proba_dict)
|
3259
3458
|
return get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=self.problem_type)
|
3260
3459
|
|
3261
|
-
def _predict_proba_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict = None) -> np.ndarray:
|
3460
|
+
def _predict_proba_model(self, X: pd.DataFrame, model: str, model_pred_proba_dict: dict | None = None) -> np.ndarray:
|
3262
3461
|
model_pred_proba_dict = self.get_model_pred_proba_dict(X=X, models=[model], model_pred_proba_dict=model_pred_proba_dict)
|
3263
3462
|
if not isinstance(model, str):
|
3264
3463
|
model = model.name
|
3265
3464
|
return model_pred_proba_dict[model]
|
3266
3465
|
|
3267
3466
|
def _proxy_model_feature_prune(
|
3268
|
-
self, model_fit_kwargs: dict, time_limit: float, layer_fit_time: float, level: int, features:
|
3269
|
-
) ->
|
3467
|
+
self, model_fit_kwargs: dict, time_limit: float, layer_fit_time: float, level: int, features: list[str], **feature_prune_kwargs: dict
|
3468
|
+
) -> list[str]:
|
3270
3469
|
"""
|
3271
3470
|
Uses the best LightGBM-based base learner of this layer to perform time-aware permutation feature importance based feature pruning.
|
3272
3471
|
If all LightGBM models fail, use the model that achieved the highest validation accuracy. Feature pruning gets the smaller of the
|
@@ -3287,12 +3486,12 @@ class AbstractTrainer:
|
|
3287
3486
|
How long it took to fit all the models in this layer once. Used to calculate how long to feature prune for.
|
3288
3487
|
level : int
|
3289
3488
|
Level of this stack layer.
|
3290
|
-
features:
|
3489
|
+
features: list[str]
|
3291
3490
|
The list of feature names in the inputted dataset.
|
3292
3491
|
|
3293
3492
|
Returns
|
3294
3493
|
-------
|
3295
|
-
candidate_features :
|
3494
|
+
candidate_features : list[str]
|
3296
3495
|
Feature names that survived the pruning procedure.
|
3297
3496
|
"""
|
3298
3497
|
k = feature_prune_kwargs.pop("k", 2)
|
@@ -3326,14 +3525,14 @@ class AbstractTrainer:
|
|
3326
3525
|
def _get_default_proxy_model_class(self):
|
3327
3526
|
return None
|
3328
3527
|
|
3329
|
-
def _retain_better_pruned_models(self, pruned_models:
|
3528
|
+
def _retain_better_pruned_models(self, pruned_models: list[str], original_prune_map: dict, force_prune: bool = False) -> list[str]:
|
3330
3529
|
"""
|
3331
3530
|
Compares models fit on the pruned set of features with their counterpart, models fit on full set of features.
|
3332
3531
|
Take the model that achieved a higher validation set score and delete the other from self.model_graph.
|
3333
3532
|
|
3334
3533
|
Parameters
|
3335
3534
|
----------
|
3336
|
-
pruned_models :
|
3535
|
+
pruned_models : list[str]
|
3337
3536
|
A list of pruned model names.
|
3338
3537
|
original_prune_map : dict
|
3339
3538
|
A dictionary mapping the names of models fitted on pruned features to the names of models fitted on original features.
|
@@ -3342,7 +3541,7 @@ class AbstractTrainer:
|
|
3342
3541
|
|
3343
3542
|
Returns
|
3344
3543
|
----------
|
3345
|
-
models :
|
3544
|
+
models : list[str]
|
3346
3545
|
A list of model names.
|
3347
3546
|
"""
|
3348
3547
|
models = []
|
@@ -3460,7 +3659,7 @@ class AbstractTrainer:
|
|
3460
3659
|
model_types = self.get_models_attribute_dict(attribute="type", models=model_names)
|
3461
3660
|
return model_names, model_paths, model_types
|
3462
3661
|
|
3463
|
-
def get_model_attribute_full(self, model:
|
3662
|
+
def get_model_attribute_full(self, model: str | list[str], attribute: str, func=sum) -> float | int:
|
3464
3663
|
"""
|
3465
3664
|
Sums the attribute value across all models that the provided model depends on, including itself.
|
3466
3665
|
For instance, this function can return the expected total predict_time of a model.
|
@@ -3497,7 +3696,7 @@ class AbstractTrainer:
|
|
3497
3696
|
attribute_full = 0
|
3498
3697
|
return attribute_full
|
3499
3698
|
|
3500
|
-
def get_models_attribute_full(self, models:
|
3699
|
+
def get_models_attribute_full(self, models: list[str], attribute: str, func=sum):
|
3501
3700
|
"""
|
3502
3701
|
For each model in models, returns the output of self.get_model_attribute_full mapped to a dict.
|
3503
3702
|
"""
|
@@ -3506,57 +3705,6 @@ class AbstractTrainer:
|
|
3506
3705
|
d[model] = self.get_model_attribute_full(model=model, attribute=attribute, func=func)
|
3507
3706
|
return d
|
3508
3707
|
|
3509
|
-
def get_models_attribute_dict(self, attribute: str, models: list = None) -> Dict[str, Any]:
|
3510
|
-
"""
|
3511
|
-
Returns dictionary of model name -> attribute value for the provided attribute.
|
3512
|
-
"""
|
3513
|
-
models_attribute_dict = nx.get_node_attributes(self.model_graph, attribute)
|
3514
|
-
if models is not None:
|
3515
|
-
model_names = []
|
3516
|
-
for model in models:
|
3517
|
-
if not isinstance(model, str):
|
3518
|
-
model = model.name
|
3519
|
-
model_names.append(model)
|
3520
|
-
if attribute == "path":
|
3521
|
-
models_attribute_dict = {key: os.path.join(*val) for key, val in models_attribute_dict.items() if key in model_names}
|
3522
|
-
else:
|
3523
|
-
models_attribute_dict = {key: val for key, val in models_attribute_dict.items() if key in model_names}
|
3524
|
-
return models_attribute_dict
|
3525
|
-
|
3526
|
-
def get_model_attribute(self, model, attribute: str, **kwargs) -> Any:
|
3527
|
-
"""
|
3528
|
-
Return model attribute value.
|
3529
|
-
If `default` is specified, return default value if attribute does not exist.
|
3530
|
-
If `default` is not specified, raise ValueError if attribute does not exist.
|
3531
|
-
"""
|
3532
|
-
if not isinstance(model, str):
|
3533
|
-
model = model.name
|
3534
|
-
if model not in self.model_graph.nodes:
|
3535
|
-
raise ValueError(f"Model does not exist: (model={model})")
|
3536
|
-
if attribute not in self.model_graph.nodes[model]:
|
3537
|
-
if "default" in kwargs:
|
3538
|
-
return kwargs["default"]
|
3539
|
-
else:
|
3540
|
-
raise ValueError(f"Model does not contain attribute: (model={model}, attribute={attribute})")
|
3541
|
-
if attribute == "path":
|
3542
|
-
return os.path.join(*self.model_graph.nodes[model][attribute])
|
3543
|
-
return self.model_graph.nodes[model][attribute]
|
3544
|
-
|
3545
|
-
def set_model_attribute(self, model, attribute: str, val):
|
3546
|
-
if not isinstance(model, str):
|
3547
|
-
model = model.name
|
3548
|
-
self.model_graph.nodes[model][attribute] = val
|
3549
|
-
|
3550
|
-
# Gets the minimum set of models that the provided model depends on, including itself
|
3551
|
-
# Returns a list of model names
|
3552
|
-
def get_minimum_model_set(self, model, include_self=True) -> list:
|
3553
|
-
if not isinstance(model, str):
|
3554
|
-
model = model.name
|
3555
|
-
minimum_model_set = list(nx.bfs_tree(self.model_graph, model, reverse=True))
|
3556
|
-
if not include_self:
|
3557
|
-
minimum_model_set = [m for m in minimum_model_set if m != model]
|
3558
|
-
return minimum_model_set
|
3559
|
-
|
3560
3708
|
# Gets the minimum set of models that the provided models depend on, including themselves
|
3561
3709
|
# Returns a list of model names
|
3562
3710
|
def get_minimum_models_set(self, models: list) -> list:
|
@@ -3573,7 +3721,7 @@ class AbstractTrainer:
|
|
3573
3721
|
base_model_set = list(self.model_graph.predecessors(model))
|
3574
3722
|
return base_model_set
|
3575
3723
|
|
3576
|
-
def model_refit_map(self, inverse=False) ->
|
3724
|
+
def model_refit_map(self, inverse=False) -> dict[str, str]:
|
3577
3725
|
"""
|
3578
3726
|
Returns dict of parent model -> refit model
|
3579
3727
|
|
@@ -3587,10 +3735,6 @@ class AbstractTrainer:
|
|
3587
3735
|
def model_exists(self, model: str) -> bool:
|
3588
3736
|
return model in self.get_model_names()
|
3589
3737
|
|
3590
|
-
def _get_banned_model_names(self) -> list:
|
3591
|
-
"""Gets all model names which would cause model files to be overwritten if a new model was trained with the name"""
|
3592
|
-
return self.get_model_names() + list(self._extra_banned_names)
|
3593
|
-
|
3594
3738
|
def _flatten_model_info(self, model_info: dict) -> dict:
|
3595
3739
|
"""
|
3596
3740
|
Flattens the model_info nested dictionary into a shallow dictionary to convert to a pandas DataFrame row.
|
@@ -3643,7 +3787,7 @@ class AbstractTrainer:
|
|
3643
3787
|
model_info_flat[key] = custom_info[key]
|
3644
3788
|
return model_info_flat
|
3645
3789
|
|
3646
|
-
def leaderboard(self, extra_info=False, refit_full: bool = None, set_refit_score_to_parent: bool = False):
|
3790
|
+
def leaderboard(self, extra_info=False, refit_full: bool | None = None, set_refit_score_to_parent: bool = False):
|
3647
3791
|
model_names = self.get_model_names()
|
3648
3792
|
models_full_dict = self.get_models_attribute_dict(models=model_names, attribute="refit_full_parent")
|
3649
3793
|
if refit_full is not None:
|
@@ -3956,30 +4100,6 @@ class AbstractTrainer:
|
|
3956
4100
|
|
3957
4101
|
return info
|
3958
4102
|
|
3959
|
-
def get_model_info(self, model: str | AbstractModel) -> Dict[str, Any]:
|
3960
|
-
if isinstance(model, str):
|
3961
|
-
if model in self.models.keys():
|
3962
|
-
model = self.models[model]
|
3963
|
-
if isinstance(model, str):
|
3964
|
-
model_type = self.get_model_attribute(model=model, attribute="type")
|
3965
|
-
model_path = self.get_model_attribute(model=model, attribute="path")
|
3966
|
-
model_info = model_type.load_info(path=os.path.join(self.path, model_path))
|
3967
|
-
else:
|
3968
|
-
model_info = model.get_info()
|
3969
|
-
return model_info
|
3970
|
-
|
3971
|
-
def get_models_info(self, models: List[str | AbstractModel] = None) -> Dict[str, Dict[str, Any]]:
|
3972
|
-
if models is None:
|
3973
|
-
models = self.get_model_names()
|
3974
|
-
model_info_dict = dict()
|
3975
|
-
for model in models:
|
3976
|
-
if isinstance(model, str):
|
3977
|
-
model_name = model
|
3978
|
-
else:
|
3979
|
-
model_name = model.name
|
3980
|
-
model_info_dict[model_name] = self.get_model_info(model=model)
|
3981
|
-
return model_info_dict
|
3982
|
-
|
3983
4103
|
def reduce_memory_size(
|
3984
4104
|
self, remove_data=True, remove_fit_stack=False, remove_fit=True, remove_info=False, requires_save=True, reduce_children=False, **kwargs
|
3985
4105
|
):
|
@@ -4062,7 +4182,7 @@ class AbstractTrainer:
|
|
4062
4182
|
for model in models_to_remove:
|
4063
4183
|
model = self.load_model(model)
|
4064
4184
|
logger.log(30, f"\tDirectory {model.path} would have been deleted.")
|
4065
|
-
logger.log(30,
|
4185
|
+
logger.log(30, "To perform the deletion, set dry_run=False")
|
4066
4186
|
return
|
4067
4187
|
|
4068
4188
|
if delete_from_disk:
|
@@ -4091,37 +4211,8 @@ class AbstractTrainer:
|
|
4091
4211
|
path_attr_model = Path(self._path_attr_model(model))
|
4092
4212
|
shutil.rmtree(path=path_attr_model, ignore_errors=True)
|
4093
4213
|
|
4094
|
-
@
|
4095
|
-
def
|
4096
|
-
load_path = os.path.join(path, cls.trainer_file_name)
|
4097
|
-
if not reset_paths:
|
4098
|
-
return load_pkl.load(path=load_path)
|
4099
|
-
else:
|
4100
|
-
obj = load_pkl.load(path=load_path)
|
4101
|
-
obj.set_contexts(path)
|
4102
|
-
obj.reset_paths = reset_paths
|
4103
|
-
return obj
|
4104
|
-
|
4105
|
-
@classmethod
|
4106
|
-
def load_info(cls, path, reset_paths=False, load_model_if_required=True):
|
4107
|
-
load_path = os.path.join(path, cls.trainer_info_name)
|
4108
|
-
try:
|
4109
|
-
return load_pkl.load(path=load_path)
|
4110
|
-
except:
|
4111
|
-
if load_model_if_required:
|
4112
|
-
trainer = cls.load(path=path, reset_paths=reset_paths)
|
4113
|
-
return trainer.get_info()
|
4114
|
-
else:
|
4115
|
-
raise
|
4116
|
-
|
4117
|
-
def save_info(self, include_model_info=False):
|
4118
|
-
info = self.get_info(include_model_info=include_model_info)
|
4119
|
-
|
4120
|
-
save_pkl.save(path=os.path.join(self.path, self.trainer_info_name), object=info)
|
4121
|
-
save_json.save(path=os.path.join(self.path, self.trainer_info_json_name), obj=info)
|
4122
|
-
return info
|
4123
|
-
|
4124
|
-
def _process_hyperparameters(self, hyperparameters: dict) -> dict:
|
4214
|
+
@staticmethod
|
4215
|
+
def _process_hyperparameters(hyperparameters: dict) -> dict:
|
4125
4216
|
return process_hyperparameters(hyperparameters=hyperparameters)
|
4126
4217
|
|
4127
4218
|
def distill(
|
@@ -4327,7 +4418,7 @@ class AbstractTrainer:
|
|
4327
4418
|
return distilled_model_names
|
4328
4419
|
|
4329
4420
|
def _get_model_fit_kwargs(
|
4330
|
-
self, X: pd.DataFrame, X_val: pd.DataFrame, time_limit: float, k_fold: int, fit_kwargs: dict, ens_sample_weight:
|
4421
|
+
self, X: pd.DataFrame, X_val: pd.DataFrame, time_limit: float, k_fold: int, fit_kwargs: dict, ens_sample_weight: list | None = None
|
4331
4422
|
) -> dict:
|
4332
4423
|
# Returns kwargs to be passed to AbstractModel's fit function
|
4333
4424
|
if fit_kwargs is None:
|
@@ -4364,7 +4455,7 @@ class AbstractTrainer:
|
|
4364
4455
|
k_fold=k_fold, k_fold_start=k_fold_start, k_fold_end=k_fold_end, n_repeats=n_repeats, n_repeat_start=n_repeat_start, compute_base_preds=False
|
4365
4456
|
)
|
4366
4457
|
|
4367
|
-
def _get_feature_prune_proxy_model(self, proxy_model_class:
|
4458
|
+
def _get_feature_prune_proxy_model(self, proxy_model_class: AbstractModel | None, level: int) -> AbstractModel:
|
4368
4459
|
"""
|
4369
4460
|
Returns proxy model to be used for feature pruning - the base learner that has the highest validation score in a particular stack layer.
|
4370
4461
|
Ties are broken by inference speed. If proxy_model_class is not None, take the best base learner belonging to proxy_model_class.
|
@@ -4398,7 +4489,7 @@ class AbstractTrainer:
|
|
4398
4489
|
best_candidate_model_rows = candidate_model_rows.loc[candidate_model_rows["score_val"] == candidate_model_rows["score_val"].max()]
|
4399
4490
|
return self.load_model(best_candidate_model_rows.loc[best_candidate_model_rows["fit_time"].idxmin()]["model"])
|
4400
4491
|
|
4401
|
-
def calibrate_model(self, model_name: str = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0):
|
4492
|
+
def calibrate_model(self, model_name: str | None = None, lr: float = 0.1, max_iter: int = 200, init_val: float = 1.0):
|
4402
4493
|
"""
|
4403
4494
|
Applies temperature scaling to a model.
|
4404
4495
|
Applies inverse softmax to predicted probs then trains temperature scalar
|
@@ -4494,11 +4585,11 @@ class AbstractTrainer:
|
|
4494
4585
|
def calibrate_decision_threshold(
|
4495
4586
|
self,
|
4496
4587
|
X: pd.DataFrame | None = None,
|
4497
|
-
y: np.
|
4588
|
+
y: np.ndarray | None = None,
|
4498
4589
|
metric: str | Scorer | None = None,
|
4499
4590
|
model: str = "best",
|
4500
4591
|
weights=None,
|
4501
|
-
decision_thresholds: int |
|
4592
|
+
decision_thresholds: int | list[float] = 25,
|
4502
4593
|
secondary_decision_thresholds: int | None = 19,
|
4503
4594
|
verbose: bool = True,
|
4504
4595
|
**kwargs,
|
@@ -4565,18 +4656,18 @@ class AbstractTrainer:
|
|
4565
4656
|
)
|
4566
4657
|
|
4567
4658
|
@staticmethod
|
4568
|
-
def _validate_num_classes(num_classes: int, problem_type: str):
|
4659
|
+
def _validate_num_classes(num_classes: int | None, problem_type: str):
|
4569
4660
|
if problem_type == BINARY:
|
4570
4661
|
assert num_classes is not None and num_classes == 2, f"num_classes must be 2 when problem_type='{problem_type}' (num_classes={num_classes})"
|
4571
4662
|
elif problem_type in [MULTICLASS, SOFTCLASS]:
|
4572
4663
|
assert num_classes is not None and num_classes >= 2, f"num_classes must be >=2 when problem_type='{problem_type}' (num_classes={num_classes})"
|
4573
4664
|
elif problem_type in [REGRESSION, QUANTILE]:
|
4574
|
-
assert num_classes is None, f"
|
4665
|
+
assert num_classes is None, f"num_classes must be None when problem_type='{problem_type}' (num_classes={num_classes})"
|
4575
4666
|
else:
|
4576
4667
|
raise AssertionError(f"Unknown problem_type: '{problem_type}'. Valid problem types: {[BINARY, MULTICLASS, REGRESSION, SOFTCLASS, QUANTILE]}")
|
4577
4668
|
|
4578
4669
|
@staticmethod
|
4579
|
-
def _validate_quantile_levels(quantile_levels:
|
4670
|
+
def _validate_quantile_levels(quantile_levels: list[float] | np.ndarray | None, problem_type: str):
|
4580
4671
|
if problem_type == QUANTILE:
|
4581
4672
|
assert quantile_levels is not None, f"quantile_levels must not be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
|
4582
4673
|
assert isinstance(quantile_levels, (list, np.ndarray)), f"quantile_levels must be a list or np.ndarray (quantile_levels={quantile_levels})"
|
@@ -4587,13 +4678,13 @@ class AbstractTrainer:
|
|
4587
4678
|
|
4588
4679
|
def _detached_train_multi_fold(
|
4589
4680
|
*,
|
4590
|
-
_self:
|
4681
|
+
_self: AbstractTabularTrainer,
|
4591
4682
|
model: str | AbstractModel,
|
4592
4683
|
X: pd.DataFrame,
|
4593
4684
|
y: pd.Series,
|
4594
4685
|
time_split: bool,
|
4595
4686
|
time_start: float,
|
4596
|
-
time_limit: float|None,
|
4687
|
+
time_limit: float | None,
|
4597
4688
|
time_limit_model_split: float | None,
|
4598
4689
|
hyperparameter_tune_kwargs: dict,
|
4599
4690
|
is_ray_worker: bool = False,
|
@@ -4636,7 +4727,7 @@ def _detached_train_multi_fold(
|
|
4636
4727
|
|
4637
4728
|
def _remote_train_multi_fold(
|
4638
4729
|
*,
|
4639
|
-
_self:
|
4730
|
+
_self: AbstractTabularTrainer,
|
4640
4731
|
model: str | AbstractModel,
|
4641
4732
|
X: pd.DataFrame,
|
4642
4733
|
y: pd.Series,
|
@@ -4694,7 +4785,7 @@ def _remote_train_multi_fold(
|
|
4694
4785
|
|
4695
4786
|
def _detached_refit_single_full(
|
4696
4787
|
*,
|
4697
|
-
_self:
|
4788
|
+
_self: AbstractTabularTrainer,
|
4698
4789
|
model: str,
|
4699
4790
|
X: pd.DataFrame,
|
4700
4791
|
y: pd.Series,
|
@@ -4787,7 +4878,7 @@ def _detached_refit_single_full(
|
|
4787
4878
|
|
4788
4879
|
def _remote_refit_single_full(
|
4789
4880
|
*,
|
4790
|
-
_self:
|
4881
|
+
_self: AbstractTabularTrainer,
|
4791
4882
|
model: str,
|
4792
4883
|
X: pd.DataFrame,
|
4793
4884
|
y: pd.Series,
|