autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -8,6 +8,7 @@ import os
|
|
8
8
|
import warnings
|
9
9
|
from typing import Dict, List, Optional, Union
|
10
10
|
|
11
|
+
import numpy as np
|
11
12
|
import pandas as pd
|
12
13
|
import transformers
|
13
14
|
|
@@ -17,14 +18,15 @@ from autogluon.core.metrics import Scorer
|
|
17
18
|
from .constants import AUTOMM_TUTORIAL_MODE, FEW_SHOT_CLASSIFICATION, NER, OBJECT_DETECTION, SEMANTIC_SEGMENTATION
|
18
19
|
from .learners import (
|
19
20
|
BaseLearner,
|
21
|
+
EnsembleLearner,
|
20
22
|
FewShotSVMLearner,
|
21
|
-
|
23
|
+
MatchingLearner,
|
22
24
|
NERLearner,
|
23
25
|
ObjectDetectionLearner,
|
24
26
|
SemanticSegmentationLearner,
|
25
27
|
)
|
26
|
-
from .problem_types import PROBLEM_TYPES_REG
|
27
28
|
from .utils import get_dir_ckpt_paths
|
29
|
+
from .utils.problem_types import PROBLEM_TYPES_REG
|
28
30
|
|
29
31
|
pl_logger = logging.getLogger("lightning")
|
30
32
|
pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/4621
|
@@ -64,6 +66,9 @@ class MultiModalPredictor:
|
|
64
66
|
pretrained: Optional[bool] = True,
|
65
67
|
validation_metric: Optional[str] = None,
|
66
68
|
sample_data_path: Optional[str] = None,
|
69
|
+
use_ensemble: Optional[bool] = False,
|
70
|
+
ensemble_size: Optional[int] = 2,
|
71
|
+
ensemble_mode: Optional[str] = "one_shot",
|
67
72
|
):
|
68
73
|
"""
|
69
74
|
Parameters
|
@@ -164,6 +169,15 @@ class MultiModalPredictor:
|
|
164
169
|
If not provided, it would be automatically chosen based on the problem type.
|
165
170
|
sample_data_path
|
166
171
|
The path to sample data from which we can infer num_classes or classes used for object detection.
|
172
|
+
use_ensemble
|
173
|
+
Whether to use ensembling when fitting the predictor (Default False).
|
174
|
+
Currently, it works only on multimodal data (image+text, image+tabular, text+tabular, image+text+tabular) with classification or regression tasks.
|
175
|
+
ensemble_size
|
176
|
+
A multiple of number of models in the ensembling pool (Default 2). The actual ensemble size = ensemble_size * the model number
|
177
|
+
ensemble_mode
|
178
|
+
The mode of conducting ensembling:
|
179
|
+
- `one_shot`: the classic ensemble selection
|
180
|
+
- `sequential`: iteratively calling the classic ensemble selection with each time growing the model zoo by the best next model.
|
167
181
|
"""
|
168
182
|
if problem_type is not None:
|
169
183
|
problem_type = problem_type.lower()
|
@@ -192,7 +206,7 @@ class MultiModalPredictor:
|
|
192
206
|
self._verbosity = verbosity
|
193
207
|
|
194
208
|
if problem_property and problem_property.is_matching:
|
195
|
-
learner_class =
|
209
|
+
learner_class = MatchingLearner
|
196
210
|
elif problem_type == OBJECT_DETECTION:
|
197
211
|
learner_class = ObjectDetectionLearner
|
198
212
|
elif problem_type == NER:
|
@@ -204,6 +218,9 @@ class MultiModalPredictor:
|
|
204
218
|
else:
|
205
219
|
learner_class = BaseLearner
|
206
220
|
|
221
|
+
if use_ensemble:
|
222
|
+
learner_class = EnsembleLearner
|
223
|
+
|
207
224
|
self._learner = learner_class(
|
208
225
|
label=label,
|
209
226
|
problem_type=problem_type,
|
@@ -222,6 +239,8 @@ class MultiModalPredictor:
|
|
222
239
|
query=query,
|
223
240
|
response=response,
|
224
241
|
match_label=match_label,
|
242
|
+
ensemble_size=ensemble_size,
|
243
|
+
ensemble_mode=ensemble_mode,
|
225
244
|
)
|
226
245
|
|
227
246
|
@property
|
@@ -413,6 +432,9 @@ class MultiModalPredictor:
|
|
413
432
|
standalone: Optional[bool] = True,
|
414
433
|
hyperparameter_tune_kwargs: Optional[dict] = None,
|
415
434
|
clean_ckpts: Optional[bool] = True,
|
435
|
+
predictions: Optional[List[np.ndarray]] = None,
|
436
|
+
labels: Optional[np.ndarray] = None,
|
437
|
+
predictors: Optional[List[Union[str, MultiModalPredictor]]] = None,
|
416
438
|
):
|
417
439
|
"""
|
418
440
|
Fit models to predict a column of a data table (label) based on the other columns (features).
|
@@ -435,6 +457,8 @@ class MultiModalPredictor:
|
|
435
457
|
time_limit
|
436
458
|
How long `fit()` should run for (wall clock time in seconds).
|
437
459
|
If not specified, `fit()` will run until the model has completed training.
|
460
|
+
Note that, if use_ensemble=True, the total running time would be time_limit * N,
|
461
|
+
where N is the number of models in the ensemble.
|
438
462
|
save_path
|
439
463
|
Path to directory where models and artifacts should be saved.
|
440
464
|
hyperparameters
|
@@ -506,6 +530,13 @@ class MultiModalPredictor:
|
|
506
530
|
teacher_learner = teacher_predictor
|
507
531
|
else:
|
508
532
|
teacher_learner = teacher_predictor._learner
|
533
|
+
|
534
|
+
if predictors is None:
|
535
|
+
learners = None
|
536
|
+
else:
|
537
|
+
assert isinstance(predictors, list)
|
538
|
+
learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
|
539
|
+
|
509
540
|
self._learner.fit(
|
510
541
|
train_data=train_data,
|
511
542
|
presets=presets,
|
@@ -522,6 +553,9 @@ class MultiModalPredictor:
|
|
522
553
|
hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
|
523
554
|
clean_ckpts=clean_ckpts,
|
524
555
|
id_mappings=id_mappings,
|
556
|
+
predictions=predictions,
|
557
|
+
labels=labels,
|
558
|
+
learners=learners,
|
525
559
|
)
|
526
560
|
|
527
561
|
return self
|
@@ -540,6 +574,8 @@ class MultiModalPredictor:
|
|
540
574
|
return_pred: Optional[bool] = False,
|
541
575
|
realtime: Optional[bool] = False,
|
542
576
|
eval_tool: Optional[str] = None,
|
577
|
+
predictions: Optional[List[np.ndarray]] = None,
|
578
|
+
labels: Optional[np.ndarray] = None,
|
543
579
|
):
|
544
580
|
"""
|
545
581
|
Evaluate the model on a given dataset.
|
@@ -595,6 +631,8 @@ class MultiModalPredictor:
|
|
595
631
|
similarity_type=similarity_type,
|
596
632
|
cutoffs=cutoffs,
|
597
633
|
label=label,
|
634
|
+
predictions=predictions,
|
635
|
+
labels=labels,
|
598
636
|
)
|
599
637
|
|
600
638
|
def predict(
|
@@ -807,18 +845,19 @@ class MultiModalPredictor:
|
|
807
845
|
|
808
846
|
with open(os.path.join(dir_path, "assets.json"), "r") as fp:
|
809
847
|
assets = json.load(fp)
|
810
|
-
|
811
|
-
|
812
|
-
|
848
|
+
learner_class = BaseLearner
|
849
|
+
if assets["learner_class"] == "MatchingLearner":
|
850
|
+
learner_class = MatchingLearner
|
851
|
+
elif assets["learner_class"] == "EnsembleLearner":
|
852
|
+
learner_class = EnsembleLearner
|
853
|
+
elif assets["learner_class"] == "FewShotSVMLearner":
|
854
|
+
learner_class = FewShotSVMLearner
|
855
|
+
elif assets["learner_class"] == "ObjectDetectionLearner":
|
813
856
|
learner_class = ObjectDetectionLearner
|
814
|
-
elif assets["
|
857
|
+
elif assets["learner_class"] == "NERLearner":
|
815
858
|
learner_class = NERLearner
|
816
|
-
elif assets["
|
817
|
-
learner_class = FewShotSVMLearner
|
818
|
-
elif assets["problem_type"] == SEMANTIC_SEGMENTATION:
|
859
|
+
elif assets["learner_class"] == "SemanticSegmentationLearner":
|
819
860
|
learner_class = SemanticSegmentationLearner
|
820
|
-
else:
|
821
|
-
learner_class = BaseLearner
|
822
861
|
|
823
862
|
predictor._learner = learner_class.load(path=path, resume=resume, verbosity=verbosity)
|
824
863
|
return predictor
|
@@ -16,66 +16,34 @@ from .config import (
|
|
16
16
|
update_config_by_rules,
|
17
17
|
update_hyperparameters,
|
18
18
|
update_tabular_config_by_resources,
|
19
|
-
|
20
|
-
)
|
21
|
-
from .data import (
|
22
|
-
assign_feature_column_names,
|
23
|
-
create_data_processor,
|
24
|
-
create_fusion_data_processors,
|
25
|
-
data_to_df,
|
26
|
-
get_mixup,
|
27
|
-
infer_dtypes_by_model_names,
|
28
|
-
infer_scarcity_mode_by_data_size,
|
29
|
-
init_df_preprocessor,
|
30
|
-
split_train_tuning_data,
|
31
|
-
turn_on_off_feature_column_info,
|
19
|
+
update_ensemble_hyperparameters,
|
32
20
|
)
|
21
|
+
from .device import compute_num_gpus, get_available_devices, move_to_device
|
33
22
|
from .distillation import DistillationMixin
|
34
23
|
from .download import download, is_url
|
35
|
-
from .environment import (
|
36
|
-
check_if_packages_installed,
|
37
|
-
compute_inference_batch_size,
|
38
|
-
compute_num_gpus,
|
39
|
-
get_available_devices,
|
40
|
-
get_precision_context,
|
41
|
-
infer_precision,
|
42
|
-
is_interactive_env,
|
43
|
-
is_interactive_strategy,
|
44
|
-
move_to_device,
|
45
|
-
run_ddp_only_once,
|
46
|
-
)
|
47
24
|
from .export import ExportMixin
|
48
25
|
from .hpo import hyperparameter_tune
|
49
|
-
from .inference import RealtimeMixin, extract_from_output
|
50
|
-
from .load import CustomUnpickler, get_dir_ckpt_paths, get_load_ckpt_paths
|
26
|
+
from .inference import RealtimeMixin, compute_inference_batch_size, extract_from_output
|
27
|
+
from .load import CustomUnpickler, protected_zip_extraction, get_dir_ckpt_paths, get_load_ckpt_paths
|
51
28
|
from .log import (
|
52
29
|
LogFilter,
|
53
30
|
apply_log_filter,
|
54
31
|
get_gpu_message,
|
55
|
-
make_exp_dir,
|
56
32
|
on_fit_end_message,
|
57
33
|
on_fit_per_run_start_message,
|
58
34
|
on_fit_start_message,
|
59
35
|
)
|
60
36
|
from .matcher import compute_semantic_similarity, convert_data_for_ranking, create_siamese_model, semantic_search
|
61
|
-
from .
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
37
|
+
from .misc import (
|
38
|
+
logits_to_prob,
|
39
|
+
path_expander,
|
40
|
+
path_to_base64str_expander,
|
41
|
+
path_to_bytearray_expander,
|
42
|
+
shopee_dataset,
|
43
|
+
tensor_to_ndarray,
|
44
|
+
merge_bio_format,
|
68
45
|
)
|
69
|
-
from .misc import logits_to_prob, merge_bio_format, shopee_dataset, tensor_to_ndarray
|
70
46
|
from .mmcv import CollateMMDet, CollateMMOcr
|
71
|
-
from .model import (
|
72
|
-
create_fusion_model,
|
73
|
-
create_model,
|
74
|
-
is_lazy_weight_tensor,
|
75
|
-
list_timm_models,
|
76
|
-
modify_duplicate_model_names,
|
77
|
-
select_model,
|
78
|
-
)
|
79
47
|
from .object_detection import (
|
80
48
|
COCODataset,
|
81
49
|
bbox_ratio_xywh_to_index_xyxy,
|
@@ -94,5 +62,11 @@ from .object_detection import (
|
|
94
62
|
save_result_voc_format,
|
95
63
|
visualize_detection,
|
96
64
|
)
|
97
|
-
from .
|
65
|
+
from .precision import get_precision_context, infer_precision
|
66
|
+
from .presets import get_basic_config, get_ensemble_presets, get_presets, list_presets, matcher_presets
|
67
|
+
from .problem_types import PROBLEM_TYPES_REG, infer_problem_type_by_eval_metric
|
68
|
+
from .save import process_save_path, setup_save_path, make_exp_dir
|
69
|
+
from .strategy import is_interactive_strategy, run_ddp_only_once
|
70
|
+
from .env import is_interactive_env
|
98
71
|
from .visualizer import NERVisualizer, ObjectDetectionVisualizer, SemanticSegmentationVisualizer, visualize_ner
|
72
|
+
from .install import check_if_packages_installed
|
@@ -10,7 +10,18 @@ import lightning.pytorch as pl
|
|
10
10
|
import torch
|
11
11
|
from lightning.pytorch.callbacks import BasePredictionWriter
|
12
12
|
|
13
|
-
from ..constants import
|
13
|
+
from ..constants import (
|
14
|
+
AUG_LOGITS,
|
15
|
+
BBOX,
|
16
|
+
LOGIT_SCALE,
|
17
|
+
MULTIMODAL_FEATURES,
|
18
|
+
MULTIMODAL_FEATURES_POST_AUG,
|
19
|
+
MULTIMODAL_FEATURES_PRE_AUG,
|
20
|
+
ORI_LOGITS,
|
21
|
+
VAE_MEAN,
|
22
|
+
VAE_VAR,
|
23
|
+
WEIGHT,
|
24
|
+
)
|
14
25
|
|
15
26
|
logger = logging.getLogger(__name__)
|
16
27
|
|
@@ -146,7 +157,17 @@ class DDPPredictionWriter(BasePredictionWriter):
|
|
146
157
|
return dict()
|
147
158
|
|
148
159
|
for k, v in x[0].items():
|
149
|
-
if k in [
|
160
|
+
if k in [
|
161
|
+
WEIGHT,
|
162
|
+
LOGIT_SCALE,
|
163
|
+
MULTIMODAL_FEATURES,
|
164
|
+
MULTIMODAL_FEATURES_PRE_AUG,
|
165
|
+
MULTIMODAL_FEATURES_POST_AUG,
|
166
|
+
ORI_LOGITS,
|
167
|
+
AUG_LOGITS,
|
168
|
+
VAE_MEAN,
|
169
|
+
VAE_VAR,
|
170
|
+
]: # ignore the keys
|
150
171
|
continue
|
151
172
|
elif isinstance(v, dict):
|
152
173
|
results[k] = self.collate([i[k] for i in x])
|
@@ -1,16 +1,30 @@
|
|
1
|
+
"""
|
2
|
+
Some utilities are copied from
|
3
|
+
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/utilities/cloud_io.py
|
4
|
+
to address warnings:
|
5
|
+
LightningDeprecationWarning: lightning.pytorch.utilities.cloud_io.atomic_save has been
|
6
|
+
deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you
|
7
|
+
can copy over its implementation.
|
8
|
+
"""
|
9
|
+
|
10
|
+
import io
|
1
11
|
import logging
|
2
12
|
import os
|
3
13
|
import re
|
4
14
|
import shutil
|
5
|
-
from
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import IO, Any, Callable, Dict, List, Optional, Tuple, Union
|
6
17
|
|
18
|
+
import fsspec
|
7
19
|
import lightning.pytorch as pl
|
8
20
|
import torch
|
9
21
|
from lightning.pytorch.strategies import DeepSpeedStrategy
|
10
22
|
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
11
23
|
|
12
|
-
from .
|
13
|
-
|
24
|
+
from .env import get_filesystem
|
25
|
+
|
26
|
+
_DEVICE = Union[torch.device, str, int]
|
27
|
+
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
|
14
28
|
|
15
29
|
logger = logging.getLogger(__name__)
|
16
30
|
|
@@ -65,6 +79,45 @@ def average_checkpoints(
|
|
65
79
|
return avg_state_dict
|
66
80
|
|
67
81
|
|
82
|
+
def pl_load(
|
83
|
+
path_or_url: Union[IO, str, Path],
|
84
|
+
map_location: _MAP_LOCATION_TYPE = None,
|
85
|
+
) -> Any:
|
86
|
+
"""Loads a checkpoint.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
path_or_url: Path or URL of the checkpoint.
|
90
|
+
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
|
91
|
+
"""
|
92
|
+
if not isinstance(path_or_url, (str, Path)):
|
93
|
+
# any sort of BytesIO or similar
|
94
|
+
return torch.load(path_or_url, map_location=map_location) # nosec B614
|
95
|
+
if str(path_or_url).startswith("http"):
|
96
|
+
return torch.hub.load_state_dict_from_url(
|
97
|
+
str(path_or_url),
|
98
|
+
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
|
99
|
+
)
|
100
|
+
fs = get_filesystem(path_or_url)
|
101
|
+
with fs.open(path_or_url, "rb") as f:
|
102
|
+
return torch.load(f, map_location=map_location) # nosec B614
|
103
|
+
|
104
|
+
|
105
|
+
def pl_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
|
106
|
+
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
checkpoint: The object to save.
|
110
|
+
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
|
111
|
+
accepts.
|
112
|
+
filepath: The path to which the checkpoint will be saved.
|
113
|
+
This points to the file that the checkpoint will be stored in.
|
114
|
+
"""
|
115
|
+
bytesbuffer = io.BytesIO()
|
116
|
+
torch.save(checkpoint, bytesbuffer) # nosec B614
|
117
|
+
with fsspec.open(filepath, "wb") as f:
|
118
|
+
f.write(bytesbuffer.getvalue())
|
119
|
+
|
120
|
+
|
68
121
|
class AutoMMModelCheckpointIO(pl.plugins.CheckpointIO):
|
69
122
|
"""
|
70
123
|
Class that customizes how checkpoints are saved. Saves either the entire model or only parameters that have been explicitly updated during training. The latter reduces memory footprint substantially when training very large models with parameter-efficient finetuning methods.
|
@@ -124,14 +177,14 @@ class AutoMMModelCheckpointIO(pl.plugins.CheckpointIO):
|
|
124
177
|
fs.makedirs(os.path.dirname(path), exist_ok=True)
|
125
178
|
try:
|
126
179
|
# write the checkpoint dictionary on the file
|
127
|
-
|
180
|
+
pl_save(checkpoint, path)
|
128
181
|
except AttributeError as err:
|
129
182
|
# todo (sean): is this try catch necessary still?
|
130
183
|
# https://github.com/Lightning-AI/lightning/pull/431
|
131
184
|
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
|
132
185
|
checkpoint.pop(key, None)
|
133
186
|
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
|
134
|
-
|
187
|
+
pl_save(checkpoint, path)
|
135
188
|
|
136
189
|
def load_checkpoint(self, path, map_location: Optional[Any] = None) -> Dict[str, Any]:
|
137
190
|
"""
|
@@ -5,22 +5,11 @@ import re
|
|
5
5
|
import warnings
|
6
6
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
7
7
|
|
8
|
-
from omegaconf import DictConfig, OmegaConf
|
9
|
-
from packaging import version
|
8
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
10
9
|
from torch import nn
|
11
10
|
|
12
|
-
from ..constants import
|
13
|
-
|
14
|
-
DATA,
|
15
|
-
FT_TRANSFORMER,
|
16
|
-
FUSION_TRANSFORMER,
|
17
|
-
HF_MODELS,
|
18
|
-
MODEL,
|
19
|
-
REGRESSION,
|
20
|
-
VALID_CONFIG_KEYS,
|
21
|
-
)
|
22
|
-
from ..presets import get_automm_presets, get_basic_automm_config
|
23
|
-
from .data import get_detected_data_types
|
11
|
+
from ..constants import DATA, FT_TRANSFORMER, FUSION_TRANSFORMER, HF_MODELS, MODEL, REGRESSION, VALID_CONFIG_KEYS
|
12
|
+
from .presets import get_basic_config, get_ensemble_presets, get_presets
|
24
13
|
|
25
14
|
logger = logging.getLogger(__name__)
|
26
15
|
|
@@ -68,7 +57,7 @@ def get_default_config(config: Optional[Union[Dict, DictConfig]] = None, extra:
|
|
68
57
|
Parameters
|
69
58
|
----------
|
70
59
|
config
|
71
|
-
A dictionary including four keys: "model", "data", "
|
60
|
+
A dictionary including four keys: "model", "data", "optim", and "env".
|
72
61
|
If any key is not given, we will fill in with the default value.
|
73
62
|
extra
|
74
63
|
A list of extra config keys.
|
@@ -83,7 +72,7 @@ def get_default_config(config: Optional[Union[Dict, DictConfig]] = None, extra:
|
|
83
72
|
if config is None:
|
84
73
|
config = {}
|
85
74
|
|
86
|
-
basic_config =
|
75
|
+
basic_config = get_basic_config(extra=extra)
|
87
76
|
for k, default_value in basic_config.items():
|
88
77
|
if k not in config:
|
89
78
|
config[k] = default_value
|
@@ -119,7 +108,7 @@ def get_config(
|
|
119
108
|
extra: Optional[List[str]] = None,
|
120
109
|
):
|
121
110
|
"""
|
122
|
-
Construct configurations for model, data,
|
111
|
+
Construct configurations for model, data, optim, and env.
|
123
112
|
It supports to overrides some default configurations.
|
124
113
|
|
125
114
|
Parameters
|
@@ -129,29 +118,29 @@ def get_config(
|
|
129
118
|
presets
|
130
119
|
Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
|
131
120
|
config
|
132
|
-
A dictionary including four keys: "model", "data", "
|
121
|
+
A dictionary including four keys: "model", "data", "optim", and "env".
|
133
122
|
If any key is not given, we will fill in with the default value.
|
134
123
|
|
135
124
|
The value of each key can be a string, yaml path, or DictConfig object. For example:
|
136
125
|
config = {
|
137
126
|
"model": "default",
|
138
127
|
"data": "default",
|
139
|
-
"
|
140
|
-
"
|
128
|
+
"optim": "default",
|
129
|
+
"env": "default",
|
141
130
|
}
|
142
131
|
or
|
143
132
|
config = {
|
144
133
|
"model": "/path/to/model/config.yaml",
|
145
134
|
"data": "/path/to/data/config.yaml",
|
146
|
-
"
|
147
|
-
"
|
135
|
+
"optim": "/path/to/optim/config.yaml",
|
136
|
+
"env": "/path/to/env/config.yaml",
|
148
137
|
}
|
149
138
|
or
|
150
139
|
config = {
|
151
140
|
"model": OmegaConf.load("/path/to/model/config.yaml"),
|
152
141
|
"data": OmegaConf.load("/path/to/data/config.yaml"),
|
153
|
-
"
|
154
|
-
"
|
142
|
+
"optim": OmegaConf.load("/path/to/optim/config.yaml"),
|
143
|
+
"env": OmegaConf.load("/path/to/env/config.yaml"),
|
155
144
|
}
|
156
145
|
overrides
|
157
146
|
This is to override some default configurations.
|
@@ -185,7 +174,7 @@ def get_config(
|
|
185
174
|
if presets is None:
|
186
175
|
preset_overrides = None
|
187
176
|
else:
|
188
|
-
preset_overrides, _ =
|
177
|
+
preset_overrides, _ = get_presets(problem_type=problem_type, presets=presets)
|
189
178
|
|
190
179
|
config = get_default_config(config, extra=extra)
|
191
180
|
# apply the preset's overrides
|
@@ -404,33 +393,6 @@ def get_local_pretrained_config_paths(config: DictConfig, path: str) -> DictConf
|
|
404
393
|
return config
|
405
394
|
|
406
395
|
|
407
|
-
def upgrade_config(config, loaded_version):
|
408
|
-
"""Upgrade outdated configurations
|
409
|
-
|
410
|
-
Parameters
|
411
|
-
----------
|
412
|
-
config
|
413
|
-
The configuration
|
414
|
-
loaded_version
|
415
|
-
The version of the config that has been loaded
|
416
|
-
|
417
|
-
Returns
|
418
|
-
-------
|
419
|
-
config
|
420
|
-
The upgraded configuration
|
421
|
-
"""
|
422
|
-
# backward compatibility for variable image size.
|
423
|
-
if version.parse(loaded_version) <= version.parse("0.6.2"):
|
424
|
-
logger.info(f"Start to upgrade the previous configuration trained by AutoMM version={loaded_version}.")
|
425
|
-
if OmegaConf.select(config, "model.timm_image") is not None:
|
426
|
-
logger.warning(
|
427
|
-
"Loading a model that has been trained via AutoGluon Multimodal<=0.6.2. "
|
428
|
-
"Setting config.model.timm_image.image_size = None."
|
429
|
-
)
|
430
|
-
config.model.timm_image.image_size = None
|
431
|
-
return config
|
432
|
-
|
433
|
-
|
434
396
|
def parse_dotlist_conf(conf):
|
435
397
|
"""
|
436
398
|
Parse the config files that is potentially in the dotlist format to a dictionary.
|
@@ -499,6 +461,7 @@ def apply_omegaconf_overrides(
|
|
499
461
|
The updated configuration.
|
500
462
|
"""
|
501
463
|
overrides = parse_dotlist_conf(overrides)
|
464
|
+
overrides = make_overrides_backward_compatible(overrides)
|
502
465
|
|
503
466
|
def _check_exist_dotlist(C, key_in_dotlist):
|
504
467
|
if not isinstance(key_in_dotlist, list):
|
@@ -519,10 +482,34 @@ def apply_omegaconf_overrides(
|
|
519
482
|
f"overrides={overrides}"
|
520
483
|
)
|
521
484
|
override_conf = OmegaConf.from_dotlist([f"{ele[0]}={ele[1]}" for ele in overrides.items()])
|
485
|
+
replace_none_str(override_conf)
|
522
486
|
conf = OmegaConf.merge(conf, override_conf)
|
523
487
|
return conf
|
524
488
|
|
525
489
|
|
490
|
+
def replace_none_str(config: Union[DictConfig, ListConfig, dict, list]):
|
491
|
+
"""
|
492
|
+
In-place replace "None" and "none" strings in the config with None.
|
493
|
+
|
494
|
+
Parameters
|
495
|
+
----------
|
496
|
+
config
|
497
|
+
A config of type DictConfig, ListConfig, dict, or list.
|
498
|
+
"""
|
499
|
+
if isinstance(config, (dict, DictConfig)):
|
500
|
+
for key, value in config.items():
|
501
|
+
if isinstance(value, str) and value.lower() == "none":
|
502
|
+
config[key] = None
|
503
|
+
elif isinstance(value, (dict, list, DictConfig, ListConfig)):
|
504
|
+
replace_none_str(value)
|
505
|
+
elif isinstance(config, (list, ListConfig)):
|
506
|
+
for i, value in enumerate(config):
|
507
|
+
if isinstance(value, str) and value.lower() == "none":
|
508
|
+
config[i] = None
|
509
|
+
elif isinstance(value, (dict, list, DictConfig, ListConfig)):
|
510
|
+
replace_none_str(value)
|
511
|
+
|
512
|
+
|
526
513
|
def update_config_by_rules(
|
527
514
|
problem_type: str,
|
528
515
|
config: DictConfig,
|
@@ -542,11 +529,11 @@ def update_config_by_rules(
|
|
542
529
|
-------
|
543
530
|
The modified config.
|
544
531
|
"""
|
545
|
-
loss_func =
|
532
|
+
loss_func = config.optim.loss_func
|
546
533
|
if loss_func is not None:
|
547
534
|
if problem_type == REGRESSION and "bce" in loss_func.lower():
|
548
535
|
# To use BCELoss for regression problems, need to first scale the labels.
|
549
|
-
config.data.label.
|
536
|
+
config.data.label.numerical_preprocessing = "minmaxscaler"
|
550
537
|
|
551
538
|
return config
|
552
539
|
|
@@ -658,7 +645,7 @@ def update_hyperparameters(
|
|
658
645
|
-------
|
659
646
|
The updated hyperparameters and hyperparameter_tune_kwargs.
|
660
647
|
"""
|
661
|
-
hyperparameters, hyperparameter_tune_kwargs =
|
648
|
+
hyperparameters, hyperparameter_tune_kwargs = get_presets(problem_type=problem_type, presets=presets)
|
662
649
|
|
663
650
|
if hyperparameter_tune_kwargs and provided_hyperparameter_tune_kwargs:
|
664
651
|
hyperparameter_tune_kwargs.update(provided_hyperparameter_tune_kwargs)
|
@@ -732,6 +719,8 @@ def filter_hyperparameters(
|
|
732
719
|
# Filter models whose data types are not detected.
|
733
720
|
# Avoid sampling unused checkpoints, e.g., hf_text models for image classification, to run jobs,
|
734
721
|
# which wastes resources and time.
|
722
|
+
from ..data.utils import get_detected_data_types
|
723
|
+
|
735
724
|
detected_data_types = get_detected_data_types(column_types)
|
736
725
|
selected_model_names = []
|
737
726
|
for model_name in hyperparameters[model_names_key]:
|
@@ -796,3 +785,86 @@ def split_hyperparameters(hyperparameters: Dict):
|
|
796
785
|
raise ValueError(f"transform_types {v} contain neither all strings nor all callable objects.")
|
797
786
|
|
798
787
|
return hyperparameters, advanced_hyperparameters
|
788
|
+
|
789
|
+
|
790
|
+
def update_ensemble_hyperparameters(
|
791
|
+
presets,
|
792
|
+
provided_hyperparameters,
|
793
|
+
):
|
794
|
+
presets_hyperparameters, _ = get_ensemble_presets(presets=presets)
|
795
|
+
if provided_hyperparameters:
|
796
|
+
learner_names = provided_hyperparameters.pop("learner_names", None)
|
797
|
+
if learner_names:
|
798
|
+
assert isinstance(
|
799
|
+
learner_names, list
|
800
|
+
), f"learner_names should be a list, but got type {type(learner_names)}"
|
801
|
+
presets_hyperparameters = {k: v for k, v in presets_hyperparameters.items() if k in learner_names}
|
802
|
+
provided_hyperparameters = {k: v for k, v in provided_hyperparameters.items() if k in learner_names}
|
803
|
+
|
804
|
+
hyperparameters = copy.deepcopy(provided_hyperparameters)
|
805
|
+
for k, v in presets_hyperparameters.items():
|
806
|
+
if k not in hyperparameters:
|
807
|
+
hyperparameters[k] = v
|
808
|
+
else:
|
809
|
+
for kk, vv in presets_hyperparameters[k].items():
|
810
|
+
if kk not in hyperparameters[k]: # don't use presets to overwrite user-provided
|
811
|
+
hyperparameters[k][kk] = vv
|
812
|
+
else:
|
813
|
+
hyperparameters = presets_hyperparameters
|
814
|
+
|
815
|
+
return hyperparameters
|
816
|
+
|
817
|
+
|
818
|
+
def make_overrides_backward_compatible(overrides: Dict):
|
819
|
+
"""
|
820
|
+
Some config keys were changed in PR https://github.com/autogluon/autogluon/pull/4737
|
821
|
+
This function is to make the changes backward compatible.
|
822
|
+
|
823
|
+
Parameters
|
824
|
+
----------
|
825
|
+
overrides
|
826
|
+
A dictionary containing the user-provided hyperparameters,
|
827
|
+
which may contain old config keys.
|
828
|
+
|
829
|
+
Returns
|
830
|
+
-------
|
831
|
+
Overrides with up-to-date config keys.
|
832
|
+
"""
|
833
|
+
key_pairs = {
|
834
|
+
"optim.learning_rate": "optim.lr",
|
835
|
+
"optim.efficient_finetune": "optim.peft",
|
836
|
+
"optim.loss_function": "optim.loss_func",
|
837
|
+
"env.num_workers_evaluation": "env.num_workers_inference",
|
838
|
+
"env.eval_batch_size_ratio": "env.inference_batch_size_ratio",
|
839
|
+
"data.label.numerical_label_preprocessing": "data.label.numerical_preprocessing",
|
840
|
+
"model.categorical_mlp.drop_rate": "model.categorical_mlp.dropout",
|
841
|
+
"model.numerical_mlp.drop_rate": "model.numerical_mlp.dropout",
|
842
|
+
"model.numerical_mlp.d_token": "model.numerical_mlp.token_dim",
|
843
|
+
"model.timm_image.max_img_num_per_col": "model.timm_image.max_image_num_per_column",
|
844
|
+
"model.clip.max_img_num_per_col": "model.clip.max_image_num_per_column",
|
845
|
+
"model.clip_image.max_img_num_per_col": "model.clip_image.max_image_num_per_column",
|
846
|
+
"model.fusion_mlp.weight": "model.fusion_mlp.aux_loss_weight",
|
847
|
+
"model.fusion_mlp.drop_rate": "model.fusion_mlp.dropout",
|
848
|
+
"model.fusion_transformer.n_blocks": "model.fusion_transformer.num_blocks",
|
849
|
+
"model.fusion_transformer.attention_n_heads": "model.fusion_transformer.attention_num_heads",
|
850
|
+
"model.fusion_transformer.ffn_d_hidden": "model.fusion_transformer.ffn_hidden_size",
|
851
|
+
"model.ft_transformer.attention_n_heads": "model.ft_transformer.attention_num_heads",
|
852
|
+
}
|
853
|
+
for k in list(overrides.keys()):
|
854
|
+
provided_k = k
|
855
|
+
if k.startswith("optimization."):
|
856
|
+
k = "optim." + k[len("optimization.") :]
|
857
|
+
logger.warning(
|
858
|
+
f"The provided hyperparameter name {provided_k} contains a deprecated key `optimization.`. "
|
859
|
+
f"Please replace `optimization.` with `optim.` when customizing the optimization hyperparameters."
|
860
|
+
)
|
861
|
+
|
862
|
+
if k in key_pairs:
|
863
|
+
overrides[key_pairs[k]] = overrides.pop(provided_k)
|
864
|
+
logger.warning(
|
865
|
+
f"The hyperparameter name {provided_k} is depreciated. "
|
866
|
+
f"We recommend using the new name {key_pairs[k]} instead."
|
867
|
+
f"The deprecated hyperparameter will raise an exception starting in AutoGluon 1.4.0"
|
868
|
+
)
|
869
|
+
|
870
|
+
return overrides
|