autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,28 +1,60 @@
|
|
1
|
+
import functools
|
1
2
|
import logging
|
2
3
|
import re
|
3
4
|
import warnings
|
4
5
|
from typing import Dict, List, Optional, Tuple
|
5
6
|
|
7
|
+
import timm
|
6
8
|
import torch
|
7
9
|
import torch._dynamo
|
8
10
|
import torch.nn.functional as F
|
11
|
+
from omegaconf import DictConfig, OmegaConf
|
12
|
+
from timm.data.constants import (
|
13
|
+
IMAGENET_DEFAULT_MEAN,
|
14
|
+
IMAGENET_DEFAULT_STD,
|
15
|
+
IMAGENET_INCEPTION_MEAN,
|
16
|
+
IMAGENET_INCEPTION_STD,
|
17
|
+
)
|
9
18
|
from torch import nn
|
10
19
|
from torch.nn.modules.loss import _Loss
|
11
20
|
from transformers import AutoConfig, AutoModel, AutoTokenizer, BertTokenizer, CLIPTokenizer, ElectraTokenizer
|
12
21
|
from transformers.models.mask2former.modeling_mask2former import Mask2FormerLoss
|
13
22
|
|
14
23
|
from ..constants import (
|
15
|
-
|
24
|
+
ALL_MODALITIES,
|
25
|
+
CATEGORICAL,
|
26
|
+
CATEGORICAL_MLP,
|
16
27
|
CLASS_LOGITS,
|
17
|
-
|
18
|
-
|
28
|
+
CLIP,
|
29
|
+
CLIP_IMAGE_MEAN,
|
30
|
+
CLIP_IMAGE_STD,
|
31
|
+
DOCUMENT,
|
32
|
+
DOCUMENT_TRANSFORMER,
|
33
|
+
FT_TRANSFORMER,
|
34
|
+
FUSION_MLP,
|
35
|
+
FUSION_NER,
|
36
|
+
FUSION_TRANSFORMER,
|
37
|
+
HF_TEXT,
|
38
|
+
IMAGE,
|
19
39
|
LOGITS,
|
20
|
-
|
40
|
+
META_TRANSFORMER,
|
41
|
+
MMDET_IMAGE,
|
42
|
+
MMOCR_TEXT_DET,
|
43
|
+
MMOCR_TEXT_RECOG,
|
44
|
+
NER_TEXT,
|
45
|
+
NUMERICAL,
|
46
|
+
NUMERICAL_MLP,
|
21
47
|
OCR,
|
22
48
|
PEFT_ADDITIVE_STRATEGIES,
|
23
49
|
REGRESSION,
|
50
|
+
SAM,
|
24
51
|
SEMANTIC_MASK,
|
25
52
|
SEMANTIC_SEGMENTATION,
|
53
|
+
SEMANTIC_SEGMENTATION_IMG,
|
54
|
+
T_FEW,
|
55
|
+
TEXT,
|
56
|
+
TEXT_NER,
|
57
|
+
TIMM_IMAGE,
|
26
58
|
)
|
27
59
|
from .adaptation_layers import ConvLoRALinear, IA3Linear, IA3LoRALinear, LoRALinear
|
28
60
|
|
@@ -450,13 +482,13 @@ def get_column_features(
|
|
450
482
|
return column_features, feature_masks
|
451
483
|
|
452
484
|
|
453
|
-
def create_adaptation(
|
485
|
+
def create_adaptation(peft: str, layer: nn.Module, lora_r: int, lora_alpha: int, **kwargs):
|
454
486
|
"""
|
455
487
|
Creates a model adaptation module (IA3, LoRA, IA3_LoRA) given a linear layer.
|
456
488
|
|
457
489
|
Parameters
|
458
490
|
----------
|
459
|
-
|
491
|
+
peft
|
460
492
|
Name of the adaptation module.
|
461
493
|
layer
|
462
494
|
The layer the adaptation module should be applied to.
|
@@ -476,11 +508,11 @@ def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lo
|
|
476
508
|
-------
|
477
509
|
Model with injected LoRA modules.
|
478
510
|
"""
|
479
|
-
if "ia3_lora" in
|
511
|
+
if "ia3_lora" in peft:
|
480
512
|
return IA3LoRALinear(
|
481
513
|
layer.in_features, layer.out_features, r=lora_r, lora_alpha=lora_alpha, merge_weights=False
|
482
514
|
)
|
483
|
-
elif "conv_lora" in
|
515
|
+
elif "conv_lora" in peft:
|
484
516
|
return ConvLoRALinear(
|
485
517
|
layer.in_features,
|
486
518
|
layer.out_features,
|
@@ -489,13 +521,13 @@ def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lo
|
|
489
521
|
merge_weights=False,
|
490
522
|
conv_lora_expert_num=kwargs["conv_lora_expert_num"],
|
491
523
|
)
|
492
|
-
elif "ia3" in
|
524
|
+
elif "ia3" in peft:
|
493
525
|
return IA3Linear(layer.in_features, layer.out_features, merge_weights=False)
|
494
|
-
elif "lora" in
|
526
|
+
elif "lora" in peft:
|
495
527
|
return LoRALinear(layer.in_features, layer.out_features, r=lora_r, lora_alpha=lora_alpha, merge_weights=False)
|
496
|
-
elif
|
528
|
+
elif peft is not None:
|
497
529
|
raise NotImplementedError(
|
498
|
-
f"The efficient finetuning strategy '{
|
530
|
+
f"The efficient finetuning strategy '{peft}'"
|
499
531
|
f" is not supported. We only support"
|
500
532
|
f" {', '.join(PEFT_ADDITIVE_STRATEGIES)}."
|
501
533
|
)
|
@@ -503,7 +535,7 @@ def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lo
|
|
503
535
|
|
504
536
|
def inject_adaptation_to_linear_layer(
|
505
537
|
model: nn.Module,
|
506
|
-
|
538
|
+
peft: str,
|
507
539
|
lora_r: int = None,
|
508
540
|
lora_alpha: int = None,
|
509
541
|
filter: Optional[List[str]] = None,
|
@@ -520,7 +552,7 @@ def inject_adaptation_to_linear_layer(
|
|
520
552
|
----------
|
521
553
|
model
|
522
554
|
A PyTorch model.
|
523
|
-
|
555
|
+
peft
|
524
556
|
Efficient finetuning method that should be applied.
|
525
557
|
lora_r
|
526
558
|
The rank r of the low-rank decomposition.
|
@@ -553,7 +585,7 @@ def inject_adaptation_to_linear_layer(
|
|
553
585
|
assert isinstance(
|
554
586
|
layer, nn.Linear
|
555
587
|
), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
|
556
|
-
adaptation_layer = create_adaptation(
|
588
|
+
adaptation_layer = create_adaptation(peft, layer, lora_r, lora_alpha, **kwargs)
|
557
589
|
adaptation_layer.weight = layer.weight
|
558
590
|
adaptation_layer.bias = layer.bias
|
559
591
|
setattr(module, c_name, adaptation_layer)
|
@@ -786,7 +818,7 @@ def run_model(model: nn.Module, batch: dict, trt_model: Optional[nn.Module] = No
|
|
786
818
|
from ..utils.onnx import OnnxModule
|
787
819
|
from .document_transformer import DocumentTransformer
|
788
820
|
from .fusion.fusion_mlp import MultimodalFusionMLP
|
789
|
-
from .
|
821
|
+
from .hf_text import HFAutoModelForTextPrediction
|
790
822
|
from .t_few import TFewModel
|
791
823
|
from .timm_image import TimmAutoModelForImagePrediction
|
792
824
|
|
@@ -807,6 +839,7 @@ def run_model(model: nn.Module, batch: dict, trt_model: Optional[nn.Module] = No
|
|
807
839
|
# HACK input data types in ONNX
|
808
840
|
if batch[k].dtype == torch.int32:
|
809
841
|
batch[k] = batch[k].to(torch.int64)
|
842
|
+
# DocumentTransformer inherited from HFAutoModelForTextPrediction
|
810
843
|
if (not isinstance(pure_model, DocumentTransformer)) and isinstance(pure_model, supported_models):
|
811
844
|
input_vec = [batch[k] for k in pure_model.input_keys]
|
812
845
|
column_names, column_values = [], []
|
@@ -903,3 +936,831 @@ def get_pretrained_tokenizer(
|
|
903
936
|
return tokenizer
|
904
937
|
except:
|
905
938
|
raise e
|
939
|
+
|
940
|
+
|
941
|
+
def extract_value_from_config(
|
942
|
+
config: Dict,
|
943
|
+
keys: Tuple[str, ...],
|
944
|
+
):
|
945
|
+
"""
|
946
|
+
Traverse a config dictionary to get some hyper-parameter's value.
|
947
|
+
|
948
|
+
Parameters
|
949
|
+
----------
|
950
|
+
config
|
951
|
+
A config dictionary.
|
952
|
+
keys
|
953
|
+
The possible names of a hyper-parameter.
|
954
|
+
|
955
|
+
Returns
|
956
|
+
-------
|
957
|
+
The hyper-parameter value.
|
958
|
+
"""
|
959
|
+
result = []
|
960
|
+
for k, v in config.items():
|
961
|
+
if k in keys:
|
962
|
+
result.append(v)
|
963
|
+
elif isinstance(v, dict):
|
964
|
+
result += extract_value_from_config(v, keys)
|
965
|
+
else:
|
966
|
+
pass
|
967
|
+
|
968
|
+
return result
|
969
|
+
|
970
|
+
|
971
|
+
def extract_image_hparams_from_config(model_name: str, config):
|
972
|
+
"""
|
973
|
+
Extract some default hyper-parameters, e.g., image size, mean, and std,
|
974
|
+
from a pre-trained (timm or huggingface) checkpoint.
|
975
|
+
|
976
|
+
Parameters
|
977
|
+
----------
|
978
|
+
model_name
|
979
|
+
Name of model.
|
980
|
+
config
|
981
|
+
Config of a pre-trained checkpoint.
|
982
|
+
|
983
|
+
Returns
|
984
|
+
-------
|
985
|
+
image_size
|
986
|
+
Image width/height.
|
987
|
+
mean
|
988
|
+
Image normalization mean.
|
989
|
+
std
|
990
|
+
Image normalizaiton std.
|
991
|
+
"""
|
992
|
+
if model_name.lower().startswith((TIMM_IMAGE, META_TRANSFORMER)):
|
993
|
+
image_size = config["input_size"][-1]
|
994
|
+
image_mean = config["mean"]
|
995
|
+
image_std = config["std"]
|
996
|
+
elif model_name.lower().startswith((CLIP, DOCUMENT_TRANSFORMER)):
|
997
|
+
extracted = extract_value_from_config(
|
998
|
+
config=config.to_diff_dict(),
|
999
|
+
keys=("image_size",),
|
1000
|
+
)
|
1001
|
+
if len(extracted) == 0:
|
1002
|
+
image_size = None
|
1003
|
+
elif len(extracted) >= 1:
|
1004
|
+
image_size = extracted[0]
|
1005
|
+
if isinstance(image_size, tuple):
|
1006
|
+
image_size = image_size[-1]
|
1007
|
+
else:
|
1008
|
+
raise ValueError(f" more than one image_size values are detected: {extracted}")
|
1009
|
+
image_mean = None
|
1010
|
+
image_std = None
|
1011
|
+
else:
|
1012
|
+
raise ValueError(f"Unknown image processor prefix: {model_name}")
|
1013
|
+
return image_size, image_mean, image_std
|
1014
|
+
|
1015
|
+
|
1016
|
+
def image_mean_std(norm_type: str):
|
1017
|
+
"""
|
1018
|
+
Get image normalization mean and std by its name.
|
1019
|
+
|
1020
|
+
Parameters
|
1021
|
+
----------
|
1022
|
+
norm_type
|
1023
|
+
Name of image normalization.
|
1024
|
+
|
1025
|
+
Returns
|
1026
|
+
-------
|
1027
|
+
Normalization mean and std.
|
1028
|
+
"""
|
1029
|
+
if norm_type == "inception":
|
1030
|
+
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
1031
|
+
elif norm_type == "imagenet":
|
1032
|
+
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
1033
|
+
elif norm_type == "clip":
|
1034
|
+
return CLIP_IMAGE_MEAN, CLIP_IMAGE_STD
|
1035
|
+
else:
|
1036
|
+
raise ValueError(f"unknown image normalization: {norm_type}")
|
1037
|
+
|
1038
|
+
|
1039
|
+
def get_image_size_mean_std(
|
1040
|
+
model_name: str,
|
1041
|
+
config,
|
1042
|
+
provided_size: int,
|
1043
|
+
provided_norm_type: str,
|
1044
|
+
support_variable_input_size: Optional[bool] = False,
|
1045
|
+
):
|
1046
|
+
image_size, image_mean, image_std = extract_image_hparams_from_config(
|
1047
|
+
model_name=model_name,
|
1048
|
+
config=config,
|
1049
|
+
)
|
1050
|
+
if support_variable_input_size and provided_size is not None:
|
1051
|
+
# We have detected that the model supports using an image size that is
|
1052
|
+
# different from the pretrained model, e.g., ConvNets with global pooling
|
1053
|
+
if provided_size < image_size:
|
1054
|
+
logger.warning(
|
1055
|
+
f"The provided image size={provided_size} is smaller than the default size "
|
1056
|
+
f"of the pretrained backbone, which is {image_size}. "
|
1057
|
+
f"Detailed configuration of the backbone is in {config}. "
|
1058
|
+
f"You may like to double check your configuration."
|
1059
|
+
)
|
1060
|
+
image_size = provided_size
|
1061
|
+
elif provided_size is not None and provided_size != image_size:
|
1062
|
+
logger.warning(
|
1063
|
+
f"The model does not support using an image size that is different from the default size. "
|
1064
|
+
f"Provided image size={provided_size}. Default size={image_size}. "
|
1065
|
+
f"Detailed model configuration={config}. We have ignored the provided image size."
|
1066
|
+
)
|
1067
|
+
|
1068
|
+
if image_size is None:
|
1069
|
+
if provided_size is not None:
|
1070
|
+
image_size = provided_size
|
1071
|
+
logger.debug(f"using provided image size: {image_size}.")
|
1072
|
+
else:
|
1073
|
+
raise ValueError("image size is missing.")
|
1074
|
+
else:
|
1075
|
+
logger.debug(f"using detected image size: {image_size}")
|
1076
|
+
|
1077
|
+
if image_mean is None or image_std is None:
|
1078
|
+
if provided_norm_type is not None:
|
1079
|
+
image_mean, image_std = image_mean_std(provided_norm_type)
|
1080
|
+
logger.debug(f"using provided normalization: {provided_norm_type}.")
|
1081
|
+
else:
|
1082
|
+
raise ValueError("image normalization mean and std are missing.")
|
1083
|
+
else:
|
1084
|
+
logger.debug(f"using detected image normalization: {image_mean} and {image_std}.")
|
1085
|
+
|
1086
|
+
return image_size, image_mean, image_std
|
1087
|
+
|
1088
|
+
|
1089
|
+
def get_text_segment_num(config, provided_segment_num: int, checkpoint_name: str):
|
1090
|
+
extracted = extract_value_from_config(config=config.to_diff_dict(), keys=("type_vocab_size",))
|
1091
|
+
if len(extracted) == 0:
|
1092
|
+
default_segment_num = 1
|
1093
|
+
elif len(extracted) == 1:
|
1094
|
+
default_segment_num = extracted[0]
|
1095
|
+
else:
|
1096
|
+
raise ValueError(f" more than one type_vocab_size values are detected: {extracted}")
|
1097
|
+
|
1098
|
+
if default_segment_num <= 0:
|
1099
|
+
default_segment_num = 1
|
1100
|
+
|
1101
|
+
if provided_segment_num < default_segment_num:
|
1102
|
+
warnings.warn(
|
1103
|
+
f"provided text_segment_num: {provided_segment_num} "
|
1104
|
+
f"is smaller than {checkpoint_name}'s default: {default_segment_num}"
|
1105
|
+
)
|
1106
|
+
text_segment_num = min(provided_segment_num, default_segment_num)
|
1107
|
+
assert text_segment_num >= 1
|
1108
|
+
logger.debug(f"text segment num: {text_segment_num}")
|
1109
|
+
|
1110
|
+
return text_segment_num
|
1111
|
+
|
1112
|
+
|
1113
|
+
def get_text_token_max_len(provided_max_len, config, tokenizer, checkpoint_name):
|
1114
|
+
"""
|
1115
|
+
Compute the allowable max length of token sequences.
|
1116
|
+
|
1117
|
+
Parameters
|
1118
|
+
----------
|
1119
|
+
provided_max_len
|
1120
|
+
The provided max length.
|
1121
|
+
config
|
1122
|
+
Model config.
|
1123
|
+
tokenizer
|
1124
|
+
Text tokenizer.
|
1125
|
+
checkpoint_name
|
1126
|
+
Name of checkpoint.
|
1127
|
+
|
1128
|
+
Returns
|
1129
|
+
-------
|
1130
|
+
Token sequence max length.
|
1131
|
+
"""
|
1132
|
+
if hasattr(config, "relative_attention") and config.relative_attention:
|
1133
|
+
default_max_len = tokenizer.model_max_length
|
1134
|
+
elif hasattr(config, "position_embedding_type") and "relative" in config.position_embedding_type:
|
1135
|
+
default_max_len = tokenizer.model_max_length
|
1136
|
+
elif hasattr(config, "max_position_embeddings"):
|
1137
|
+
default_max_len = config.max_position_embeddings
|
1138
|
+
else:
|
1139
|
+
default_max_len = tokenizer.model_max_length
|
1140
|
+
|
1141
|
+
if provided_max_len is None or provided_max_len <= 0:
|
1142
|
+
max_len = default_max_len
|
1143
|
+
else:
|
1144
|
+
if provided_max_len < default_max_len:
|
1145
|
+
if default_max_len < 10**6: # Larger than this value usually means infinite.
|
1146
|
+
warnings.warn(
|
1147
|
+
f"provided max length: {provided_max_len} "
|
1148
|
+
f"is smaller than {checkpoint_name}'s default: {default_max_len}"
|
1149
|
+
)
|
1150
|
+
max_len = min(provided_max_len, default_max_len)
|
1151
|
+
|
1152
|
+
logger.debug(f"text max length: {max_len}")
|
1153
|
+
|
1154
|
+
return max_len
|
1155
|
+
|
1156
|
+
|
1157
|
+
def replace_missing_images_with_learnable(
|
1158
|
+
images: torch.Tensor,
|
1159
|
+
image_masks,
|
1160
|
+
learnable_image: nn.Parameter,
|
1161
|
+
):
|
1162
|
+
b, n, c, h, w = images.shape
|
1163
|
+
assert learnable_image.shape == (c, h, w)
|
1164
|
+
for i in range(b):
|
1165
|
+
for j in range(n):
|
1166
|
+
if not image_masks[i][j]: # False means a missing image
|
1167
|
+
images[i][j] = learnable_image
|
1168
|
+
|
1169
|
+
return images
|
1170
|
+
|
1171
|
+
|
1172
|
+
def select_model(
|
1173
|
+
config: DictConfig,
|
1174
|
+
df_preprocessor,
|
1175
|
+
strict: Optional[bool] = True,
|
1176
|
+
):
|
1177
|
+
"""
|
1178
|
+
Filter model config through the detected modalities in the training data.
|
1179
|
+
If MultiModalFeaturePreprocessor can't detect some modality,
|
1180
|
+
this function will remove the models that use this modality. This function is to
|
1181
|
+
maximize the user flexibility in defining the config.
|
1182
|
+
For example, if one uses the default, including hf_text and timm_image, as the model config template
|
1183
|
+
but the training data don't have images, this function will filter out timm_image.
|
1184
|
+
|
1185
|
+
Parameters
|
1186
|
+
----------
|
1187
|
+
config
|
1188
|
+
A DictConfig object. The model config should be accessible by "config.model"
|
1189
|
+
df_preprocessor
|
1190
|
+
A MultiModalFeaturePreprocessor object, which has called .fit() on the training data.
|
1191
|
+
Column names of the same modality are grouped into one list. If a modality's list is empty,
|
1192
|
+
it means the training data don't have this modality.
|
1193
|
+
strict
|
1194
|
+
If False, allow retaining one model when partial modalities are available for that model.
|
1195
|
+
|
1196
|
+
Returns
|
1197
|
+
-------
|
1198
|
+
Config with some unused models removed.
|
1199
|
+
"""
|
1200
|
+
data_status = {}
|
1201
|
+
for per_modality in ALL_MODALITIES:
|
1202
|
+
data_status[per_modality] = False
|
1203
|
+
if len(df_preprocessor.image_feature_names) > 0:
|
1204
|
+
data_status[IMAGE] = True
|
1205
|
+
if len(df_preprocessor.text_feature_names) > 0:
|
1206
|
+
data_status[TEXT] = True
|
1207
|
+
if len(df_preprocessor.categorical_feature_names) > 0:
|
1208
|
+
data_status[CATEGORICAL] = True
|
1209
|
+
if len(df_preprocessor.numerical_feature_names) > 0:
|
1210
|
+
data_status[NUMERICAL] = True
|
1211
|
+
if len(df_preprocessor.ner_feature_names) > 0:
|
1212
|
+
data_status[TEXT_NER] = True
|
1213
|
+
if len(df_preprocessor.document_feature_names) > 0:
|
1214
|
+
data_status[DOCUMENT] = True
|
1215
|
+
if len(df_preprocessor.semantic_segmentation_feature_names) > 0:
|
1216
|
+
data_status[SEMANTIC_SEGMENTATION_IMG] = True
|
1217
|
+
|
1218
|
+
names = config.model.names
|
1219
|
+
if isinstance(names, str):
|
1220
|
+
names = [names]
|
1221
|
+
selected_model_names = []
|
1222
|
+
fusion_model_name = []
|
1223
|
+
for model_name in names:
|
1224
|
+
model_config = getattr(config.model, model_name)
|
1225
|
+
strict = getattr(model_config, "requires_all_dtypes", strict)
|
1226
|
+
if not model_config.data_types:
|
1227
|
+
fusion_model_name.append(model_name)
|
1228
|
+
continue
|
1229
|
+
model_data_status = [data_status[d_type] for d_type in model_config.data_types]
|
1230
|
+
if all(model_data_status):
|
1231
|
+
selected_model_names.append(model_name)
|
1232
|
+
else:
|
1233
|
+
if any(model_data_status) and not strict:
|
1234
|
+
selected_model_names.append(model_name)
|
1235
|
+
# update data types to be consistent with detected
|
1236
|
+
model_config.data_types = [d_type for d_type in model_config.data_types if data_status[d_type]]
|
1237
|
+
else:
|
1238
|
+
delattr(config.model, model_name)
|
1239
|
+
|
1240
|
+
if len(selected_model_names) == 0:
|
1241
|
+
raise ValueError("No model is available for this dataset.")
|
1242
|
+
# only allow no more than 1 fusion model
|
1243
|
+
if len(fusion_model_name) > 1:
|
1244
|
+
raise ValueError(f"More than one fusion models `{fusion_model_name}` are detected, but only one is allowed.")
|
1245
|
+
|
1246
|
+
if len(selected_model_names) > 1:
|
1247
|
+
assert len(fusion_model_name) == 1
|
1248
|
+
selected_model_names.extend(fusion_model_name)
|
1249
|
+
elif len(fusion_model_name) == 1 and hasattr(config.model, fusion_model_name[0]):
|
1250
|
+
delattr(config.model, fusion_model_name[0])
|
1251
|
+
|
1252
|
+
config.model.names = selected_model_names
|
1253
|
+
logger.debug(f"selected models: {selected_model_names}")
|
1254
|
+
for model_name in selected_model_names:
|
1255
|
+
logger.debug(f"model dtypes: {getattr(config.model, model_name).data_types}")
|
1256
|
+
|
1257
|
+
# clean up unused model configs
|
1258
|
+
model_keys = list(config.model.keys())
|
1259
|
+
for model_name in model_keys:
|
1260
|
+
if model_name not in selected_model_names + ["names"]:
|
1261
|
+
delattr(config.model, model_name)
|
1262
|
+
|
1263
|
+
return config
|
1264
|
+
|
1265
|
+
|
1266
|
+
def create_model(
|
1267
|
+
model_name: str,
|
1268
|
+
model_config: DictConfig,
|
1269
|
+
num_classes: Optional[int] = 0,
|
1270
|
+
classes: Optional[list] = None,
|
1271
|
+
num_numerical_columns: Optional[int] = None,
|
1272
|
+
num_categories: Optional[Dict] = None,
|
1273
|
+
numerical_fill_values: Optional[Dict] = None,
|
1274
|
+
pretrained: Optional[bool] = True,
|
1275
|
+
is_matching: Optional[bool] = False,
|
1276
|
+
):
|
1277
|
+
"""
|
1278
|
+
Create a single model.
|
1279
|
+
|
1280
|
+
Parameters
|
1281
|
+
----------
|
1282
|
+
model_name
|
1283
|
+
Name of the model.
|
1284
|
+
model_config
|
1285
|
+
Config of the model.
|
1286
|
+
num_classes
|
1287
|
+
The class number for a classification task. It should be 1 for a regression task.
|
1288
|
+
classes
|
1289
|
+
All classes in this dataset.
|
1290
|
+
num_numerical_columns
|
1291
|
+
The number of numerical columns in the training dataframe.
|
1292
|
+
num_categories
|
1293
|
+
The category number for each categorical column in the training dataframe.
|
1294
|
+
numerical_fill_values
|
1295
|
+
If numerical values are null, fill them with these.
|
1296
|
+
pretrained
|
1297
|
+
Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
|
1298
|
+
is_matching
|
1299
|
+
Whether the model is used for semantic matching.
|
1300
|
+
|
1301
|
+
Returns
|
1302
|
+
-------
|
1303
|
+
A model.
|
1304
|
+
"""
|
1305
|
+
if model_name.lower().startswith(CLIP):
|
1306
|
+
from .clip import CLIPForImageText
|
1307
|
+
|
1308
|
+
model = CLIPForImageText(
|
1309
|
+
prefix=model_name,
|
1310
|
+
checkpoint_name=model_config.checkpoint_name,
|
1311
|
+
num_classes=num_classes,
|
1312
|
+
pretrained=pretrained,
|
1313
|
+
tokenizer_name=model_config.tokenizer_name,
|
1314
|
+
has_image=IMAGE in model_config.data_types,
|
1315
|
+
has_text=TEXT in model_config.data_types,
|
1316
|
+
image_size=model_config.image_size,
|
1317
|
+
image_norm=model_config.image_norm,
|
1318
|
+
image_chan_num=model_config.image_chan_num,
|
1319
|
+
use_learnable_image=model_config.use_learnable_image,
|
1320
|
+
max_text_len=model_config.max_text_len,
|
1321
|
+
text_segment_num=model_config.text_segment_num,
|
1322
|
+
is_matching=is_matching,
|
1323
|
+
)
|
1324
|
+
elif model_name.lower().startswith(TIMM_IMAGE):
|
1325
|
+
from .timm_image import TimmAutoModelForImagePrediction
|
1326
|
+
|
1327
|
+
model = TimmAutoModelForImagePrediction(
|
1328
|
+
prefix=model_name,
|
1329
|
+
checkpoint_name=model_config.checkpoint_name,
|
1330
|
+
num_classes=num_classes,
|
1331
|
+
mix_choice=model_config.mix_choice,
|
1332
|
+
pretrained=pretrained,
|
1333
|
+
image_size=model_config.image_size,
|
1334
|
+
image_norm=model_config.image_norm,
|
1335
|
+
image_chan_num=model_config.image_chan_num,
|
1336
|
+
use_learnable_image=model_config.use_learnable_image,
|
1337
|
+
)
|
1338
|
+
elif model_name.lower().startswith(HF_TEXT):
|
1339
|
+
from .hf_text import HFAutoModelForTextPrediction
|
1340
|
+
|
1341
|
+
model = HFAutoModelForTextPrediction(
|
1342
|
+
prefix=model_name,
|
1343
|
+
checkpoint_name=model_config.checkpoint_name,
|
1344
|
+
num_classes=num_classes,
|
1345
|
+
pooling_mode=model_config.pooling_mode,
|
1346
|
+
gradient_checkpointing=model_config.gradient_checkpointing,
|
1347
|
+
low_cpu_mem_usage=model_config.low_cpu_mem_usage,
|
1348
|
+
pretrained=pretrained,
|
1349
|
+
tokenizer_name=model_config.tokenizer_name,
|
1350
|
+
max_text_len=model_config.max_text_len,
|
1351
|
+
text_segment_num=model_config.text_segment_num,
|
1352
|
+
use_fast=model_config.use_fast,
|
1353
|
+
)
|
1354
|
+
elif model_name.lower().startswith(T_FEW):
|
1355
|
+
from .t_few import TFewModel
|
1356
|
+
|
1357
|
+
model = TFewModel(
|
1358
|
+
prefix=model_name,
|
1359
|
+
checkpoint_name=model_config.checkpoint_name,
|
1360
|
+
length_norm=model_config.length_norm, # Normalizes length to adjust for length bias in target template
|
1361
|
+
unlikely_loss=model_config.unlikely_loss, # Adds loss term that lowers probability of incorrect outputs
|
1362
|
+
mc_loss=model_config.mc_loss, # Adds multiple choice cross entropy loss
|
1363
|
+
num_classes=num_classes,
|
1364
|
+
gradient_checkpointing=model_config.gradient_checkpointing,
|
1365
|
+
low_cpu_mem_usage=model_config.low_cpu_mem_usage,
|
1366
|
+
pretrained=pretrained,
|
1367
|
+
tokenizer_name=model_config.tokenizer_name,
|
1368
|
+
max_text_len=model_config.max_text_len,
|
1369
|
+
text_segment_num=model_config.text_segment_num,
|
1370
|
+
)
|
1371
|
+
elif model_name.lower().startswith(NUMERICAL_MLP):
|
1372
|
+
from .numerical_mlp import NumericalMLP
|
1373
|
+
|
1374
|
+
model = NumericalMLP(
|
1375
|
+
prefix=model_name,
|
1376
|
+
in_features=num_numerical_columns,
|
1377
|
+
hidden_features=model_config.hidden_size,
|
1378
|
+
out_features=model_config.hidden_size,
|
1379
|
+
num_layers=model_config.num_layers,
|
1380
|
+
activation=model_config.activation,
|
1381
|
+
dropout=model_config.dropout,
|
1382
|
+
normalization=model_config.normalization,
|
1383
|
+
token_dim=model_config.token_dim,
|
1384
|
+
embedding_arch=model_config.embedding_arch,
|
1385
|
+
num_classes=num_classes,
|
1386
|
+
numerical_fill_values=numerical_fill_values,
|
1387
|
+
)
|
1388
|
+
elif model_name.lower().startswith(CATEGORICAL_MLP):
|
1389
|
+
from .categorical_mlp import CategoricalMLP
|
1390
|
+
|
1391
|
+
model = CategoricalMLP(
|
1392
|
+
prefix=model_name,
|
1393
|
+
num_categories=num_categories,
|
1394
|
+
out_features=model_config.hidden_size,
|
1395
|
+
num_layers=model_config.num_layers,
|
1396
|
+
activation=model_config.activation,
|
1397
|
+
dropout=model_config.dropout,
|
1398
|
+
normalization=model_config.normalization,
|
1399
|
+
num_classes=num_classes,
|
1400
|
+
)
|
1401
|
+
elif model_name.lower().startswith(DOCUMENT_TRANSFORMER):
|
1402
|
+
from .document_transformer import DocumentTransformer
|
1403
|
+
|
1404
|
+
model = DocumentTransformer(
|
1405
|
+
prefix=model_name,
|
1406
|
+
checkpoint_name=model_config.checkpoint_name,
|
1407
|
+
num_classes=num_classes,
|
1408
|
+
pooling_mode=model_config.pooling_mode,
|
1409
|
+
gradient_checkpointing=model_config.gradient_checkpointing,
|
1410
|
+
low_cpu_mem_usage=model_config.low_cpu_mem_usage,
|
1411
|
+
pretrained=pretrained,
|
1412
|
+
tokenizer_name=model_config.tokenizer_name,
|
1413
|
+
image_size=model_config.image_size,
|
1414
|
+
image_norm=model_config.image_norm,
|
1415
|
+
)
|
1416
|
+
elif model_name.lower().startswith(MMDET_IMAGE):
|
1417
|
+
from .mmdet_image import MMDetAutoModelForObjectDetection
|
1418
|
+
|
1419
|
+
model = MMDetAutoModelForObjectDetection(
|
1420
|
+
prefix=model_name,
|
1421
|
+
checkpoint_name=model_config.checkpoint_name,
|
1422
|
+
config_file=model_config.config_file,
|
1423
|
+
classes=classes,
|
1424
|
+
pretrained=pretrained,
|
1425
|
+
output_bbox_format=model_config.output_bbox_format,
|
1426
|
+
frozen_layers=model_config.frozen_layers,
|
1427
|
+
)
|
1428
|
+
elif model_name.lower().startswith(MMOCR_TEXT_DET):
|
1429
|
+
from .mmocr_text_detection import MMOCRAutoModelForTextDetection
|
1430
|
+
|
1431
|
+
model = MMOCRAutoModelForTextDetection(
|
1432
|
+
prefix=model_name,
|
1433
|
+
checkpoint_name=model_config.checkpoint_name,
|
1434
|
+
)
|
1435
|
+
elif model_name.lower().startswith(MMOCR_TEXT_RECOG):
|
1436
|
+
from .mmocr_text_recognition import MMOCRAutoModelForTextRecognition
|
1437
|
+
|
1438
|
+
model = MMOCRAutoModelForTextRecognition(
|
1439
|
+
prefix=model_name,
|
1440
|
+
checkpoint_name=model_config.checkpoint_name,
|
1441
|
+
)
|
1442
|
+
elif model_name.lower().startswith(NER_TEXT):
|
1443
|
+
from .ner_text import HFAutoModelForNER
|
1444
|
+
|
1445
|
+
model = HFAutoModelForNER(
|
1446
|
+
prefix=model_name,
|
1447
|
+
checkpoint_name=model_config.checkpoint_name,
|
1448
|
+
num_classes=num_classes,
|
1449
|
+
gradient_checkpointing=model_config.gradient_checkpointing,
|
1450
|
+
low_cpu_mem_usage=model_config.low_cpu_mem_usage,
|
1451
|
+
pretrained=pretrained,
|
1452
|
+
tokenizer_name=model_config.tokenizer_name,
|
1453
|
+
)
|
1454
|
+
elif model_name.lower().startswith(FUSION_MLP):
|
1455
|
+
from .fusion import MultimodalFusionMLP
|
1456
|
+
|
1457
|
+
model = functools.partial(
|
1458
|
+
MultimodalFusionMLP,
|
1459
|
+
prefix=model_name,
|
1460
|
+
hidden_features=model_config.hidden_sizes,
|
1461
|
+
num_classes=num_classes,
|
1462
|
+
adapt_in_features=model_config.adapt_in_features,
|
1463
|
+
activation=model_config.activation,
|
1464
|
+
dropout=model_config.dropout,
|
1465
|
+
normalization=model_config.normalization,
|
1466
|
+
aux_loss_weight=model_config.aux_loss_weight,
|
1467
|
+
)
|
1468
|
+
elif model_name.lower().startswith(FUSION_NER):
|
1469
|
+
from .fusion import MultimodalFusionNER
|
1470
|
+
|
1471
|
+
model = functools.partial(
|
1472
|
+
MultimodalFusionNER,
|
1473
|
+
prefix=model_name,
|
1474
|
+
hidden_features=model_config.hidden_sizes,
|
1475
|
+
num_classes=num_classes,
|
1476
|
+
adapt_in_features=model_config.adapt_in_features,
|
1477
|
+
activation=model_config.activation,
|
1478
|
+
dropout_prob=model_config.drop_rate,
|
1479
|
+
normalization=model_config.normalization,
|
1480
|
+
loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
|
1481
|
+
)
|
1482
|
+
elif model_name.lower().startswith(FUSION_TRANSFORMER):
|
1483
|
+
from .fusion import MultimodalFusionTransformer
|
1484
|
+
|
1485
|
+
model = functools.partial(
|
1486
|
+
MultimodalFusionTransformer,
|
1487
|
+
prefix=model_name,
|
1488
|
+
hidden_features=model_config.hidden_size,
|
1489
|
+
num_classes=num_classes,
|
1490
|
+
num_blocks=model_config.num_blocks,
|
1491
|
+
attention_num_heads=model_config.attention_num_heads,
|
1492
|
+
ffn_hidden_size=model_config.ffn_hidden_size,
|
1493
|
+
attention_dropout=model_config.attention_dropout,
|
1494
|
+
residual_dropout=model_config.residual_dropout,
|
1495
|
+
ffn_dropout=model_config.ffn_dropout,
|
1496
|
+
attention_normalization=model_config.normalization,
|
1497
|
+
ffn_normalization=model_config.normalization,
|
1498
|
+
head_normalization=model_config.normalization,
|
1499
|
+
ffn_activation=model_config.ffn_activation,
|
1500
|
+
head_activation=model_config.head_activation,
|
1501
|
+
adapt_in_features=model_config.adapt_in_features,
|
1502
|
+
aux_loss_weight=model_config.aux_loss_weight,
|
1503
|
+
additive_attention=model_config.additive_attention,
|
1504
|
+
share_qv_weights=model_config.share_qv_weights,
|
1505
|
+
)
|
1506
|
+
elif model_name.lower().startswith(FT_TRANSFORMER):
|
1507
|
+
from .ft_transformer import FT_Transformer
|
1508
|
+
|
1509
|
+
model = FT_Transformer(
|
1510
|
+
prefix=model_name,
|
1511
|
+
num_numerical_columns=num_numerical_columns,
|
1512
|
+
num_categories=num_categories,
|
1513
|
+
numerical_fill_values=numerical_fill_values,
|
1514
|
+
embedding_arch=model_config.embedding_arch,
|
1515
|
+
token_dim=model_config.token_dim,
|
1516
|
+
hidden_size=model_config.hidden_size,
|
1517
|
+
hidden_features=model_config.hidden_size,
|
1518
|
+
num_classes=num_classes,
|
1519
|
+
num_blocks=model_config.num_blocks,
|
1520
|
+
attention_num_heads=model_config.attention_num_heads,
|
1521
|
+
attention_dropout=model_config.attention_dropout,
|
1522
|
+
attention_normalization=model_config.normalization,
|
1523
|
+
ffn_hidden_size=model_config.ffn_hidden_size,
|
1524
|
+
ffn_dropout=model_config.ffn_dropout,
|
1525
|
+
ffn_normalization=model_config.normalization,
|
1526
|
+
ffn_activation=model_config.ffn_activation,
|
1527
|
+
residual_dropout=model_config.residual_dropout,
|
1528
|
+
head_normalization=model_config.normalization,
|
1529
|
+
head_activation=model_config.head_activation,
|
1530
|
+
additive_attention=model_config.additive_attention,
|
1531
|
+
share_qv_weights=model_config.share_qv_weights,
|
1532
|
+
pooling_mode=model_config.pooling_mode,
|
1533
|
+
checkpoint_name=model_config.checkpoint_name,
|
1534
|
+
pretrained=pretrained,
|
1535
|
+
)
|
1536
|
+
elif model_name.lower().startswith(SAM):
|
1537
|
+
from .sam import SAMForSemanticSegmentation
|
1538
|
+
|
1539
|
+
model = SAMForSemanticSegmentation(
|
1540
|
+
prefix=model_name,
|
1541
|
+
checkpoint_name=model_config.checkpoint_name,
|
1542
|
+
num_classes=num_classes,
|
1543
|
+
pretrained=pretrained,
|
1544
|
+
frozen_layers=model_config.frozen_layers,
|
1545
|
+
num_mask_tokens=model_config.num_mask_tokens,
|
1546
|
+
image_norm=model_config.image_norm,
|
1547
|
+
)
|
1548
|
+
elif model_name.lower().startswith(META_TRANSFORMER):
|
1549
|
+
from .meta_transformer import MetaTransformer
|
1550
|
+
|
1551
|
+
model = MetaTransformer(
|
1552
|
+
prefix=model_name,
|
1553
|
+
checkpoint_path=model_config.checkpoint_path,
|
1554
|
+
num_classes=num_classes,
|
1555
|
+
model_version=model_config.model_version,
|
1556
|
+
has_image=IMAGE in model_config.data_types,
|
1557
|
+
has_text=TEXT in model_config.data_types,
|
1558
|
+
num_numerical_columns=num_numerical_columns,
|
1559
|
+
num_categories=num_categories,
|
1560
|
+
numerical_fill_values=numerical_fill_values,
|
1561
|
+
image_size=model_config.image_size,
|
1562
|
+
image_norm=model_config.image_norm,
|
1563
|
+
image_chan_num=model_config.image_chan_num,
|
1564
|
+
use_learnable_image=model_config.use_learnable_image,
|
1565
|
+
max_text_len=model_config.max_text_len,
|
1566
|
+
text_segment_num=model_config.text_segment_num,
|
1567
|
+
)
|
1568
|
+
else:
|
1569
|
+
raise ValueError(f"unknown model name: {model_name}")
|
1570
|
+
|
1571
|
+
return model
|
1572
|
+
|
1573
|
+
|
1574
|
+
def create_fusion_model(
|
1575
|
+
config: DictConfig,
|
1576
|
+
num_classes: Optional[int] = None,
|
1577
|
+
classes: Optional[list] = None,
|
1578
|
+
num_numerical_columns: Optional[int] = None,
|
1579
|
+
num_categories: Optional[Dict] = None,
|
1580
|
+
numerical_fill_values: Optional[Dict] = None,
|
1581
|
+
pretrained: Optional[bool] = True,
|
1582
|
+
):
|
1583
|
+
"""
|
1584
|
+
Create models. It supports the auto models of huggingface text and timm image.
|
1585
|
+
Multimodal models, e.g., CLIP, should be added case-by-case since their configs and usages
|
1586
|
+
may be different. It uses MLP for the numerical features, categorical features, and late-fusion.
|
1587
|
+
|
1588
|
+
Parameters
|
1589
|
+
----------
|
1590
|
+
config
|
1591
|
+
A DictConfig object. The model config should be accessible by "config.model".
|
1592
|
+
num_classes
|
1593
|
+
The class number for a classification task. It should be 1 for a regression task.
|
1594
|
+
classes
|
1595
|
+
All classes in this dataset.
|
1596
|
+
num_numerical_columns
|
1597
|
+
The number of numerical columns in the training dataframe.
|
1598
|
+
num_categories
|
1599
|
+
The category number for each categorical column in the training dataframe.
|
1600
|
+
numerical_fill_values
|
1601
|
+
If numerical values are null, fill them with these.
|
1602
|
+
pretrained
|
1603
|
+
Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
|
1604
|
+
|
1605
|
+
Returns
|
1606
|
+
-------
|
1607
|
+
A Pytorch model.
|
1608
|
+
"""
|
1609
|
+
names = config.model.names
|
1610
|
+
if isinstance(names, str):
|
1611
|
+
names = [names]
|
1612
|
+
# make sure no duplicate model names
|
1613
|
+
assert len(names) == len(set(names))
|
1614
|
+
logger.debug(f"output_shape: {num_classes}")
|
1615
|
+
names = sorted(names)
|
1616
|
+
config.model.names = names
|
1617
|
+
single_models = []
|
1618
|
+
fusion_model = None
|
1619
|
+
|
1620
|
+
for model_name in names:
|
1621
|
+
model_config = getattr(config.model, model_name)
|
1622
|
+
model = create_model(
|
1623
|
+
model_name=model_name,
|
1624
|
+
model_config=model_config,
|
1625
|
+
num_classes=num_classes,
|
1626
|
+
classes=classes,
|
1627
|
+
num_numerical_columns=num_numerical_columns,
|
1628
|
+
num_categories=num_categories,
|
1629
|
+
numerical_fill_values=numerical_fill_values,
|
1630
|
+
pretrained=pretrained,
|
1631
|
+
)
|
1632
|
+
|
1633
|
+
if isinstance(model, functools.partial): # fusion model
|
1634
|
+
if fusion_model is None:
|
1635
|
+
fusion_model = model
|
1636
|
+
else:
|
1637
|
+
raise ValueError(
|
1638
|
+
f"More than one fusion models are detected in {names}. Only one fusion model is allowed."
|
1639
|
+
)
|
1640
|
+
else: # single model
|
1641
|
+
if config.optim.peft is not None:
|
1642
|
+
model = apply_peft_adaptation(model, config)
|
1643
|
+
single_models.append(model)
|
1644
|
+
|
1645
|
+
if len(single_models) > 1:
|
1646
|
+
# must have one fusion model if there are multiple independent models
|
1647
|
+
model = fusion_model(models=single_models)
|
1648
|
+
elif len(single_models) == 1:
|
1649
|
+
model = single_models[0]
|
1650
|
+
else:
|
1651
|
+
raise ValueError(f"No available models for {names}")
|
1652
|
+
|
1653
|
+
# build augmenter for multimodal data augmentation
|
1654
|
+
if config.optim.lemda.turn_on:
|
1655
|
+
from .fusion import MultimodalFusionMLP
|
1656
|
+
|
1657
|
+
assert isinstance(model, MultimodalFusionMLP)
|
1658
|
+
from .augmenter import Augmenter
|
1659
|
+
|
1660
|
+
augmenter = Augmenter(
|
1661
|
+
arch_type=config.optim.lemda.arch_type,
|
1662
|
+
input_dim=model.augmenter_in_features,
|
1663
|
+
z_dim=config.optim.lemda.z_dim,
|
1664
|
+
num_layers=config.optim.lemda.num_layers,
|
1665
|
+
adv_weight=config.optim.lemda.adv_weight,
|
1666
|
+
)
|
1667
|
+
model.augmenter = augmenter
|
1668
|
+
|
1669
|
+
return model
|
1670
|
+
|
1671
|
+
|
1672
|
+
def apply_peft_adaptation(model: nn.Module, config: DictConfig) -> nn.Module:
|
1673
|
+
"""
|
1674
|
+
Apply an adaptation to the model for efficient fine-tuning.
|
1675
|
+
|
1676
|
+
Parameters
|
1677
|
+
----------
|
1678
|
+
model
|
1679
|
+
A PyTorch model.
|
1680
|
+
config:
|
1681
|
+
A DictConfig object. The optimization config should be accessible by "config.optimization".
|
1682
|
+
"""
|
1683
|
+
if config.optim.peft in PEFT_ADDITIVE_STRATEGIES:
|
1684
|
+
model = inject_adaptation_to_linear_layer(
|
1685
|
+
model=model,
|
1686
|
+
peft=config.optim.peft,
|
1687
|
+
lora_r=config.optim.lora.r,
|
1688
|
+
lora_alpha=config.optim.lora.alpha,
|
1689
|
+
module_filter=config.optim.lora.module_filter,
|
1690
|
+
filter=config.optim.lora.filter,
|
1691
|
+
extra_trainable_params=config.optim.extra_trainable_params,
|
1692
|
+
conv_lora_expert_num=config.optim.lora.conv_lora_expert_num,
|
1693
|
+
)
|
1694
|
+
model.name_to_id = model.get_layer_ids() # Need to update name to id dictionary.
|
1695
|
+
|
1696
|
+
return model
|
1697
|
+
|
1698
|
+
|
1699
|
+
def modify_duplicate_model_names(
|
1700
|
+
learner,
|
1701
|
+
postfix: str,
|
1702
|
+
blacklist: List[str],
|
1703
|
+
):
|
1704
|
+
"""
|
1705
|
+
Modify a learner's model names if they exist in a blacklist.
|
1706
|
+
|
1707
|
+
Parameters
|
1708
|
+
----------
|
1709
|
+
learner
|
1710
|
+
A BaseLearner object.
|
1711
|
+
postfix
|
1712
|
+
The postfix used to change the duplicate names.
|
1713
|
+
blacklist
|
1714
|
+
A list of names. The provided learner can't use model names in the list.
|
1715
|
+
|
1716
|
+
Returns
|
1717
|
+
-------
|
1718
|
+
The learner guaranteed has no duplicate model names with the blacklist names.
|
1719
|
+
"""
|
1720
|
+
model_names = []
|
1721
|
+
for n in learner._config.model.names:
|
1722
|
+
if n in blacklist:
|
1723
|
+
new_name = f"{n}_{postfix}"
|
1724
|
+
assert new_name not in blacklist
|
1725
|
+
assert new_name not in learner._config.model.names
|
1726
|
+
# modify model prefix
|
1727
|
+
if n == learner._model.prefix:
|
1728
|
+
learner._model.prefix = new_name
|
1729
|
+
else:
|
1730
|
+
assert isinstance(learner._model.model, nn.ModuleList)
|
1731
|
+
for per_model in learner._model.model:
|
1732
|
+
if n == per_model.prefix:
|
1733
|
+
per_model.prefix = new_name
|
1734
|
+
break
|
1735
|
+
# modify data processor prefix
|
1736
|
+
for per_modality_processors in learner._data_processors.values():
|
1737
|
+
for per_processor in per_modality_processors:
|
1738
|
+
if n == per_processor.prefix:
|
1739
|
+
per_processor.prefix = new_name
|
1740
|
+
# modify model config keys
|
1741
|
+
setattr(learner._config.model, new_name, getattr(learner._config.model, n))
|
1742
|
+
delattr(learner._config.model, n)
|
1743
|
+
|
1744
|
+
model_names.append(new_name)
|
1745
|
+
else:
|
1746
|
+
model_names.append(n)
|
1747
|
+
|
1748
|
+
learner._config.model.names = model_names
|
1749
|
+
|
1750
|
+
return learner
|
1751
|
+
|
1752
|
+
|
1753
|
+
def list_timm_models(pretrained=True):
|
1754
|
+
return timm.list_models(pretrained=pretrained)
|
1755
|
+
|
1756
|
+
|
1757
|
+
def is_lazy_weight_tensor(p: torch.Tensor) -> bool:
|
1758
|
+
from torch.nn.parameter import UninitializedParameter
|
1759
|
+
|
1760
|
+
if isinstance(p, UninitializedParameter):
|
1761
|
+
warnings.warn(
|
1762
|
+
"A layer with UninitializedParameter was found. "
|
1763
|
+
"Thus, the total number of parameters detected may be inaccurate."
|
1764
|
+
)
|
1765
|
+
return True
|
1766
|
+
return False
|