autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -0,0 +1,748 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import pathlib
6
+ import pickle
7
+ import pprint
8
+ import time
9
+ import warnings
10
+ from typing import Callable, Dict, List, Optional, Union
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ from autogluon.core.metrics import Scorer
16
+ from autogluon.core.models.greedy_ensemble.ensemble_selection import EnsembleSelection
17
+
18
+ from .. import version as ag_version
19
+ from ..constants import BINARY, LOGITS, MULTICLASS, REGRESSION, TEST, VAL, Y_PRED, Y_TRUE
20
+ from ..optim import compute_score
21
+ from ..utils import (
22
+ extract_from_output,
23
+ get_dir_ckpt_paths,
24
+ logits_to_prob,
25
+ on_fit_end_message,
26
+ update_ensemble_hyperparameters,
27
+ )
28
+ from .base import BaseLearner
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class EnsembleLearner(BaseLearner):
34
+ def __init__(
35
+ self,
36
+ label: Optional[str] = None,
37
+ problem_type: Optional[str] = None,
38
+ presets: Optional[str] = "high_quality",
39
+ eval_metric: Optional[Union[str, Scorer]] = None,
40
+ hyperparameters: Optional[dict] = None,
41
+ path: Optional[str] = None,
42
+ verbosity: Optional[int] = 2,
43
+ warn_if_exist: Optional[bool] = True,
44
+ enable_progress_bar: Optional[bool] = None,
45
+ ensemble_size: Optional[int] = 2,
46
+ ensemble_mode: Optional[str] = "one_shot",
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Parameters
51
+ ----------
52
+ label
53
+ Name of the column that contains the target variable to predict.
54
+ problem_type
55
+ Type of the prediction problem. We support standard problems like
56
+
57
+ - 'binary': Binary classification
58
+ - 'multiclass': Multi-class classification
59
+ - 'regression': Regression
60
+ - 'classification': Classification problems include 'binary' and 'multiclass' classification.
61
+ presets
62
+ Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
63
+ eval_metric
64
+ Evaluation metric name. If `eval_metric = None`, it is automatically chosen based on `problem_type`.
65
+ Defaults to 'accuracy' for multiclass classification, `roc_auc` for binary classification, and 'root_mean_squared_error' for regression.
66
+ hyperparameters
67
+ This is to override some default configurations.
68
+ For example, changing the text and image backbones can be done by formatting:
69
+
70
+ a string
71
+ hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
72
+
73
+ or a list of strings
74
+ hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
75
+
76
+ or a dictionary
77
+ hyperparameters = {
78
+ "model.hf_text.checkpoint_name": "google/electra-small-discriminator",
79
+ "model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
80
+ }
81
+ path
82
+ Path to directory where models and intermediate outputs should be saved.
83
+ If unspecified, a time-stamped folder called "AutogluonAutoMM/ag-[TIMESTAMP]"
84
+ will be created in the working directory to store all models.
85
+ Note: To call `fit()` twice and save all results of each fit,
86
+ you must specify different `path` locations or don't specify `path` at all.
87
+ Otherwise files from first `fit()` will be overwritten by second `fit()`.
88
+ verbosity
89
+ Verbosity levels range from 0 to 4 and control how much information is printed.
90
+ Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings).
91
+ If using logging, you can alternatively control amount of information printed via `logger.setLevel(L)`,
92
+ where `L` ranges from 0 to 50
93
+ (Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels)
94
+ warn_if_exist
95
+ Whether to raise warning if the specified path already exists.
96
+ enable_progress_bar
97
+ Whether to show progress bar. It will be True by default and will also be
98
+ disabled if the environment variable os.environ["AUTOMM_DISABLE_PROGRESS_BAR"] is set.
99
+ ensemble_size
100
+ A multiple of number of models in the ensembling pool (Default 2). The actual ensemble size = ensemble_size * the model number
101
+ ensemble_mode
102
+ The mode of conducting ensembling:
103
+ - `one_shot`: the classic ensemble selection
104
+ - `sequential`: iteratively calling the classic ensemble selection with each time growing the model zoo by the best next model.
105
+ """
106
+ super().__init__(
107
+ label=label,
108
+ problem_type=problem_type,
109
+ presets=presets,
110
+ eval_metric=eval_metric,
111
+ hyperparameters=hyperparameters,
112
+ path=path,
113
+ verbosity=verbosity,
114
+ warn_if_exist=warn_if_exist,
115
+ enable_progress_bar=enable_progress_bar,
116
+ )
117
+ self._ensemble_size = int(ensemble_size)
118
+ assert ensemble_mode in ["sequential", "one_shot"]
119
+ self._ensemble_mode = ensemble_mode
120
+ self._weighted_ensemble = None
121
+ self._selected_learners = None
122
+ self._all_learners = None
123
+ self._selected_indices = []
124
+ self._relative_path = False
125
+
126
+ return
127
+
128
+ def get_learner_path(self, learner_path: str):
129
+ if self._relative_path:
130
+ learner_path = os.path.join(self._save_path, learner_path)
131
+ return learner_path
132
+
133
+ def get_learner_name(self, learner):
134
+ if isinstance(learner, str):
135
+ if self._relative_path:
136
+ learner_name = learner
137
+ else:
138
+ learner_name = pathlib.PurePath(learner).name
139
+ else:
140
+ learner_name = pathlib.PurePath(learner.path).name
141
+
142
+ return learner_name
143
+
144
+ def predict_all_for_ensembling(
145
+ self,
146
+ learners: List[Union[str, BaseLearner]],
147
+ data: Union[pd.DataFrame, str],
148
+ mode: str,
149
+ requires_label: Optional[bool] = False,
150
+ save: Optional[bool] = False,
151
+ ):
152
+ assert mode in [VAL, TEST]
153
+ predictions = []
154
+ labels = None
155
+ i = 0
156
+ for per_learner in learners:
157
+ i += 1
158
+ logger.info(f"\npredicting with learner {i}: {per_learner}\n")
159
+ if isinstance(per_learner, str):
160
+ per_learner_path = self.get_learner_path(per_learner)
161
+ else:
162
+ per_learner_path = per_learner.path
163
+
164
+ pred_file_path = os.path.join(per_learner_path, f"{mode}_predictions.npy")
165
+ if os.path.isfile(pred_file_path):
166
+ logger.info(f"{mode}_predictions.npy exists. loading it...")
167
+ y_pred = np.load(pred_file_path)
168
+ else:
169
+ if isinstance(per_learner, str):
170
+ per_learner = BaseLearner.load(path=per_learner_path)
171
+ if not self._problem_type:
172
+ self._problem_type = per_learner.problem_type
173
+ else:
174
+ assert self._problem_type == per_learner.problem_type
175
+ outputs = per_learner.predict_per_run(
176
+ data=data,
177
+ realtime=False,
178
+ requires_label=False,
179
+ )
180
+ y_pred = extract_from_output(outputs=outputs, ret_type=LOGITS)
181
+
182
+ if self._problem_type == REGRESSION:
183
+ y_pred = per_learner._df_preprocessor.transform_prediction(y_pred=y_pred)
184
+ if self._problem_type in [BINARY, MULTICLASS]:
185
+ y_pred = logits_to_prob(y_pred)
186
+ if self._problem_type == BINARY:
187
+ y_pred = y_pred[:, 1]
188
+
189
+ if save:
190
+ np.save(pred_file_path, y_pred)
191
+
192
+ if requires_label:
193
+ label_file_path = os.path.join(per_learner_path, f"{mode}_labels.npy")
194
+ if os.path.isfile(label_file_path):
195
+ logger.info(f"{mode}_labels.npy exists. loading it...")
196
+ y_true = np.load(label_file_path)
197
+ else:
198
+ if isinstance(per_learner, str):
199
+ per_learner = BaseLearner.load(path=per_learner_path)
200
+ y_true = per_learner._df_preprocessor.transform_label_for_metric(df=data)
201
+
202
+ if save:
203
+ np.save(label_file_path, y_true)
204
+
205
+ if labels is None:
206
+ labels = y_true
207
+ else:
208
+ assert np.array_equal(y_true, labels)
209
+
210
+ predictions.append(y_pred)
211
+
212
+ if requires_label:
213
+ return predictions, labels
214
+ else:
215
+ return predictions
216
+
217
+ @staticmethod
218
+ def verify_predictions_labels(predictions, learners, labels=None):
219
+ if labels is not None:
220
+ assert isinstance(labels, np.ndarray)
221
+ assert isinstance(predictions, list) and all(isinstance(ele, np.ndarray) for ele in predictions)
222
+ assert len(learners) == len(
223
+ predictions
224
+ ), f"len(learners) {len(learners)} doesn't match len(predictions) {len(predictions)}"
225
+
226
+ def fit_per_ensemble(
227
+ self,
228
+ predictions: List[np.ndarray],
229
+ labels: np.ndarray,
230
+ ):
231
+ weighted_ensemble = EnsembleSelection(
232
+ ensemble_size=self._ensemble_size * len(predictions),
233
+ problem_type=self._problem_type,
234
+ metric=self._eval_metric_func,
235
+ )
236
+ weighted_ensemble.fit(predictions=predictions, labels=labels)
237
+
238
+ return weighted_ensemble
239
+
240
+ def select_next_best(self, left_learner_indices, selected_learner_indices, predictions, labels):
241
+ best_regret = None
242
+ best_weighted_ensemble = None
243
+ best_next_index = None
244
+ for i in left_learner_indices:
245
+ tmp_learner_indices = selected_learner_indices + [i]
246
+ tmp_predictions = [predictions[j] for j in tmp_learner_indices]
247
+ tmp_weighted_ensemble = self.fit_per_ensemble(
248
+ predictions=tmp_predictions,
249
+ labels=labels,
250
+ )
251
+ if best_regret is None or tmp_weighted_ensemble.train_score_ < best_regret:
252
+ best_regret = tmp_weighted_ensemble.train_score_
253
+ best_weighted_ensemble = tmp_weighted_ensemble
254
+ best_next_index = i
255
+
256
+ return best_regret, best_next_index, best_weighted_ensemble
257
+
258
+ def sequential_ensemble(
259
+ self,
260
+ predictions: List[np.ndarray],
261
+ labels: np.ndarray,
262
+ ):
263
+ selected_learner_indices = []
264
+ all_learner_indices = list(range(len(predictions)))
265
+ best_regret = None
266
+ best_weighted_ensemble = None
267
+ best_selected_learner_indices = None
268
+ while len(selected_learner_indices) < len(all_learner_indices):
269
+ left_learner_indices = [i for i in all_learner_indices if i not in selected_learner_indices]
270
+ assert sorted(all_learner_indices) == sorted(selected_learner_indices + left_learner_indices)
271
+ logger.debug(f"\nleft_learner_indices: {left_learner_indices}")
272
+ if not left_learner_indices:
273
+ break
274
+ logger.debug(f"selected_learner_indices: {selected_learner_indices}")
275
+ tmp_reget, next_index, tmp_weighted_ensemble = self.select_next_best(
276
+ left_learner_indices=left_learner_indices,
277
+ selected_learner_indices=selected_learner_indices,
278
+ predictions=predictions,
279
+ labels=labels,
280
+ )
281
+ selected_learner_indices.append(next_index)
282
+ if best_regret is None or tmp_reget < best_regret:
283
+ best_regret = tmp_reget
284
+ best_weighted_ensemble = tmp_weighted_ensemble
285
+ best_selected_learner_indices = copy.deepcopy(selected_learner_indices)
286
+ logger.debug(f"\nbest score: {self._eval_metric_func._optimum-best_regret}")
287
+ logger.debug(f"best_selected_learner_indices: {best_selected_learner_indices}")
288
+ logger.debug(f"best_ensemble_weights: {best_weighted_ensemble.weights_}")
289
+
290
+ return best_weighted_ensemble, best_selected_learner_indices
291
+
292
+ def update_hyperparameters(self, hyperparameters: Dict):
293
+ if self._hyperparameters and hyperparameters:
294
+ self._hyperparameters.update(hyperparameters)
295
+ elif hyperparameters:
296
+ self._hyperparameters = hyperparameters
297
+
298
+ self._hyperparameters = update_ensemble_hyperparameters(
299
+ presets=self._presets,
300
+ provided_hyperparameters=self._hyperparameters,
301
+ )
302
+ # filter out meta-transformer if no local checkpoint path is provided
303
+ if "early_fusion" in self._hyperparameters:
304
+ if self._hyperparameters["early_fusion"]["model.meta_transformer.checkpoint_path"] == "null":
305
+ self._hyperparameters.pop("early_fusion")
306
+ message = (
307
+ "`early_fusion` will not be used in ensembling because `early_fusion` relies on MetaTransformer, "
308
+ "but no local MetaTransformer model checkpoint is provided. To use `early_fusion`, "
309
+ "download its model checkpoints from https://github.com/invictus717/MetaTransformer to local "
310
+ "and set the checkpoint path as follows:\n"
311
+ "```python\n"
312
+ "hyperparameters = {\n"
313
+ ' "early_fusion": {\n'
314
+ ' "model.meta_transformer.checkpoint_path": args.meta_transformer_ckpt_path,\n'
315
+ " }\n"
316
+ "}\n"
317
+ "```\n"
318
+ "Note that presets `high_quality` (default) and `medium_quality` need the base model, while preset "
319
+ "`best_quality` requires the large model. Make sure to download the right MetaTransformer version. "
320
+ "We recommend using the download links under tag `国内下载源` because the corresponding "
321
+ "downloaded models are not compressed and can be loaded directly.\n"
322
+ )
323
+
324
+ logger.warning(message)
325
+
326
+ def fit_all(
327
+ self,
328
+ train_data,
329
+ tuning_data,
330
+ hyperparameters,
331
+ column_types,
332
+ holdout_frac,
333
+ time_limit,
334
+ seed,
335
+ standalone,
336
+ clean_ckpts,
337
+ ):
338
+ self._relative_path = True
339
+ self.update_hyperparameters(hyperparameters=hyperparameters)
340
+
341
+ learners = []
342
+ assert (
343
+ len(self._hyperparameters) > 1
344
+ ), f"Ensembling requires training more than 1 learners, but got {len(self._hyperparameters)} sets of hyperparameters."
345
+ logger.info(
346
+ f"Will ensemble {len(self._hyperparameters)} models with the following configs:\n {pprint.pformat(self._hyperparameters)}"
347
+ )
348
+ for per_name, per_hparams in self._hyperparameters.items():
349
+ per_learner_path = os.path.join(self._save_path, per_name)
350
+ if not os.path.isdir(per_learner_path):
351
+ logger.info(f"\nfitting learner {per_name}")
352
+ logger.debug(f"hyperparameters: {per_hparams}")
353
+ per_learner = BaseLearner(
354
+ label=self._label_column,
355
+ problem_type=self._problem_type,
356
+ presets=self._presets,
357
+ eval_metric=self._eval_metric_func,
358
+ hyperparameters=per_hparams,
359
+ path=per_learner_path,
360
+ verbosity=self._verbosity,
361
+ warn_if_exist=self._warn_if_exist,
362
+ enable_progress_bar=self._enable_progress_bar,
363
+ pretrained=self._pretrained,
364
+ validation_metric=self._validation_metric_name,
365
+ )
366
+ per_learner.fit(
367
+ train_data=train_data,
368
+ tuning_data=tuning_data,
369
+ time_limit=time_limit,
370
+ column_types=column_types,
371
+ holdout_frac=holdout_frac,
372
+ seed=seed,
373
+ standalone=standalone,
374
+ clean_ckpts=clean_ckpts,
375
+ )
376
+ learners.append(per_name)
377
+
378
+ return learners
379
+
380
+ def on_fit_end(
381
+ self,
382
+ training_start: float,
383
+ **kwargs,
384
+ ):
385
+ self._fit_called = True
386
+ training_end = time.time()
387
+ self._total_train_time = training_end - training_start
388
+ logger.info(on_fit_end_message(self._save_path))
389
+
390
+ def update_attributes_by_first_learner(self, learners: List):
391
+ # load df preprocessor from the first learner
392
+ if isinstance(learners[0], str):
393
+ first_learner_path = self.get_learner_path(learners[0])
394
+ dir_path, ckpt_path = get_dir_ckpt_paths(path=first_learner_path)
395
+ assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
396
+ first_learner = BaseLearner(label="dummy_label")
397
+ first_learner = BaseLearner._load_metadata(learner=first_learner, path=dir_path)
398
+ else:
399
+ first_learner = learners[0]
400
+
401
+ self._df_preprocessor = first_learner._df_preprocessor
402
+ self._eval_metric_func = first_learner._eval_metric_func
403
+ self._eval_metric_name = first_learner._eval_metric_name
404
+ self._problem_type = first_learner._problem_type
405
+
406
+ def fit_ensemble(
407
+ self,
408
+ predictions: Optional[List[np.ndarray]] = None,
409
+ labels: Optional[np.ndarray] = None,
410
+ learners: Optional[List[Union[str, BaseLearner]]] = None,
411
+ train_data: Optional[Union[pd.DataFrame, str]] = None,
412
+ tuning_data: Optional[Union[pd.DataFrame, str]] = None,
413
+ holdout_frac: Optional[float] = None,
414
+ seed: Optional[int] = 0,
415
+ ):
416
+ if not predictions or labels is None:
417
+ self.prepare_train_tuning_data(
418
+ train_data=train_data,
419
+ tuning_data=tuning_data,
420
+ holdout_frac=holdout_frac,
421
+ seed=seed,
422
+ )
423
+ predictions, labels = self.predict_all_for_ensembling(
424
+ learners=learners,
425
+ data=self._tuning_data,
426
+ mode=VAL,
427
+ requires_label=True,
428
+ save=True,
429
+ )
430
+
431
+ self.verify_predictions_labels(
432
+ predictions=predictions,
433
+ labels=labels,
434
+ learners=learners,
435
+ )
436
+
437
+ if self._ensemble_mode == "sequential":
438
+ weighted_ensemble, selected_learner_indices = self.sequential_ensemble(
439
+ predictions=predictions,
440
+ labels=labels,
441
+ )
442
+ elif self._ensemble_mode == "one_shot":
443
+ weighted_ensemble = self.fit_per_ensemble(
444
+ predictions=predictions,
445
+ labels=labels,
446
+ )
447
+ selected_learner_indices = list(range(len(learners)))
448
+ else:
449
+ raise ValueError(f"Unsupported ensemble_mode: {self._ensemble_mode}")
450
+
451
+ predictions = [predictions[j] for j in selected_learner_indices]
452
+ predictions = weighted_ensemble.predict_proba(predictions)
453
+
454
+ # for regression, the transform_prediction() is already called in predict_all()
455
+ if self._eval_metric_func.needs_pred and self._problem_type != REGRESSION:
456
+ predictions = self._df_preprocessor.transform_prediction(
457
+ y_pred=predictions,
458
+ inverse_categorical=False,
459
+ )
460
+ metric_data = {
461
+ Y_PRED: predictions,
462
+ Y_TRUE: labels,
463
+ }
464
+ score = compute_score(
465
+ metric_data=metric_data,
466
+ metric=self._eval_metric_func,
467
+ )
468
+
469
+ logger.debug(f"\nEnsembling score on validation data: {score}")
470
+
471
+ return weighted_ensemble, selected_learner_indices
472
+
473
+ def fit(
474
+ self,
475
+ train_data: Union[pd.DataFrame, str],
476
+ presets: Optional[str] = None,
477
+ tuning_data: Optional[Union[pd.DataFrame, str]] = None,
478
+ time_limit: Optional[int] = None,
479
+ save_path: Optional[str] = None,
480
+ hyperparameters: Optional[Union[str, Dict, List[str]]] = None,
481
+ column_types: Optional[Dict] = None,
482
+ holdout_frac: Optional[float] = None,
483
+ teacher_learner: Union[str, BaseLearner] = None,
484
+ seed: Optional[int] = 0,
485
+ standalone: Optional[bool] = True,
486
+ hyperparameter_tune_kwargs: Optional[Dict] = None,
487
+ clean_ckpts: Optional[bool] = True,
488
+ learners: Optional[List[Union[str, BaseLearner]]] = None,
489
+ predictions: Optional[List[np.ndarray]] = None,
490
+ labels: Optional[np.ndarray] = None,
491
+ **kwargs,
492
+ ):
493
+ self.setup_save_path(save_path=save_path)
494
+ training_start = self.on_fit_start(presets=presets)
495
+ if learners is None:
496
+ learners = self.fit_all(
497
+ train_data=train_data,
498
+ tuning_data=tuning_data,
499
+ hyperparameters=hyperparameters,
500
+ column_types=column_types,
501
+ holdout_frac=holdout_frac,
502
+ time_limit=time_limit,
503
+ seed=seed,
504
+ standalone=standalone,
505
+ clean_ckpts=clean_ckpts,
506
+ )
507
+ assert len(learners) > 1, f"Ensembling requires more than 1 learners, but got {len(learners)}."
508
+
509
+ self.update_attributes_by_first_learner(learners=learners)
510
+ weighted_ensemble, selected_learner_indices = self.fit_ensemble(
511
+ predictions=predictions,
512
+ labels=labels,
513
+ learners=learners,
514
+ train_data=train_data,
515
+ tuning_data=tuning_data,
516
+ holdout_frac=holdout_frac,
517
+ seed=seed,
518
+ )
519
+
520
+ assert len(selected_learner_indices) == len(weighted_ensemble.weights_)
521
+ self._weighted_ensemble = weighted_ensemble
522
+ self._selected_learners = [learners[i] for i in selected_learner_indices]
523
+ self._all_learners = learners
524
+ self._selected_indices = selected_learner_indices
525
+
526
+ self.on_fit_end(training_start=training_start)
527
+ self.save(path=self._save_path)
528
+
529
+ return self
530
+
531
+ def predict(
532
+ self,
533
+ data: Union[pd.DataFrame, dict, list, str],
534
+ predictions: Optional[List[np.ndarray]] = None,
535
+ as_pandas: Optional[bool] = None,
536
+ **kwargs,
537
+ ):
538
+ self.on_predict_start()
539
+ if not predictions:
540
+ predictions = self.predict_all_for_ensembling(
541
+ learners=self._selected_learners,
542
+ data=data,
543
+ mode=TEST,
544
+ requires_label=False,
545
+ save=False,
546
+ )
547
+ else:
548
+ predictions = [predictions[i] for i in self._selected_indices]
549
+
550
+ self.verify_predictions_labels(
551
+ predictions=predictions,
552
+ learners=self._selected_learners,
553
+ )
554
+ pred = self._weighted_ensemble.predict_proba(predictions)
555
+ # for regression, the transform_prediction() is already called in predict_all()
556
+ if self._problem_type in [BINARY, MULTICLASS]:
557
+ pred = self._df_preprocessor.transform_prediction(
558
+ y_pred=pred,
559
+ inverse_categorical=True,
560
+ )
561
+ if (as_pandas is None and isinstance(data, pd.DataFrame)) or as_pandas is True:
562
+ pred = self._as_pandas(data=data, to_be_converted=pred)
563
+
564
+ return pred
565
+
566
+ def predict_proba(
567
+ self,
568
+ data: Union[pd.DataFrame, dict, list],
569
+ predictions: Optional[List[np.ndarray]] = None,
570
+ as_pandas: Optional[bool] = None,
571
+ as_multiclass: Optional[bool] = True,
572
+ **kwargs,
573
+ ):
574
+ self.on_predict_start()
575
+ assert self._problem_type not in [
576
+ REGRESSION,
577
+ ], f"Problem {self._problem_type} has no probability output."
578
+
579
+ if not predictions:
580
+ predictions = self.predict_all_for_ensembling(
581
+ learners=self._selected_learners,
582
+ data=data,
583
+ mode=TEST,
584
+ requires_label=False,
585
+ save=False,
586
+ )
587
+ else:
588
+ predictions = [predictions[i] for i in self._selected_indices]
589
+
590
+ self.verify_predictions_labels(
591
+ predictions=predictions,
592
+ learners=self._selected_learners,
593
+ )
594
+ prob = self._weighted_ensemble.predict_proba(predictions)
595
+ if as_multiclass and self._problem_type == BINARY:
596
+ prob = np.column_stack((1 - prob, prob))
597
+
598
+ if (as_pandas is None and isinstance(data, pd.DataFrame)) or as_pandas is True:
599
+ prob = self._as_pandas(data=data, to_be_converted=prob)
600
+
601
+ return prob
602
+
603
+ def evaluate(
604
+ self,
605
+ data: Union[pd.DataFrame, dict, list, str],
606
+ predictions: Optional[List[np.ndarray]] = None,
607
+ labels: Optional[np.ndarray] = None,
608
+ save_all: Optional[bool] = True,
609
+ **kwargs,
610
+ ):
611
+ self.on_predict_start()
612
+ if not predictions or labels is None:
613
+ if save_all:
614
+ learners = self._all_learners
615
+ else:
616
+ learners = self._selected_learners
617
+ predictions, labels = self.predict_all_for_ensembling(
618
+ learners=learners,
619
+ data=data,
620
+ mode=TEST,
621
+ requires_label=True,
622
+ save=True,
623
+ )
624
+ if save_all:
625
+ predictions = [predictions[i] for i in self._selected_indices]
626
+ else:
627
+ predictions = [predictions[i] for i in self._selected_indices]
628
+
629
+ self.verify_predictions_labels(
630
+ predictions=predictions,
631
+ labels=labels,
632
+ learners=self._selected_learners,
633
+ )
634
+ all_scores = dict()
635
+ for per_predictions, per_learner in zip(predictions, self._selected_learners):
636
+ if not isinstance(per_learner, str):
637
+ per_learner = per_learner.path
638
+ metric_data = {
639
+ Y_PRED: per_predictions,
640
+ Y_TRUE: labels,
641
+ }
642
+ all_scores[per_learner] = compute_score(
643
+ metric_data=metric_data,
644
+ metric=self._eval_metric_func,
645
+ )
646
+
647
+ predictions = self._weighted_ensemble.predict_proba(predictions)
648
+ # for regression, the transform_prediction() is already called in predict_all()
649
+ if self._eval_metric_func.needs_pred and self._problem_type != REGRESSION:
650
+ predictions = self._df_preprocessor.transform_prediction(
651
+ y_pred=predictions,
652
+ inverse_categorical=False,
653
+ )
654
+ metric_data = {
655
+ Y_PRED: predictions,
656
+ Y_TRUE: labels,
657
+ }
658
+ all_scores["ensemble"] = compute_score(
659
+ metric_data=metric_data,
660
+ metric=self._eval_metric_func,
661
+ )
662
+
663
+ return all_scores
664
+
665
+ def extract_embedding(
666
+ self,
667
+ data: Union[pd.DataFrame, dict, list],
668
+ return_masks: Optional[bool] = False,
669
+ as_tensor: Optional[bool] = False,
670
+ as_pandas: Optional[bool] = False,
671
+ realtime: Optional[bool] = False,
672
+ **kwargs,
673
+ ):
674
+ raise ValueError(f"EnsembleLearner doesn't support extracting embedding yet.")
675
+
676
+ def save(
677
+ self,
678
+ path: str,
679
+ **kwargs,
680
+ ):
681
+ selected_learner_names = [self.get_learner_name(per_learner) for per_learner in self._selected_learners]
682
+ all_learner_names = [self.get_learner_name(per_learner) for per_learner in self._all_learners]
683
+
684
+ os.makedirs(path, exist_ok=True)
685
+ with open(os.path.join(path, f"assets.json"), "w") as fp:
686
+ json.dump(
687
+ {
688
+ "learner_class": self.__class__.__name__,
689
+ "ensemble_size": self._ensemble_size,
690
+ "ensemble_mode": self._ensemble_mode,
691
+ "selected_learners": selected_learner_names,
692
+ "all_learners": all_learner_names,
693
+ "selected_indices": self._selected_indices,
694
+ "ensemble_weights": self._weighted_ensemble.weights_.tolist(),
695
+ "save_path": path,
696
+ "relative_path": True,
697
+ "fit_called": self._fit_called,
698
+ "version": ag_version.__version__,
699
+ "hyperparameters": self._hyperparameters,
700
+ },
701
+ fp,
702
+ ensure_ascii=True,
703
+ )
704
+
705
+ with open(os.path.join(path, "ensemble.pkl"), "wb") as fp:
706
+ pickle.dump(self._weighted_ensemble, fp)
707
+
708
+ # save each learner
709
+ for per_learner in self._all_learners:
710
+ per_learner_name = self.get_learner_name(per_learner)
711
+ if isinstance(per_learner, str):
712
+ per_learner_path = self.get_learner_path(per_learner)
713
+ per_learner = BaseLearner.load(per_learner_path)
714
+
715
+ per_learner_save_path = os.path.join(path, per_learner_name)
716
+ per_learner.save(per_learner_save_path)
717
+
718
+ return
719
+
720
+ @classmethod
721
+ def load(
722
+ cls,
723
+ path: str,
724
+ **kwargs,
725
+ ):
726
+ dir_path, ckpt_path = get_dir_ckpt_paths(path=path)
727
+ assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
728
+ with open(os.path.join(dir_path, "assets.json"), "r") as fp:
729
+ assets = json.load(fp)
730
+
731
+ learner = cls(
732
+ hyperparameters=assets["hyperparameters"],
733
+ )
734
+ learner._ensemble_size = assets["ensemble_size"]
735
+ learner._ensemble_mode = assets["ensemble_mode"]
736
+ learner._selected_learners = assets["selected_learners"]
737
+ learner._all_learners = assets["all_learners"]
738
+ learner._selected_indices = assets["selected_indices"]
739
+ learner._save_path = path # in case the original exp dir is copied to somewhere else
740
+ learner._relative_path = assets["relative_path"]
741
+ learner._fit_called = assets["fit_called"]
742
+
743
+ with open(os.path.join(path, "ensemble.pkl"), "rb") as fp:
744
+ learner._weighted_ensemble = pickle.load(fp) # nosec B301
745
+
746
+ learner.update_attributes_by_first_learner(learners=learner._selected_learners)
747
+
748
+ return learner