autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,748 @@
|
|
1
|
+
import copy
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import pathlib
|
6
|
+
import pickle
|
7
|
+
import pprint
|
8
|
+
import time
|
9
|
+
import warnings
|
10
|
+
from typing import Callable, Dict, List, Optional, Union
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
import pandas as pd
|
14
|
+
|
15
|
+
from autogluon.core.metrics import Scorer
|
16
|
+
from autogluon.core.models.greedy_ensemble.ensemble_selection import EnsembleSelection
|
17
|
+
|
18
|
+
from .. import version as ag_version
|
19
|
+
from ..constants import BINARY, LOGITS, MULTICLASS, REGRESSION, TEST, VAL, Y_PRED, Y_TRUE
|
20
|
+
from ..optim import compute_score
|
21
|
+
from ..utils import (
|
22
|
+
extract_from_output,
|
23
|
+
get_dir_ckpt_paths,
|
24
|
+
logits_to_prob,
|
25
|
+
on_fit_end_message,
|
26
|
+
update_ensemble_hyperparameters,
|
27
|
+
)
|
28
|
+
from .base import BaseLearner
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
class EnsembleLearner(BaseLearner):
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
label: Optional[str] = None,
|
37
|
+
problem_type: Optional[str] = None,
|
38
|
+
presets: Optional[str] = "high_quality",
|
39
|
+
eval_metric: Optional[Union[str, Scorer]] = None,
|
40
|
+
hyperparameters: Optional[dict] = None,
|
41
|
+
path: Optional[str] = None,
|
42
|
+
verbosity: Optional[int] = 2,
|
43
|
+
warn_if_exist: Optional[bool] = True,
|
44
|
+
enable_progress_bar: Optional[bool] = None,
|
45
|
+
ensemble_size: Optional[int] = 2,
|
46
|
+
ensemble_mode: Optional[str] = "one_shot",
|
47
|
+
**kwargs,
|
48
|
+
):
|
49
|
+
"""
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
label
|
53
|
+
Name of the column that contains the target variable to predict.
|
54
|
+
problem_type
|
55
|
+
Type of the prediction problem. We support standard problems like
|
56
|
+
|
57
|
+
- 'binary': Binary classification
|
58
|
+
- 'multiclass': Multi-class classification
|
59
|
+
- 'regression': Regression
|
60
|
+
- 'classification': Classification problems include 'binary' and 'multiclass' classification.
|
61
|
+
presets
|
62
|
+
Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
|
63
|
+
eval_metric
|
64
|
+
Evaluation metric name. If `eval_metric = None`, it is automatically chosen based on `problem_type`.
|
65
|
+
Defaults to 'accuracy' for multiclass classification, `roc_auc` for binary classification, and 'root_mean_squared_error' for regression.
|
66
|
+
hyperparameters
|
67
|
+
This is to override some default configurations.
|
68
|
+
For example, changing the text and image backbones can be done by formatting:
|
69
|
+
|
70
|
+
a string
|
71
|
+
hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
|
72
|
+
|
73
|
+
or a list of strings
|
74
|
+
hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
|
75
|
+
|
76
|
+
or a dictionary
|
77
|
+
hyperparameters = {
|
78
|
+
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
|
79
|
+
"model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
|
80
|
+
}
|
81
|
+
path
|
82
|
+
Path to directory where models and intermediate outputs should be saved.
|
83
|
+
If unspecified, a time-stamped folder called "AutogluonAutoMM/ag-[TIMESTAMP]"
|
84
|
+
will be created in the working directory to store all models.
|
85
|
+
Note: To call `fit()` twice and save all results of each fit,
|
86
|
+
you must specify different `path` locations or don't specify `path` at all.
|
87
|
+
Otherwise files from first `fit()` will be overwritten by second `fit()`.
|
88
|
+
verbosity
|
89
|
+
Verbosity levels range from 0 to 4 and control how much information is printed.
|
90
|
+
Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings).
|
91
|
+
If using logging, you can alternatively control amount of information printed via `logger.setLevel(L)`,
|
92
|
+
where `L` ranges from 0 to 50
|
93
|
+
(Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels)
|
94
|
+
warn_if_exist
|
95
|
+
Whether to raise warning if the specified path already exists.
|
96
|
+
enable_progress_bar
|
97
|
+
Whether to show progress bar. It will be True by default and will also be
|
98
|
+
disabled if the environment variable os.environ["AUTOMM_DISABLE_PROGRESS_BAR"] is set.
|
99
|
+
ensemble_size
|
100
|
+
A multiple of number of models in the ensembling pool (Default 2). The actual ensemble size = ensemble_size * the model number
|
101
|
+
ensemble_mode
|
102
|
+
The mode of conducting ensembling:
|
103
|
+
- `one_shot`: the classic ensemble selection
|
104
|
+
- `sequential`: iteratively calling the classic ensemble selection with each time growing the model zoo by the best next model.
|
105
|
+
"""
|
106
|
+
super().__init__(
|
107
|
+
label=label,
|
108
|
+
problem_type=problem_type,
|
109
|
+
presets=presets,
|
110
|
+
eval_metric=eval_metric,
|
111
|
+
hyperparameters=hyperparameters,
|
112
|
+
path=path,
|
113
|
+
verbosity=verbosity,
|
114
|
+
warn_if_exist=warn_if_exist,
|
115
|
+
enable_progress_bar=enable_progress_bar,
|
116
|
+
)
|
117
|
+
self._ensemble_size = int(ensemble_size)
|
118
|
+
assert ensemble_mode in ["sequential", "one_shot"]
|
119
|
+
self._ensemble_mode = ensemble_mode
|
120
|
+
self._weighted_ensemble = None
|
121
|
+
self._selected_learners = None
|
122
|
+
self._all_learners = None
|
123
|
+
self._selected_indices = []
|
124
|
+
self._relative_path = False
|
125
|
+
|
126
|
+
return
|
127
|
+
|
128
|
+
def get_learner_path(self, learner_path: str):
|
129
|
+
if self._relative_path:
|
130
|
+
learner_path = os.path.join(self._save_path, learner_path)
|
131
|
+
return learner_path
|
132
|
+
|
133
|
+
def get_learner_name(self, learner):
|
134
|
+
if isinstance(learner, str):
|
135
|
+
if self._relative_path:
|
136
|
+
learner_name = learner
|
137
|
+
else:
|
138
|
+
learner_name = pathlib.PurePath(learner).name
|
139
|
+
else:
|
140
|
+
learner_name = pathlib.PurePath(learner.path).name
|
141
|
+
|
142
|
+
return learner_name
|
143
|
+
|
144
|
+
def predict_all_for_ensembling(
|
145
|
+
self,
|
146
|
+
learners: List[Union[str, BaseLearner]],
|
147
|
+
data: Union[pd.DataFrame, str],
|
148
|
+
mode: str,
|
149
|
+
requires_label: Optional[bool] = False,
|
150
|
+
save: Optional[bool] = False,
|
151
|
+
):
|
152
|
+
assert mode in [VAL, TEST]
|
153
|
+
predictions = []
|
154
|
+
labels = None
|
155
|
+
i = 0
|
156
|
+
for per_learner in learners:
|
157
|
+
i += 1
|
158
|
+
logger.info(f"\npredicting with learner {i}: {per_learner}\n")
|
159
|
+
if isinstance(per_learner, str):
|
160
|
+
per_learner_path = self.get_learner_path(per_learner)
|
161
|
+
else:
|
162
|
+
per_learner_path = per_learner.path
|
163
|
+
|
164
|
+
pred_file_path = os.path.join(per_learner_path, f"{mode}_predictions.npy")
|
165
|
+
if os.path.isfile(pred_file_path):
|
166
|
+
logger.info(f"{mode}_predictions.npy exists. loading it...")
|
167
|
+
y_pred = np.load(pred_file_path)
|
168
|
+
else:
|
169
|
+
if isinstance(per_learner, str):
|
170
|
+
per_learner = BaseLearner.load(path=per_learner_path)
|
171
|
+
if not self._problem_type:
|
172
|
+
self._problem_type = per_learner.problem_type
|
173
|
+
else:
|
174
|
+
assert self._problem_type == per_learner.problem_type
|
175
|
+
outputs = per_learner.predict_per_run(
|
176
|
+
data=data,
|
177
|
+
realtime=False,
|
178
|
+
requires_label=False,
|
179
|
+
)
|
180
|
+
y_pred = extract_from_output(outputs=outputs, ret_type=LOGITS)
|
181
|
+
|
182
|
+
if self._problem_type == REGRESSION:
|
183
|
+
y_pred = per_learner._df_preprocessor.transform_prediction(y_pred=y_pred)
|
184
|
+
if self._problem_type in [BINARY, MULTICLASS]:
|
185
|
+
y_pred = logits_to_prob(y_pred)
|
186
|
+
if self._problem_type == BINARY:
|
187
|
+
y_pred = y_pred[:, 1]
|
188
|
+
|
189
|
+
if save:
|
190
|
+
np.save(pred_file_path, y_pred)
|
191
|
+
|
192
|
+
if requires_label:
|
193
|
+
label_file_path = os.path.join(per_learner_path, f"{mode}_labels.npy")
|
194
|
+
if os.path.isfile(label_file_path):
|
195
|
+
logger.info(f"{mode}_labels.npy exists. loading it...")
|
196
|
+
y_true = np.load(label_file_path)
|
197
|
+
else:
|
198
|
+
if isinstance(per_learner, str):
|
199
|
+
per_learner = BaseLearner.load(path=per_learner_path)
|
200
|
+
y_true = per_learner._df_preprocessor.transform_label_for_metric(df=data)
|
201
|
+
|
202
|
+
if save:
|
203
|
+
np.save(label_file_path, y_true)
|
204
|
+
|
205
|
+
if labels is None:
|
206
|
+
labels = y_true
|
207
|
+
else:
|
208
|
+
assert np.array_equal(y_true, labels)
|
209
|
+
|
210
|
+
predictions.append(y_pred)
|
211
|
+
|
212
|
+
if requires_label:
|
213
|
+
return predictions, labels
|
214
|
+
else:
|
215
|
+
return predictions
|
216
|
+
|
217
|
+
@staticmethod
|
218
|
+
def verify_predictions_labels(predictions, learners, labels=None):
|
219
|
+
if labels is not None:
|
220
|
+
assert isinstance(labels, np.ndarray)
|
221
|
+
assert isinstance(predictions, list) and all(isinstance(ele, np.ndarray) for ele in predictions)
|
222
|
+
assert len(learners) == len(
|
223
|
+
predictions
|
224
|
+
), f"len(learners) {len(learners)} doesn't match len(predictions) {len(predictions)}"
|
225
|
+
|
226
|
+
def fit_per_ensemble(
|
227
|
+
self,
|
228
|
+
predictions: List[np.ndarray],
|
229
|
+
labels: np.ndarray,
|
230
|
+
):
|
231
|
+
weighted_ensemble = EnsembleSelection(
|
232
|
+
ensemble_size=self._ensemble_size * len(predictions),
|
233
|
+
problem_type=self._problem_type,
|
234
|
+
metric=self._eval_metric_func,
|
235
|
+
)
|
236
|
+
weighted_ensemble.fit(predictions=predictions, labels=labels)
|
237
|
+
|
238
|
+
return weighted_ensemble
|
239
|
+
|
240
|
+
def select_next_best(self, left_learner_indices, selected_learner_indices, predictions, labels):
|
241
|
+
best_regret = None
|
242
|
+
best_weighted_ensemble = None
|
243
|
+
best_next_index = None
|
244
|
+
for i in left_learner_indices:
|
245
|
+
tmp_learner_indices = selected_learner_indices + [i]
|
246
|
+
tmp_predictions = [predictions[j] for j in tmp_learner_indices]
|
247
|
+
tmp_weighted_ensemble = self.fit_per_ensemble(
|
248
|
+
predictions=tmp_predictions,
|
249
|
+
labels=labels,
|
250
|
+
)
|
251
|
+
if best_regret is None or tmp_weighted_ensemble.train_score_ < best_regret:
|
252
|
+
best_regret = tmp_weighted_ensemble.train_score_
|
253
|
+
best_weighted_ensemble = tmp_weighted_ensemble
|
254
|
+
best_next_index = i
|
255
|
+
|
256
|
+
return best_regret, best_next_index, best_weighted_ensemble
|
257
|
+
|
258
|
+
def sequential_ensemble(
|
259
|
+
self,
|
260
|
+
predictions: List[np.ndarray],
|
261
|
+
labels: np.ndarray,
|
262
|
+
):
|
263
|
+
selected_learner_indices = []
|
264
|
+
all_learner_indices = list(range(len(predictions)))
|
265
|
+
best_regret = None
|
266
|
+
best_weighted_ensemble = None
|
267
|
+
best_selected_learner_indices = None
|
268
|
+
while len(selected_learner_indices) < len(all_learner_indices):
|
269
|
+
left_learner_indices = [i for i in all_learner_indices if i not in selected_learner_indices]
|
270
|
+
assert sorted(all_learner_indices) == sorted(selected_learner_indices + left_learner_indices)
|
271
|
+
logger.debug(f"\nleft_learner_indices: {left_learner_indices}")
|
272
|
+
if not left_learner_indices:
|
273
|
+
break
|
274
|
+
logger.debug(f"selected_learner_indices: {selected_learner_indices}")
|
275
|
+
tmp_reget, next_index, tmp_weighted_ensemble = self.select_next_best(
|
276
|
+
left_learner_indices=left_learner_indices,
|
277
|
+
selected_learner_indices=selected_learner_indices,
|
278
|
+
predictions=predictions,
|
279
|
+
labels=labels,
|
280
|
+
)
|
281
|
+
selected_learner_indices.append(next_index)
|
282
|
+
if best_regret is None or tmp_reget < best_regret:
|
283
|
+
best_regret = tmp_reget
|
284
|
+
best_weighted_ensemble = tmp_weighted_ensemble
|
285
|
+
best_selected_learner_indices = copy.deepcopy(selected_learner_indices)
|
286
|
+
logger.debug(f"\nbest score: {self._eval_metric_func._optimum-best_regret}")
|
287
|
+
logger.debug(f"best_selected_learner_indices: {best_selected_learner_indices}")
|
288
|
+
logger.debug(f"best_ensemble_weights: {best_weighted_ensemble.weights_}")
|
289
|
+
|
290
|
+
return best_weighted_ensemble, best_selected_learner_indices
|
291
|
+
|
292
|
+
def update_hyperparameters(self, hyperparameters: Dict):
|
293
|
+
if self._hyperparameters and hyperparameters:
|
294
|
+
self._hyperparameters.update(hyperparameters)
|
295
|
+
elif hyperparameters:
|
296
|
+
self._hyperparameters = hyperparameters
|
297
|
+
|
298
|
+
self._hyperparameters = update_ensemble_hyperparameters(
|
299
|
+
presets=self._presets,
|
300
|
+
provided_hyperparameters=self._hyperparameters,
|
301
|
+
)
|
302
|
+
# filter out meta-transformer if no local checkpoint path is provided
|
303
|
+
if "early_fusion" in self._hyperparameters:
|
304
|
+
if self._hyperparameters["early_fusion"]["model.meta_transformer.checkpoint_path"] == "null":
|
305
|
+
self._hyperparameters.pop("early_fusion")
|
306
|
+
message = (
|
307
|
+
"`early_fusion` will not be used in ensembling because `early_fusion` relies on MetaTransformer, "
|
308
|
+
"but no local MetaTransformer model checkpoint is provided. To use `early_fusion`, "
|
309
|
+
"download its model checkpoints from https://github.com/invictus717/MetaTransformer to local "
|
310
|
+
"and set the checkpoint path as follows:\n"
|
311
|
+
"```python\n"
|
312
|
+
"hyperparameters = {\n"
|
313
|
+
' "early_fusion": {\n'
|
314
|
+
' "model.meta_transformer.checkpoint_path": args.meta_transformer_ckpt_path,\n'
|
315
|
+
" }\n"
|
316
|
+
"}\n"
|
317
|
+
"```\n"
|
318
|
+
"Note that presets `high_quality` (default) and `medium_quality` need the base model, while preset "
|
319
|
+
"`best_quality` requires the large model. Make sure to download the right MetaTransformer version. "
|
320
|
+
"We recommend using the download links under tag `国内下载源` because the corresponding "
|
321
|
+
"downloaded models are not compressed and can be loaded directly.\n"
|
322
|
+
)
|
323
|
+
|
324
|
+
logger.warning(message)
|
325
|
+
|
326
|
+
def fit_all(
|
327
|
+
self,
|
328
|
+
train_data,
|
329
|
+
tuning_data,
|
330
|
+
hyperparameters,
|
331
|
+
column_types,
|
332
|
+
holdout_frac,
|
333
|
+
time_limit,
|
334
|
+
seed,
|
335
|
+
standalone,
|
336
|
+
clean_ckpts,
|
337
|
+
):
|
338
|
+
self._relative_path = True
|
339
|
+
self.update_hyperparameters(hyperparameters=hyperparameters)
|
340
|
+
|
341
|
+
learners = []
|
342
|
+
assert (
|
343
|
+
len(self._hyperparameters) > 1
|
344
|
+
), f"Ensembling requires training more than 1 learners, but got {len(self._hyperparameters)} sets of hyperparameters."
|
345
|
+
logger.info(
|
346
|
+
f"Will ensemble {len(self._hyperparameters)} models with the following configs:\n {pprint.pformat(self._hyperparameters)}"
|
347
|
+
)
|
348
|
+
for per_name, per_hparams in self._hyperparameters.items():
|
349
|
+
per_learner_path = os.path.join(self._save_path, per_name)
|
350
|
+
if not os.path.isdir(per_learner_path):
|
351
|
+
logger.info(f"\nfitting learner {per_name}")
|
352
|
+
logger.debug(f"hyperparameters: {per_hparams}")
|
353
|
+
per_learner = BaseLearner(
|
354
|
+
label=self._label_column,
|
355
|
+
problem_type=self._problem_type,
|
356
|
+
presets=self._presets,
|
357
|
+
eval_metric=self._eval_metric_func,
|
358
|
+
hyperparameters=per_hparams,
|
359
|
+
path=per_learner_path,
|
360
|
+
verbosity=self._verbosity,
|
361
|
+
warn_if_exist=self._warn_if_exist,
|
362
|
+
enable_progress_bar=self._enable_progress_bar,
|
363
|
+
pretrained=self._pretrained,
|
364
|
+
validation_metric=self._validation_metric_name,
|
365
|
+
)
|
366
|
+
per_learner.fit(
|
367
|
+
train_data=train_data,
|
368
|
+
tuning_data=tuning_data,
|
369
|
+
time_limit=time_limit,
|
370
|
+
column_types=column_types,
|
371
|
+
holdout_frac=holdout_frac,
|
372
|
+
seed=seed,
|
373
|
+
standalone=standalone,
|
374
|
+
clean_ckpts=clean_ckpts,
|
375
|
+
)
|
376
|
+
learners.append(per_name)
|
377
|
+
|
378
|
+
return learners
|
379
|
+
|
380
|
+
def on_fit_end(
|
381
|
+
self,
|
382
|
+
training_start: float,
|
383
|
+
**kwargs,
|
384
|
+
):
|
385
|
+
self._fit_called = True
|
386
|
+
training_end = time.time()
|
387
|
+
self._total_train_time = training_end - training_start
|
388
|
+
logger.info(on_fit_end_message(self._save_path))
|
389
|
+
|
390
|
+
def update_attributes_by_first_learner(self, learners: List):
|
391
|
+
# load df preprocessor from the first learner
|
392
|
+
if isinstance(learners[0], str):
|
393
|
+
first_learner_path = self.get_learner_path(learners[0])
|
394
|
+
dir_path, ckpt_path = get_dir_ckpt_paths(path=first_learner_path)
|
395
|
+
assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
|
396
|
+
first_learner = BaseLearner(label="dummy_label")
|
397
|
+
first_learner = BaseLearner._load_metadata(learner=first_learner, path=dir_path)
|
398
|
+
else:
|
399
|
+
first_learner = learners[0]
|
400
|
+
|
401
|
+
self._df_preprocessor = first_learner._df_preprocessor
|
402
|
+
self._eval_metric_func = first_learner._eval_metric_func
|
403
|
+
self._eval_metric_name = first_learner._eval_metric_name
|
404
|
+
self._problem_type = first_learner._problem_type
|
405
|
+
|
406
|
+
def fit_ensemble(
|
407
|
+
self,
|
408
|
+
predictions: Optional[List[np.ndarray]] = None,
|
409
|
+
labels: Optional[np.ndarray] = None,
|
410
|
+
learners: Optional[List[Union[str, BaseLearner]]] = None,
|
411
|
+
train_data: Optional[Union[pd.DataFrame, str]] = None,
|
412
|
+
tuning_data: Optional[Union[pd.DataFrame, str]] = None,
|
413
|
+
holdout_frac: Optional[float] = None,
|
414
|
+
seed: Optional[int] = 0,
|
415
|
+
):
|
416
|
+
if not predictions or labels is None:
|
417
|
+
self.prepare_train_tuning_data(
|
418
|
+
train_data=train_data,
|
419
|
+
tuning_data=tuning_data,
|
420
|
+
holdout_frac=holdout_frac,
|
421
|
+
seed=seed,
|
422
|
+
)
|
423
|
+
predictions, labels = self.predict_all_for_ensembling(
|
424
|
+
learners=learners,
|
425
|
+
data=self._tuning_data,
|
426
|
+
mode=VAL,
|
427
|
+
requires_label=True,
|
428
|
+
save=True,
|
429
|
+
)
|
430
|
+
|
431
|
+
self.verify_predictions_labels(
|
432
|
+
predictions=predictions,
|
433
|
+
labels=labels,
|
434
|
+
learners=learners,
|
435
|
+
)
|
436
|
+
|
437
|
+
if self._ensemble_mode == "sequential":
|
438
|
+
weighted_ensemble, selected_learner_indices = self.sequential_ensemble(
|
439
|
+
predictions=predictions,
|
440
|
+
labels=labels,
|
441
|
+
)
|
442
|
+
elif self._ensemble_mode == "one_shot":
|
443
|
+
weighted_ensemble = self.fit_per_ensemble(
|
444
|
+
predictions=predictions,
|
445
|
+
labels=labels,
|
446
|
+
)
|
447
|
+
selected_learner_indices = list(range(len(learners)))
|
448
|
+
else:
|
449
|
+
raise ValueError(f"Unsupported ensemble_mode: {self._ensemble_mode}")
|
450
|
+
|
451
|
+
predictions = [predictions[j] for j in selected_learner_indices]
|
452
|
+
predictions = weighted_ensemble.predict_proba(predictions)
|
453
|
+
|
454
|
+
# for regression, the transform_prediction() is already called in predict_all()
|
455
|
+
if self._eval_metric_func.needs_pred and self._problem_type != REGRESSION:
|
456
|
+
predictions = self._df_preprocessor.transform_prediction(
|
457
|
+
y_pred=predictions,
|
458
|
+
inverse_categorical=False,
|
459
|
+
)
|
460
|
+
metric_data = {
|
461
|
+
Y_PRED: predictions,
|
462
|
+
Y_TRUE: labels,
|
463
|
+
}
|
464
|
+
score = compute_score(
|
465
|
+
metric_data=metric_data,
|
466
|
+
metric=self._eval_metric_func,
|
467
|
+
)
|
468
|
+
|
469
|
+
logger.debug(f"\nEnsembling score on validation data: {score}")
|
470
|
+
|
471
|
+
return weighted_ensemble, selected_learner_indices
|
472
|
+
|
473
|
+
def fit(
|
474
|
+
self,
|
475
|
+
train_data: Union[pd.DataFrame, str],
|
476
|
+
presets: Optional[str] = None,
|
477
|
+
tuning_data: Optional[Union[pd.DataFrame, str]] = None,
|
478
|
+
time_limit: Optional[int] = None,
|
479
|
+
save_path: Optional[str] = None,
|
480
|
+
hyperparameters: Optional[Union[str, Dict, List[str]]] = None,
|
481
|
+
column_types: Optional[Dict] = None,
|
482
|
+
holdout_frac: Optional[float] = None,
|
483
|
+
teacher_learner: Union[str, BaseLearner] = None,
|
484
|
+
seed: Optional[int] = 0,
|
485
|
+
standalone: Optional[bool] = True,
|
486
|
+
hyperparameter_tune_kwargs: Optional[Dict] = None,
|
487
|
+
clean_ckpts: Optional[bool] = True,
|
488
|
+
learners: Optional[List[Union[str, BaseLearner]]] = None,
|
489
|
+
predictions: Optional[List[np.ndarray]] = None,
|
490
|
+
labels: Optional[np.ndarray] = None,
|
491
|
+
**kwargs,
|
492
|
+
):
|
493
|
+
self.setup_save_path(save_path=save_path)
|
494
|
+
training_start = self.on_fit_start(presets=presets)
|
495
|
+
if learners is None:
|
496
|
+
learners = self.fit_all(
|
497
|
+
train_data=train_data,
|
498
|
+
tuning_data=tuning_data,
|
499
|
+
hyperparameters=hyperparameters,
|
500
|
+
column_types=column_types,
|
501
|
+
holdout_frac=holdout_frac,
|
502
|
+
time_limit=time_limit,
|
503
|
+
seed=seed,
|
504
|
+
standalone=standalone,
|
505
|
+
clean_ckpts=clean_ckpts,
|
506
|
+
)
|
507
|
+
assert len(learners) > 1, f"Ensembling requires more than 1 learners, but got {len(learners)}."
|
508
|
+
|
509
|
+
self.update_attributes_by_first_learner(learners=learners)
|
510
|
+
weighted_ensemble, selected_learner_indices = self.fit_ensemble(
|
511
|
+
predictions=predictions,
|
512
|
+
labels=labels,
|
513
|
+
learners=learners,
|
514
|
+
train_data=train_data,
|
515
|
+
tuning_data=tuning_data,
|
516
|
+
holdout_frac=holdout_frac,
|
517
|
+
seed=seed,
|
518
|
+
)
|
519
|
+
|
520
|
+
assert len(selected_learner_indices) == len(weighted_ensemble.weights_)
|
521
|
+
self._weighted_ensemble = weighted_ensemble
|
522
|
+
self._selected_learners = [learners[i] for i in selected_learner_indices]
|
523
|
+
self._all_learners = learners
|
524
|
+
self._selected_indices = selected_learner_indices
|
525
|
+
|
526
|
+
self.on_fit_end(training_start=training_start)
|
527
|
+
self.save(path=self._save_path)
|
528
|
+
|
529
|
+
return self
|
530
|
+
|
531
|
+
def predict(
|
532
|
+
self,
|
533
|
+
data: Union[pd.DataFrame, dict, list, str],
|
534
|
+
predictions: Optional[List[np.ndarray]] = None,
|
535
|
+
as_pandas: Optional[bool] = None,
|
536
|
+
**kwargs,
|
537
|
+
):
|
538
|
+
self.on_predict_start()
|
539
|
+
if not predictions:
|
540
|
+
predictions = self.predict_all_for_ensembling(
|
541
|
+
learners=self._selected_learners,
|
542
|
+
data=data,
|
543
|
+
mode=TEST,
|
544
|
+
requires_label=False,
|
545
|
+
save=False,
|
546
|
+
)
|
547
|
+
else:
|
548
|
+
predictions = [predictions[i] for i in self._selected_indices]
|
549
|
+
|
550
|
+
self.verify_predictions_labels(
|
551
|
+
predictions=predictions,
|
552
|
+
learners=self._selected_learners,
|
553
|
+
)
|
554
|
+
pred = self._weighted_ensemble.predict_proba(predictions)
|
555
|
+
# for regression, the transform_prediction() is already called in predict_all()
|
556
|
+
if self._problem_type in [BINARY, MULTICLASS]:
|
557
|
+
pred = self._df_preprocessor.transform_prediction(
|
558
|
+
y_pred=pred,
|
559
|
+
inverse_categorical=True,
|
560
|
+
)
|
561
|
+
if (as_pandas is None and isinstance(data, pd.DataFrame)) or as_pandas is True:
|
562
|
+
pred = self._as_pandas(data=data, to_be_converted=pred)
|
563
|
+
|
564
|
+
return pred
|
565
|
+
|
566
|
+
def predict_proba(
|
567
|
+
self,
|
568
|
+
data: Union[pd.DataFrame, dict, list],
|
569
|
+
predictions: Optional[List[np.ndarray]] = None,
|
570
|
+
as_pandas: Optional[bool] = None,
|
571
|
+
as_multiclass: Optional[bool] = True,
|
572
|
+
**kwargs,
|
573
|
+
):
|
574
|
+
self.on_predict_start()
|
575
|
+
assert self._problem_type not in [
|
576
|
+
REGRESSION,
|
577
|
+
], f"Problem {self._problem_type} has no probability output."
|
578
|
+
|
579
|
+
if not predictions:
|
580
|
+
predictions = self.predict_all_for_ensembling(
|
581
|
+
learners=self._selected_learners,
|
582
|
+
data=data,
|
583
|
+
mode=TEST,
|
584
|
+
requires_label=False,
|
585
|
+
save=False,
|
586
|
+
)
|
587
|
+
else:
|
588
|
+
predictions = [predictions[i] for i in self._selected_indices]
|
589
|
+
|
590
|
+
self.verify_predictions_labels(
|
591
|
+
predictions=predictions,
|
592
|
+
learners=self._selected_learners,
|
593
|
+
)
|
594
|
+
prob = self._weighted_ensemble.predict_proba(predictions)
|
595
|
+
if as_multiclass and self._problem_type == BINARY:
|
596
|
+
prob = np.column_stack((1 - prob, prob))
|
597
|
+
|
598
|
+
if (as_pandas is None and isinstance(data, pd.DataFrame)) or as_pandas is True:
|
599
|
+
prob = self._as_pandas(data=data, to_be_converted=prob)
|
600
|
+
|
601
|
+
return prob
|
602
|
+
|
603
|
+
def evaluate(
|
604
|
+
self,
|
605
|
+
data: Union[pd.DataFrame, dict, list, str],
|
606
|
+
predictions: Optional[List[np.ndarray]] = None,
|
607
|
+
labels: Optional[np.ndarray] = None,
|
608
|
+
save_all: Optional[bool] = True,
|
609
|
+
**kwargs,
|
610
|
+
):
|
611
|
+
self.on_predict_start()
|
612
|
+
if not predictions or labels is None:
|
613
|
+
if save_all:
|
614
|
+
learners = self._all_learners
|
615
|
+
else:
|
616
|
+
learners = self._selected_learners
|
617
|
+
predictions, labels = self.predict_all_for_ensembling(
|
618
|
+
learners=learners,
|
619
|
+
data=data,
|
620
|
+
mode=TEST,
|
621
|
+
requires_label=True,
|
622
|
+
save=True,
|
623
|
+
)
|
624
|
+
if save_all:
|
625
|
+
predictions = [predictions[i] for i in self._selected_indices]
|
626
|
+
else:
|
627
|
+
predictions = [predictions[i] for i in self._selected_indices]
|
628
|
+
|
629
|
+
self.verify_predictions_labels(
|
630
|
+
predictions=predictions,
|
631
|
+
labels=labels,
|
632
|
+
learners=self._selected_learners,
|
633
|
+
)
|
634
|
+
all_scores = dict()
|
635
|
+
for per_predictions, per_learner in zip(predictions, self._selected_learners):
|
636
|
+
if not isinstance(per_learner, str):
|
637
|
+
per_learner = per_learner.path
|
638
|
+
metric_data = {
|
639
|
+
Y_PRED: per_predictions,
|
640
|
+
Y_TRUE: labels,
|
641
|
+
}
|
642
|
+
all_scores[per_learner] = compute_score(
|
643
|
+
metric_data=metric_data,
|
644
|
+
metric=self._eval_metric_func,
|
645
|
+
)
|
646
|
+
|
647
|
+
predictions = self._weighted_ensemble.predict_proba(predictions)
|
648
|
+
# for regression, the transform_prediction() is already called in predict_all()
|
649
|
+
if self._eval_metric_func.needs_pred and self._problem_type != REGRESSION:
|
650
|
+
predictions = self._df_preprocessor.transform_prediction(
|
651
|
+
y_pred=predictions,
|
652
|
+
inverse_categorical=False,
|
653
|
+
)
|
654
|
+
metric_data = {
|
655
|
+
Y_PRED: predictions,
|
656
|
+
Y_TRUE: labels,
|
657
|
+
}
|
658
|
+
all_scores["ensemble"] = compute_score(
|
659
|
+
metric_data=metric_data,
|
660
|
+
metric=self._eval_metric_func,
|
661
|
+
)
|
662
|
+
|
663
|
+
return all_scores
|
664
|
+
|
665
|
+
def extract_embedding(
|
666
|
+
self,
|
667
|
+
data: Union[pd.DataFrame, dict, list],
|
668
|
+
return_masks: Optional[bool] = False,
|
669
|
+
as_tensor: Optional[bool] = False,
|
670
|
+
as_pandas: Optional[bool] = False,
|
671
|
+
realtime: Optional[bool] = False,
|
672
|
+
**kwargs,
|
673
|
+
):
|
674
|
+
raise ValueError(f"EnsembleLearner doesn't support extracting embedding yet.")
|
675
|
+
|
676
|
+
def save(
|
677
|
+
self,
|
678
|
+
path: str,
|
679
|
+
**kwargs,
|
680
|
+
):
|
681
|
+
selected_learner_names = [self.get_learner_name(per_learner) for per_learner in self._selected_learners]
|
682
|
+
all_learner_names = [self.get_learner_name(per_learner) for per_learner in self._all_learners]
|
683
|
+
|
684
|
+
os.makedirs(path, exist_ok=True)
|
685
|
+
with open(os.path.join(path, f"assets.json"), "w") as fp:
|
686
|
+
json.dump(
|
687
|
+
{
|
688
|
+
"learner_class": self.__class__.__name__,
|
689
|
+
"ensemble_size": self._ensemble_size,
|
690
|
+
"ensemble_mode": self._ensemble_mode,
|
691
|
+
"selected_learners": selected_learner_names,
|
692
|
+
"all_learners": all_learner_names,
|
693
|
+
"selected_indices": self._selected_indices,
|
694
|
+
"ensemble_weights": self._weighted_ensemble.weights_.tolist(),
|
695
|
+
"save_path": path,
|
696
|
+
"relative_path": True,
|
697
|
+
"fit_called": self._fit_called,
|
698
|
+
"version": ag_version.__version__,
|
699
|
+
"hyperparameters": self._hyperparameters,
|
700
|
+
},
|
701
|
+
fp,
|
702
|
+
ensure_ascii=True,
|
703
|
+
)
|
704
|
+
|
705
|
+
with open(os.path.join(path, "ensemble.pkl"), "wb") as fp:
|
706
|
+
pickle.dump(self._weighted_ensemble, fp)
|
707
|
+
|
708
|
+
# save each learner
|
709
|
+
for per_learner in self._all_learners:
|
710
|
+
per_learner_name = self.get_learner_name(per_learner)
|
711
|
+
if isinstance(per_learner, str):
|
712
|
+
per_learner_path = self.get_learner_path(per_learner)
|
713
|
+
per_learner = BaseLearner.load(per_learner_path)
|
714
|
+
|
715
|
+
per_learner_save_path = os.path.join(path, per_learner_name)
|
716
|
+
per_learner.save(per_learner_save_path)
|
717
|
+
|
718
|
+
return
|
719
|
+
|
720
|
+
@classmethod
|
721
|
+
def load(
|
722
|
+
cls,
|
723
|
+
path: str,
|
724
|
+
**kwargs,
|
725
|
+
):
|
726
|
+
dir_path, ckpt_path = get_dir_ckpt_paths(path=path)
|
727
|
+
assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
|
728
|
+
with open(os.path.join(dir_path, "assets.json"), "r") as fp:
|
729
|
+
assets = json.load(fp)
|
730
|
+
|
731
|
+
learner = cls(
|
732
|
+
hyperparameters=assets["hyperparameters"],
|
733
|
+
)
|
734
|
+
learner._ensemble_size = assets["ensemble_size"]
|
735
|
+
learner._ensemble_mode = assets["ensemble_mode"]
|
736
|
+
learner._selected_learners = assets["selected_learners"]
|
737
|
+
learner._all_learners = assets["all_learners"]
|
738
|
+
learner._selected_indices = assets["selected_indices"]
|
739
|
+
learner._save_path = path # in case the original exp dir is copied to somewhere else
|
740
|
+
learner._relative_path = assets["relative_path"]
|
741
|
+
learner._fit_called = assets["fit_called"]
|
742
|
+
|
743
|
+
with open(os.path.join(path, "ensemble.pkl"), "rb") as fp:
|
744
|
+
learner._weighted_ensemble = pickle.load(fp) # nosec B301
|
745
|
+
|
746
|
+
learner.update_attributes_by_first_learner(learners=learner._selected_learners)
|
747
|
+
|
748
|
+
return learner
|