autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,28 +1,60 @@
1
+ import functools
1
2
  import logging
2
3
  import re
3
4
  import warnings
4
5
  from typing import Dict, List, Optional, Tuple
5
6
 
7
+ import timm
6
8
  import torch
7
9
  import torch._dynamo
8
10
  import torch.nn.functional as F
11
+ from omegaconf import DictConfig, OmegaConf
12
+ from timm.data.constants import (
13
+ IMAGENET_DEFAULT_MEAN,
14
+ IMAGENET_DEFAULT_STD,
15
+ IMAGENET_INCEPTION_MEAN,
16
+ IMAGENET_INCEPTION_STD,
17
+ )
9
18
  from torch import nn
10
19
  from torch.nn.modules.loss import _Loss
11
20
  from transformers import AutoConfig, AutoModel, AutoTokenizer, BertTokenizer, CLIPTokenizer, ElectraTokenizer
12
21
  from transformers.models.mask2former.modeling_mask2former import Mask2FormerLoss
13
22
 
14
23
  from ..constants import (
15
- AUTOMM,
24
+ ALL_MODALITIES,
25
+ CATEGORICAL,
26
+ CATEGORICAL_MLP,
16
27
  CLASS_LOGITS,
17
- COLUMN_FEATURES,
18
- FEATURES,
28
+ CLIP,
29
+ CLIP_IMAGE_MEAN,
30
+ CLIP_IMAGE_STD,
31
+ DOCUMENT,
32
+ DOCUMENT_TRANSFORMER,
33
+ FT_TRANSFORMER,
34
+ FUSION_MLP,
35
+ FUSION_NER,
36
+ FUSION_TRANSFORMER,
37
+ HF_TEXT,
38
+ IMAGE,
19
39
  LOGITS,
20
- MASKS,
40
+ META_TRANSFORMER,
41
+ MMDET_IMAGE,
42
+ MMOCR_TEXT_DET,
43
+ MMOCR_TEXT_RECOG,
44
+ NER_TEXT,
45
+ NUMERICAL,
46
+ NUMERICAL_MLP,
21
47
  OCR,
22
48
  PEFT_ADDITIVE_STRATEGIES,
23
49
  REGRESSION,
50
+ SAM,
24
51
  SEMANTIC_MASK,
25
52
  SEMANTIC_SEGMENTATION,
53
+ SEMANTIC_SEGMENTATION_IMG,
54
+ T_FEW,
55
+ TEXT,
56
+ TEXT_NER,
57
+ TIMM_IMAGE,
26
58
  )
27
59
  from .adaptation_layers import ConvLoRALinear, IA3Linear, IA3LoRALinear, LoRALinear
28
60
 
@@ -450,13 +482,13 @@ def get_column_features(
450
482
  return column_features, feature_masks
451
483
 
452
484
 
453
- def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lora_alpha: int, **kwargs):
485
+ def create_adaptation(peft: str, layer: nn.Module, lora_r: int, lora_alpha: int, **kwargs):
454
486
  """
455
487
  Creates a model adaptation module (IA3, LoRA, IA3_LoRA) given a linear layer.
456
488
 
457
489
  Parameters
458
490
  ----------
459
- efficient_finetune
491
+ peft
460
492
  Name of the adaptation module.
461
493
  layer
462
494
  The layer the adaptation module should be applied to.
@@ -476,11 +508,11 @@ def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lo
476
508
  -------
477
509
  Model with injected LoRA modules.
478
510
  """
479
- if "ia3_lora" in efficient_finetune:
511
+ if "ia3_lora" in peft:
480
512
  return IA3LoRALinear(
481
513
  layer.in_features, layer.out_features, r=lora_r, lora_alpha=lora_alpha, merge_weights=False
482
514
  )
483
- elif "conv_lora" in efficient_finetune:
515
+ elif "conv_lora" in peft:
484
516
  return ConvLoRALinear(
485
517
  layer.in_features,
486
518
  layer.out_features,
@@ -489,13 +521,13 @@ def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lo
489
521
  merge_weights=False,
490
522
  conv_lora_expert_num=kwargs["conv_lora_expert_num"],
491
523
  )
492
- elif "ia3" in efficient_finetune:
524
+ elif "ia3" in peft:
493
525
  return IA3Linear(layer.in_features, layer.out_features, merge_weights=False)
494
- elif "lora" in efficient_finetune:
526
+ elif "lora" in peft:
495
527
  return LoRALinear(layer.in_features, layer.out_features, r=lora_r, lora_alpha=lora_alpha, merge_weights=False)
496
- elif efficient_finetune is not None and efficient_finetune != "None":
528
+ elif peft is not None:
497
529
  raise NotImplementedError(
498
- f"The efficient finetuning strategy '{efficient_finetune}'"
530
+ f"The efficient finetuning strategy '{peft}'"
499
531
  f" is not supported. We only support"
500
532
  f" {', '.join(PEFT_ADDITIVE_STRATEGIES)}."
501
533
  )
@@ -503,7 +535,7 @@ def create_adaptation(efficient_finetune: str, layer: nn.Module, lora_r: int, lo
503
535
 
504
536
  def inject_adaptation_to_linear_layer(
505
537
  model: nn.Module,
506
- efficient_finetune: str,
538
+ peft: str,
507
539
  lora_r: int = None,
508
540
  lora_alpha: int = None,
509
541
  filter: Optional[List[str]] = None,
@@ -520,7 +552,7 @@ def inject_adaptation_to_linear_layer(
520
552
  ----------
521
553
  model
522
554
  A PyTorch model.
523
- efficient_finetune
555
+ peft
524
556
  Efficient finetuning method that should be applied.
525
557
  lora_r
526
558
  The rank r of the low-rank decomposition.
@@ -553,7 +585,7 @@ def inject_adaptation_to_linear_layer(
553
585
  assert isinstance(
554
586
  layer, nn.Linear
555
587
  ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
556
- adaptation_layer = create_adaptation(efficient_finetune, layer, lora_r, lora_alpha, **kwargs)
588
+ adaptation_layer = create_adaptation(peft, layer, lora_r, lora_alpha, **kwargs)
557
589
  adaptation_layer.weight = layer.weight
558
590
  adaptation_layer.bias = layer.bias
559
591
  setattr(module, c_name, adaptation_layer)
@@ -786,7 +818,7 @@ def run_model(model: nn.Module, batch: dict, trt_model: Optional[nn.Module] = No
786
818
  from ..utils.onnx import OnnxModule
787
819
  from .document_transformer import DocumentTransformer
788
820
  from .fusion.fusion_mlp import MultimodalFusionMLP
789
- from .huggingface_text import HFAutoModelForTextPrediction
821
+ from .hf_text import HFAutoModelForTextPrediction
790
822
  from .t_few import TFewModel
791
823
  from .timm_image import TimmAutoModelForImagePrediction
792
824
 
@@ -807,6 +839,7 @@ def run_model(model: nn.Module, batch: dict, trt_model: Optional[nn.Module] = No
807
839
  # HACK input data types in ONNX
808
840
  if batch[k].dtype == torch.int32:
809
841
  batch[k] = batch[k].to(torch.int64)
842
+ # DocumentTransformer inherited from HFAutoModelForTextPrediction
810
843
  if (not isinstance(pure_model, DocumentTransformer)) and isinstance(pure_model, supported_models):
811
844
  input_vec = [batch[k] for k in pure_model.input_keys]
812
845
  column_names, column_values = [], []
@@ -903,3 +936,831 @@ def get_pretrained_tokenizer(
903
936
  return tokenizer
904
937
  except:
905
938
  raise e
939
+
940
+
941
+ def extract_value_from_config(
942
+ config: Dict,
943
+ keys: Tuple[str, ...],
944
+ ):
945
+ """
946
+ Traverse a config dictionary to get some hyper-parameter's value.
947
+
948
+ Parameters
949
+ ----------
950
+ config
951
+ A config dictionary.
952
+ keys
953
+ The possible names of a hyper-parameter.
954
+
955
+ Returns
956
+ -------
957
+ The hyper-parameter value.
958
+ """
959
+ result = []
960
+ for k, v in config.items():
961
+ if k in keys:
962
+ result.append(v)
963
+ elif isinstance(v, dict):
964
+ result += extract_value_from_config(v, keys)
965
+ else:
966
+ pass
967
+
968
+ return result
969
+
970
+
971
+ def extract_image_hparams_from_config(model_name: str, config):
972
+ """
973
+ Extract some default hyper-parameters, e.g., image size, mean, and std,
974
+ from a pre-trained (timm or huggingface) checkpoint.
975
+
976
+ Parameters
977
+ ----------
978
+ model_name
979
+ Name of model.
980
+ config
981
+ Config of a pre-trained checkpoint.
982
+
983
+ Returns
984
+ -------
985
+ image_size
986
+ Image width/height.
987
+ mean
988
+ Image normalization mean.
989
+ std
990
+ Image normalizaiton std.
991
+ """
992
+ if model_name.lower().startswith((TIMM_IMAGE, META_TRANSFORMER)):
993
+ image_size = config["input_size"][-1]
994
+ image_mean = config["mean"]
995
+ image_std = config["std"]
996
+ elif model_name.lower().startswith((CLIP, DOCUMENT_TRANSFORMER)):
997
+ extracted = extract_value_from_config(
998
+ config=config.to_diff_dict(),
999
+ keys=("image_size",),
1000
+ )
1001
+ if len(extracted) == 0:
1002
+ image_size = None
1003
+ elif len(extracted) >= 1:
1004
+ image_size = extracted[0]
1005
+ if isinstance(image_size, tuple):
1006
+ image_size = image_size[-1]
1007
+ else:
1008
+ raise ValueError(f" more than one image_size values are detected: {extracted}")
1009
+ image_mean = None
1010
+ image_std = None
1011
+ else:
1012
+ raise ValueError(f"Unknown image processor prefix: {model_name}")
1013
+ return image_size, image_mean, image_std
1014
+
1015
+
1016
+ def image_mean_std(norm_type: str):
1017
+ """
1018
+ Get image normalization mean and std by its name.
1019
+
1020
+ Parameters
1021
+ ----------
1022
+ norm_type
1023
+ Name of image normalization.
1024
+
1025
+ Returns
1026
+ -------
1027
+ Normalization mean and std.
1028
+ """
1029
+ if norm_type == "inception":
1030
+ return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1031
+ elif norm_type == "imagenet":
1032
+ return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1033
+ elif norm_type == "clip":
1034
+ return CLIP_IMAGE_MEAN, CLIP_IMAGE_STD
1035
+ else:
1036
+ raise ValueError(f"unknown image normalization: {norm_type}")
1037
+
1038
+
1039
+ def get_image_size_mean_std(
1040
+ model_name: str,
1041
+ config,
1042
+ provided_size: int,
1043
+ provided_norm_type: str,
1044
+ support_variable_input_size: Optional[bool] = False,
1045
+ ):
1046
+ image_size, image_mean, image_std = extract_image_hparams_from_config(
1047
+ model_name=model_name,
1048
+ config=config,
1049
+ )
1050
+ if support_variable_input_size and provided_size is not None:
1051
+ # We have detected that the model supports using an image size that is
1052
+ # different from the pretrained model, e.g., ConvNets with global pooling
1053
+ if provided_size < image_size:
1054
+ logger.warning(
1055
+ f"The provided image size={provided_size} is smaller than the default size "
1056
+ f"of the pretrained backbone, which is {image_size}. "
1057
+ f"Detailed configuration of the backbone is in {config}. "
1058
+ f"You may like to double check your configuration."
1059
+ )
1060
+ image_size = provided_size
1061
+ elif provided_size is not None and provided_size != image_size:
1062
+ logger.warning(
1063
+ f"The model does not support using an image size that is different from the default size. "
1064
+ f"Provided image size={provided_size}. Default size={image_size}. "
1065
+ f"Detailed model configuration={config}. We have ignored the provided image size."
1066
+ )
1067
+
1068
+ if image_size is None:
1069
+ if provided_size is not None:
1070
+ image_size = provided_size
1071
+ logger.debug(f"using provided image size: {image_size}.")
1072
+ else:
1073
+ raise ValueError("image size is missing.")
1074
+ else:
1075
+ logger.debug(f"using detected image size: {image_size}")
1076
+
1077
+ if image_mean is None or image_std is None:
1078
+ if provided_norm_type is not None:
1079
+ image_mean, image_std = image_mean_std(provided_norm_type)
1080
+ logger.debug(f"using provided normalization: {provided_norm_type}.")
1081
+ else:
1082
+ raise ValueError("image normalization mean and std are missing.")
1083
+ else:
1084
+ logger.debug(f"using detected image normalization: {image_mean} and {image_std}.")
1085
+
1086
+ return image_size, image_mean, image_std
1087
+
1088
+
1089
+ def get_text_segment_num(config, provided_segment_num: int, checkpoint_name: str):
1090
+ extracted = extract_value_from_config(config=config.to_diff_dict(), keys=("type_vocab_size",))
1091
+ if len(extracted) == 0:
1092
+ default_segment_num = 1
1093
+ elif len(extracted) == 1:
1094
+ default_segment_num = extracted[0]
1095
+ else:
1096
+ raise ValueError(f" more than one type_vocab_size values are detected: {extracted}")
1097
+
1098
+ if default_segment_num <= 0:
1099
+ default_segment_num = 1
1100
+
1101
+ if provided_segment_num < default_segment_num:
1102
+ warnings.warn(
1103
+ f"provided text_segment_num: {provided_segment_num} "
1104
+ f"is smaller than {checkpoint_name}'s default: {default_segment_num}"
1105
+ )
1106
+ text_segment_num = min(provided_segment_num, default_segment_num)
1107
+ assert text_segment_num >= 1
1108
+ logger.debug(f"text segment num: {text_segment_num}")
1109
+
1110
+ return text_segment_num
1111
+
1112
+
1113
+ def get_text_token_max_len(provided_max_len, config, tokenizer, checkpoint_name):
1114
+ """
1115
+ Compute the allowable max length of token sequences.
1116
+
1117
+ Parameters
1118
+ ----------
1119
+ provided_max_len
1120
+ The provided max length.
1121
+ config
1122
+ Model config.
1123
+ tokenizer
1124
+ Text tokenizer.
1125
+ checkpoint_name
1126
+ Name of checkpoint.
1127
+
1128
+ Returns
1129
+ -------
1130
+ Token sequence max length.
1131
+ """
1132
+ if hasattr(config, "relative_attention") and config.relative_attention:
1133
+ default_max_len = tokenizer.model_max_length
1134
+ elif hasattr(config, "position_embedding_type") and "relative" in config.position_embedding_type:
1135
+ default_max_len = tokenizer.model_max_length
1136
+ elif hasattr(config, "max_position_embeddings"):
1137
+ default_max_len = config.max_position_embeddings
1138
+ else:
1139
+ default_max_len = tokenizer.model_max_length
1140
+
1141
+ if provided_max_len is None or provided_max_len <= 0:
1142
+ max_len = default_max_len
1143
+ else:
1144
+ if provided_max_len < default_max_len:
1145
+ if default_max_len < 10**6: # Larger than this value usually means infinite.
1146
+ warnings.warn(
1147
+ f"provided max length: {provided_max_len} "
1148
+ f"is smaller than {checkpoint_name}'s default: {default_max_len}"
1149
+ )
1150
+ max_len = min(provided_max_len, default_max_len)
1151
+
1152
+ logger.debug(f"text max length: {max_len}")
1153
+
1154
+ return max_len
1155
+
1156
+
1157
+ def replace_missing_images_with_learnable(
1158
+ images: torch.Tensor,
1159
+ image_masks,
1160
+ learnable_image: nn.Parameter,
1161
+ ):
1162
+ b, n, c, h, w = images.shape
1163
+ assert learnable_image.shape == (c, h, w)
1164
+ for i in range(b):
1165
+ for j in range(n):
1166
+ if not image_masks[i][j]: # False means a missing image
1167
+ images[i][j] = learnable_image
1168
+
1169
+ return images
1170
+
1171
+
1172
+ def select_model(
1173
+ config: DictConfig,
1174
+ df_preprocessor,
1175
+ strict: Optional[bool] = True,
1176
+ ):
1177
+ """
1178
+ Filter model config through the detected modalities in the training data.
1179
+ If MultiModalFeaturePreprocessor can't detect some modality,
1180
+ this function will remove the models that use this modality. This function is to
1181
+ maximize the user flexibility in defining the config.
1182
+ For example, if one uses the default, including hf_text and timm_image, as the model config template
1183
+ but the training data don't have images, this function will filter out timm_image.
1184
+
1185
+ Parameters
1186
+ ----------
1187
+ config
1188
+ A DictConfig object. The model config should be accessible by "config.model"
1189
+ df_preprocessor
1190
+ A MultiModalFeaturePreprocessor object, which has called .fit() on the training data.
1191
+ Column names of the same modality are grouped into one list. If a modality's list is empty,
1192
+ it means the training data don't have this modality.
1193
+ strict
1194
+ If False, allow retaining one model when partial modalities are available for that model.
1195
+
1196
+ Returns
1197
+ -------
1198
+ Config with some unused models removed.
1199
+ """
1200
+ data_status = {}
1201
+ for per_modality in ALL_MODALITIES:
1202
+ data_status[per_modality] = False
1203
+ if len(df_preprocessor.image_feature_names) > 0:
1204
+ data_status[IMAGE] = True
1205
+ if len(df_preprocessor.text_feature_names) > 0:
1206
+ data_status[TEXT] = True
1207
+ if len(df_preprocessor.categorical_feature_names) > 0:
1208
+ data_status[CATEGORICAL] = True
1209
+ if len(df_preprocessor.numerical_feature_names) > 0:
1210
+ data_status[NUMERICAL] = True
1211
+ if len(df_preprocessor.ner_feature_names) > 0:
1212
+ data_status[TEXT_NER] = True
1213
+ if len(df_preprocessor.document_feature_names) > 0:
1214
+ data_status[DOCUMENT] = True
1215
+ if len(df_preprocessor.semantic_segmentation_feature_names) > 0:
1216
+ data_status[SEMANTIC_SEGMENTATION_IMG] = True
1217
+
1218
+ names = config.model.names
1219
+ if isinstance(names, str):
1220
+ names = [names]
1221
+ selected_model_names = []
1222
+ fusion_model_name = []
1223
+ for model_name in names:
1224
+ model_config = getattr(config.model, model_name)
1225
+ strict = getattr(model_config, "requires_all_dtypes", strict)
1226
+ if not model_config.data_types:
1227
+ fusion_model_name.append(model_name)
1228
+ continue
1229
+ model_data_status = [data_status[d_type] for d_type in model_config.data_types]
1230
+ if all(model_data_status):
1231
+ selected_model_names.append(model_name)
1232
+ else:
1233
+ if any(model_data_status) and not strict:
1234
+ selected_model_names.append(model_name)
1235
+ # update data types to be consistent with detected
1236
+ model_config.data_types = [d_type for d_type in model_config.data_types if data_status[d_type]]
1237
+ else:
1238
+ delattr(config.model, model_name)
1239
+
1240
+ if len(selected_model_names) == 0:
1241
+ raise ValueError("No model is available for this dataset.")
1242
+ # only allow no more than 1 fusion model
1243
+ if len(fusion_model_name) > 1:
1244
+ raise ValueError(f"More than one fusion models `{fusion_model_name}` are detected, but only one is allowed.")
1245
+
1246
+ if len(selected_model_names) > 1:
1247
+ assert len(fusion_model_name) == 1
1248
+ selected_model_names.extend(fusion_model_name)
1249
+ elif len(fusion_model_name) == 1 and hasattr(config.model, fusion_model_name[0]):
1250
+ delattr(config.model, fusion_model_name[0])
1251
+
1252
+ config.model.names = selected_model_names
1253
+ logger.debug(f"selected models: {selected_model_names}")
1254
+ for model_name in selected_model_names:
1255
+ logger.debug(f"model dtypes: {getattr(config.model, model_name).data_types}")
1256
+
1257
+ # clean up unused model configs
1258
+ model_keys = list(config.model.keys())
1259
+ for model_name in model_keys:
1260
+ if model_name not in selected_model_names + ["names"]:
1261
+ delattr(config.model, model_name)
1262
+
1263
+ return config
1264
+
1265
+
1266
+ def create_model(
1267
+ model_name: str,
1268
+ model_config: DictConfig,
1269
+ num_classes: Optional[int] = 0,
1270
+ classes: Optional[list] = None,
1271
+ num_numerical_columns: Optional[int] = None,
1272
+ num_categories: Optional[Dict] = None,
1273
+ numerical_fill_values: Optional[Dict] = None,
1274
+ pretrained: Optional[bool] = True,
1275
+ is_matching: Optional[bool] = False,
1276
+ ):
1277
+ """
1278
+ Create a single model.
1279
+
1280
+ Parameters
1281
+ ----------
1282
+ model_name
1283
+ Name of the model.
1284
+ model_config
1285
+ Config of the model.
1286
+ num_classes
1287
+ The class number for a classification task. It should be 1 for a regression task.
1288
+ classes
1289
+ All classes in this dataset.
1290
+ num_numerical_columns
1291
+ The number of numerical columns in the training dataframe.
1292
+ num_categories
1293
+ The category number for each categorical column in the training dataframe.
1294
+ numerical_fill_values
1295
+ If numerical values are null, fill them with these.
1296
+ pretrained
1297
+ Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
1298
+ is_matching
1299
+ Whether the model is used for semantic matching.
1300
+
1301
+ Returns
1302
+ -------
1303
+ A model.
1304
+ """
1305
+ if model_name.lower().startswith(CLIP):
1306
+ from .clip import CLIPForImageText
1307
+
1308
+ model = CLIPForImageText(
1309
+ prefix=model_name,
1310
+ checkpoint_name=model_config.checkpoint_name,
1311
+ num_classes=num_classes,
1312
+ pretrained=pretrained,
1313
+ tokenizer_name=model_config.tokenizer_name,
1314
+ has_image=IMAGE in model_config.data_types,
1315
+ has_text=TEXT in model_config.data_types,
1316
+ image_size=model_config.image_size,
1317
+ image_norm=model_config.image_norm,
1318
+ image_chan_num=model_config.image_chan_num,
1319
+ use_learnable_image=model_config.use_learnable_image,
1320
+ max_text_len=model_config.max_text_len,
1321
+ text_segment_num=model_config.text_segment_num,
1322
+ is_matching=is_matching,
1323
+ )
1324
+ elif model_name.lower().startswith(TIMM_IMAGE):
1325
+ from .timm_image import TimmAutoModelForImagePrediction
1326
+
1327
+ model = TimmAutoModelForImagePrediction(
1328
+ prefix=model_name,
1329
+ checkpoint_name=model_config.checkpoint_name,
1330
+ num_classes=num_classes,
1331
+ mix_choice=model_config.mix_choice,
1332
+ pretrained=pretrained,
1333
+ image_size=model_config.image_size,
1334
+ image_norm=model_config.image_norm,
1335
+ image_chan_num=model_config.image_chan_num,
1336
+ use_learnable_image=model_config.use_learnable_image,
1337
+ )
1338
+ elif model_name.lower().startswith(HF_TEXT):
1339
+ from .hf_text import HFAutoModelForTextPrediction
1340
+
1341
+ model = HFAutoModelForTextPrediction(
1342
+ prefix=model_name,
1343
+ checkpoint_name=model_config.checkpoint_name,
1344
+ num_classes=num_classes,
1345
+ pooling_mode=model_config.pooling_mode,
1346
+ gradient_checkpointing=model_config.gradient_checkpointing,
1347
+ low_cpu_mem_usage=model_config.low_cpu_mem_usage,
1348
+ pretrained=pretrained,
1349
+ tokenizer_name=model_config.tokenizer_name,
1350
+ max_text_len=model_config.max_text_len,
1351
+ text_segment_num=model_config.text_segment_num,
1352
+ use_fast=model_config.use_fast,
1353
+ )
1354
+ elif model_name.lower().startswith(T_FEW):
1355
+ from .t_few import TFewModel
1356
+
1357
+ model = TFewModel(
1358
+ prefix=model_name,
1359
+ checkpoint_name=model_config.checkpoint_name,
1360
+ length_norm=model_config.length_norm, # Normalizes length to adjust for length bias in target template
1361
+ unlikely_loss=model_config.unlikely_loss, # Adds loss term that lowers probability of incorrect outputs
1362
+ mc_loss=model_config.mc_loss, # Adds multiple choice cross entropy loss
1363
+ num_classes=num_classes,
1364
+ gradient_checkpointing=model_config.gradient_checkpointing,
1365
+ low_cpu_mem_usage=model_config.low_cpu_mem_usage,
1366
+ pretrained=pretrained,
1367
+ tokenizer_name=model_config.tokenizer_name,
1368
+ max_text_len=model_config.max_text_len,
1369
+ text_segment_num=model_config.text_segment_num,
1370
+ )
1371
+ elif model_name.lower().startswith(NUMERICAL_MLP):
1372
+ from .numerical_mlp import NumericalMLP
1373
+
1374
+ model = NumericalMLP(
1375
+ prefix=model_name,
1376
+ in_features=num_numerical_columns,
1377
+ hidden_features=model_config.hidden_size,
1378
+ out_features=model_config.hidden_size,
1379
+ num_layers=model_config.num_layers,
1380
+ activation=model_config.activation,
1381
+ dropout=model_config.dropout,
1382
+ normalization=model_config.normalization,
1383
+ token_dim=model_config.token_dim,
1384
+ embedding_arch=model_config.embedding_arch,
1385
+ num_classes=num_classes,
1386
+ numerical_fill_values=numerical_fill_values,
1387
+ )
1388
+ elif model_name.lower().startswith(CATEGORICAL_MLP):
1389
+ from .categorical_mlp import CategoricalMLP
1390
+
1391
+ model = CategoricalMLP(
1392
+ prefix=model_name,
1393
+ num_categories=num_categories,
1394
+ out_features=model_config.hidden_size,
1395
+ num_layers=model_config.num_layers,
1396
+ activation=model_config.activation,
1397
+ dropout=model_config.dropout,
1398
+ normalization=model_config.normalization,
1399
+ num_classes=num_classes,
1400
+ )
1401
+ elif model_name.lower().startswith(DOCUMENT_TRANSFORMER):
1402
+ from .document_transformer import DocumentTransformer
1403
+
1404
+ model = DocumentTransformer(
1405
+ prefix=model_name,
1406
+ checkpoint_name=model_config.checkpoint_name,
1407
+ num_classes=num_classes,
1408
+ pooling_mode=model_config.pooling_mode,
1409
+ gradient_checkpointing=model_config.gradient_checkpointing,
1410
+ low_cpu_mem_usage=model_config.low_cpu_mem_usage,
1411
+ pretrained=pretrained,
1412
+ tokenizer_name=model_config.tokenizer_name,
1413
+ image_size=model_config.image_size,
1414
+ image_norm=model_config.image_norm,
1415
+ )
1416
+ elif model_name.lower().startswith(MMDET_IMAGE):
1417
+ from .mmdet_image import MMDetAutoModelForObjectDetection
1418
+
1419
+ model = MMDetAutoModelForObjectDetection(
1420
+ prefix=model_name,
1421
+ checkpoint_name=model_config.checkpoint_name,
1422
+ config_file=model_config.config_file,
1423
+ classes=classes,
1424
+ pretrained=pretrained,
1425
+ output_bbox_format=model_config.output_bbox_format,
1426
+ frozen_layers=model_config.frozen_layers,
1427
+ )
1428
+ elif model_name.lower().startswith(MMOCR_TEXT_DET):
1429
+ from .mmocr_text_detection import MMOCRAutoModelForTextDetection
1430
+
1431
+ model = MMOCRAutoModelForTextDetection(
1432
+ prefix=model_name,
1433
+ checkpoint_name=model_config.checkpoint_name,
1434
+ )
1435
+ elif model_name.lower().startswith(MMOCR_TEXT_RECOG):
1436
+ from .mmocr_text_recognition import MMOCRAutoModelForTextRecognition
1437
+
1438
+ model = MMOCRAutoModelForTextRecognition(
1439
+ prefix=model_name,
1440
+ checkpoint_name=model_config.checkpoint_name,
1441
+ )
1442
+ elif model_name.lower().startswith(NER_TEXT):
1443
+ from .ner_text import HFAutoModelForNER
1444
+
1445
+ model = HFAutoModelForNER(
1446
+ prefix=model_name,
1447
+ checkpoint_name=model_config.checkpoint_name,
1448
+ num_classes=num_classes,
1449
+ gradient_checkpointing=model_config.gradient_checkpointing,
1450
+ low_cpu_mem_usage=model_config.low_cpu_mem_usage,
1451
+ pretrained=pretrained,
1452
+ tokenizer_name=model_config.tokenizer_name,
1453
+ )
1454
+ elif model_name.lower().startswith(FUSION_MLP):
1455
+ from .fusion import MultimodalFusionMLP
1456
+
1457
+ model = functools.partial(
1458
+ MultimodalFusionMLP,
1459
+ prefix=model_name,
1460
+ hidden_features=model_config.hidden_sizes,
1461
+ num_classes=num_classes,
1462
+ adapt_in_features=model_config.adapt_in_features,
1463
+ activation=model_config.activation,
1464
+ dropout=model_config.dropout,
1465
+ normalization=model_config.normalization,
1466
+ aux_loss_weight=model_config.aux_loss_weight,
1467
+ )
1468
+ elif model_name.lower().startswith(FUSION_NER):
1469
+ from .fusion import MultimodalFusionNER
1470
+
1471
+ model = functools.partial(
1472
+ MultimodalFusionNER,
1473
+ prefix=model_name,
1474
+ hidden_features=model_config.hidden_sizes,
1475
+ num_classes=num_classes,
1476
+ adapt_in_features=model_config.adapt_in_features,
1477
+ activation=model_config.activation,
1478
+ dropout_prob=model_config.drop_rate,
1479
+ normalization=model_config.normalization,
1480
+ loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
1481
+ )
1482
+ elif model_name.lower().startswith(FUSION_TRANSFORMER):
1483
+ from .fusion import MultimodalFusionTransformer
1484
+
1485
+ model = functools.partial(
1486
+ MultimodalFusionTransformer,
1487
+ prefix=model_name,
1488
+ hidden_features=model_config.hidden_size,
1489
+ num_classes=num_classes,
1490
+ num_blocks=model_config.num_blocks,
1491
+ attention_num_heads=model_config.attention_num_heads,
1492
+ ffn_hidden_size=model_config.ffn_hidden_size,
1493
+ attention_dropout=model_config.attention_dropout,
1494
+ residual_dropout=model_config.residual_dropout,
1495
+ ffn_dropout=model_config.ffn_dropout,
1496
+ attention_normalization=model_config.normalization,
1497
+ ffn_normalization=model_config.normalization,
1498
+ head_normalization=model_config.normalization,
1499
+ ffn_activation=model_config.ffn_activation,
1500
+ head_activation=model_config.head_activation,
1501
+ adapt_in_features=model_config.adapt_in_features,
1502
+ aux_loss_weight=model_config.aux_loss_weight,
1503
+ additive_attention=model_config.additive_attention,
1504
+ share_qv_weights=model_config.share_qv_weights,
1505
+ )
1506
+ elif model_name.lower().startswith(FT_TRANSFORMER):
1507
+ from .ft_transformer import FT_Transformer
1508
+
1509
+ model = FT_Transformer(
1510
+ prefix=model_name,
1511
+ num_numerical_columns=num_numerical_columns,
1512
+ num_categories=num_categories,
1513
+ numerical_fill_values=numerical_fill_values,
1514
+ embedding_arch=model_config.embedding_arch,
1515
+ token_dim=model_config.token_dim,
1516
+ hidden_size=model_config.hidden_size,
1517
+ hidden_features=model_config.hidden_size,
1518
+ num_classes=num_classes,
1519
+ num_blocks=model_config.num_blocks,
1520
+ attention_num_heads=model_config.attention_num_heads,
1521
+ attention_dropout=model_config.attention_dropout,
1522
+ attention_normalization=model_config.normalization,
1523
+ ffn_hidden_size=model_config.ffn_hidden_size,
1524
+ ffn_dropout=model_config.ffn_dropout,
1525
+ ffn_normalization=model_config.normalization,
1526
+ ffn_activation=model_config.ffn_activation,
1527
+ residual_dropout=model_config.residual_dropout,
1528
+ head_normalization=model_config.normalization,
1529
+ head_activation=model_config.head_activation,
1530
+ additive_attention=model_config.additive_attention,
1531
+ share_qv_weights=model_config.share_qv_weights,
1532
+ pooling_mode=model_config.pooling_mode,
1533
+ checkpoint_name=model_config.checkpoint_name,
1534
+ pretrained=pretrained,
1535
+ )
1536
+ elif model_name.lower().startswith(SAM):
1537
+ from .sam import SAMForSemanticSegmentation
1538
+
1539
+ model = SAMForSemanticSegmentation(
1540
+ prefix=model_name,
1541
+ checkpoint_name=model_config.checkpoint_name,
1542
+ num_classes=num_classes,
1543
+ pretrained=pretrained,
1544
+ frozen_layers=model_config.frozen_layers,
1545
+ num_mask_tokens=model_config.num_mask_tokens,
1546
+ image_norm=model_config.image_norm,
1547
+ )
1548
+ elif model_name.lower().startswith(META_TRANSFORMER):
1549
+ from .meta_transformer import MetaTransformer
1550
+
1551
+ model = MetaTransformer(
1552
+ prefix=model_name,
1553
+ checkpoint_path=model_config.checkpoint_path,
1554
+ num_classes=num_classes,
1555
+ model_version=model_config.model_version,
1556
+ has_image=IMAGE in model_config.data_types,
1557
+ has_text=TEXT in model_config.data_types,
1558
+ num_numerical_columns=num_numerical_columns,
1559
+ num_categories=num_categories,
1560
+ numerical_fill_values=numerical_fill_values,
1561
+ image_size=model_config.image_size,
1562
+ image_norm=model_config.image_norm,
1563
+ image_chan_num=model_config.image_chan_num,
1564
+ use_learnable_image=model_config.use_learnable_image,
1565
+ max_text_len=model_config.max_text_len,
1566
+ text_segment_num=model_config.text_segment_num,
1567
+ )
1568
+ else:
1569
+ raise ValueError(f"unknown model name: {model_name}")
1570
+
1571
+ return model
1572
+
1573
+
1574
+ def create_fusion_model(
1575
+ config: DictConfig,
1576
+ num_classes: Optional[int] = None,
1577
+ classes: Optional[list] = None,
1578
+ num_numerical_columns: Optional[int] = None,
1579
+ num_categories: Optional[Dict] = None,
1580
+ numerical_fill_values: Optional[Dict] = None,
1581
+ pretrained: Optional[bool] = True,
1582
+ ):
1583
+ """
1584
+ Create models. It supports the auto models of huggingface text and timm image.
1585
+ Multimodal models, e.g., CLIP, should be added case-by-case since their configs and usages
1586
+ may be different. It uses MLP for the numerical features, categorical features, and late-fusion.
1587
+
1588
+ Parameters
1589
+ ----------
1590
+ config
1591
+ A DictConfig object. The model config should be accessible by "config.model".
1592
+ num_classes
1593
+ The class number for a classification task. It should be 1 for a regression task.
1594
+ classes
1595
+ All classes in this dataset.
1596
+ num_numerical_columns
1597
+ The number of numerical columns in the training dataframe.
1598
+ num_categories
1599
+ The category number for each categorical column in the training dataframe.
1600
+ numerical_fill_values
1601
+ If numerical values are null, fill them with these.
1602
+ pretrained
1603
+ Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
1604
+
1605
+ Returns
1606
+ -------
1607
+ A Pytorch model.
1608
+ """
1609
+ names = config.model.names
1610
+ if isinstance(names, str):
1611
+ names = [names]
1612
+ # make sure no duplicate model names
1613
+ assert len(names) == len(set(names))
1614
+ logger.debug(f"output_shape: {num_classes}")
1615
+ names = sorted(names)
1616
+ config.model.names = names
1617
+ single_models = []
1618
+ fusion_model = None
1619
+
1620
+ for model_name in names:
1621
+ model_config = getattr(config.model, model_name)
1622
+ model = create_model(
1623
+ model_name=model_name,
1624
+ model_config=model_config,
1625
+ num_classes=num_classes,
1626
+ classes=classes,
1627
+ num_numerical_columns=num_numerical_columns,
1628
+ num_categories=num_categories,
1629
+ numerical_fill_values=numerical_fill_values,
1630
+ pretrained=pretrained,
1631
+ )
1632
+
1633
+ if isinstance(model, functools.partial): # fusion model
1634
+ if fusion_model is None:
1635
+ fusion_model = model
1636
+ else:
1637
+ raise ValueError(
1638
+ f"More than one fusion models are detected in {names}. Only one fusion model is allowed."
1639
+ )
1640
+ else: # single model
1641
+ if config.optim.peft is not None:
1642
+ model = apply_peft_adaptation(model, config)
1643
+ single_models.append(model)
1644
+
1645
+ if len(single_models) > 1:
1646
+ # must have one fusion model if there are multiple independent models
1647
+ model = fusion_model(models=single_models)
1648
+ elif len(single_models) == 1:
1649
+ model = single_models[0]
1650
+ else:
1651
+ raise ValueError(f"No available models for {names}")
1652
+
1653
+ # build augmenter for multimodal data augmentation
1654
+ if config.optim.lemda.turn_on:
1655
+ from .fusion import MultimodalFusionMLP
1656
+
1657
+ assert isinstance(model, MultimodalFusionMLP)
1658
+ from .augmenter import Augmenter
1659
+
1660
+ augmenter = Augmenter(
1661
+ arch_type=config.optim.lemda.arch_type,
1662
+ input_dim=model.augmenter_in_features,
1663
+ z_dim=config.optim.lemda.z_dim,
1664
+ num_layers=config.optim.lemda.num_layers,
1665
+ adv_weight=config.optim.lemda.adv_weight,
1666
+ )
1667
+ model.augmenter = augmenter
1668
+
1669
+ return model
1670
+
1671
+
1672
+ def apply_peft_adaptation(model: nn.Module, config: DictConfig) -> nn.Module:
1673
+ """
1674
+ Apply an adaptation to the model for efficient fine-tuning.
1675
+
1676
+ Parameters
1677
+ ----------
1678
+ model
1679
+ A PyTorch model.
1680
+ config:
1681
+ A DictConfig object. The optimization config should be accessible by "config.optimization".
1682
+ """
1683
+ if config.optim.peft in PEFT_ADDITIVE_STRATEGIES:
1684
+ model = inject_adaptation_to_linear_layer(
1685
+ model=model,
1686
+ peft=config.optim.peft,
1687
+ lora_r=config.optim.lora.r,
1688
+ lora_alpha=config.optim.lora.alpha,
1689
+ module_filter=config.optim.lora.module_filter,
1690
+ filter=config.optim.lora.filter,
1691
+ extra_trainable_params=config.optim.extra_trainable_params,
1692
+ conv_lora_expert_num=config.optim.lora.conv_lora_expert_num,
1693
+ )
1694
+ model.name_to_id = model.get_layer_ids() # Need to update name to id dictionary.
1695
+
1696
+ return model
1697
+
1698
+
1699
+ def modify_duplicate_model_names(
1700
+ learner,
1701
+ postfix: str,
1702
+ blacklist: List[str],
1703
+ ):
1704
+ """
1705
+ Modify a learner's model names if they exist in a blacklist.
1706
+
1707
+ Parameters
1708
+ ----------
1709
+ learner
1710
+ A BaseLearner object.
1711
+ postfix
1712
+ The postfix used to change the duplicate names.
1713
+ blacklist
1714
+ A list of names. The provided learner can't use model names in the list.
1715
+
1716
+ Returns
1717
+ -------
1718
+ The learner guaranteed has no duplicate model names with the blacklist names.
1719
+ """
1720
+ model_names = []
1721
+ for n in learner._config.model.names:
1722
+ if n in blacklist:
1723
+ new_name = f"{n}_{postfix}"
1724
+ assert new_name not in blacklist
1725
+ assert new_name not in learner._config.model.names
1726
+ # modify model prefix
1727
+ if n == learner._model.prefix:
1728
+ learner._model.prefix = new_name
1729
+ else:
1730
+ assert isinstance(learner._model.model, nn.ModuleList)
1731
+ for per_model in learner._model.model:
1732
+ if n == per_model.prefix:
1733
+ per_model.prefix = new_name
1734
+ break
1735
+ # modify data processor prefix
1736
+ for per_modality_processors in learner._data_processors.values():
1737
+ for per_processor in per_modality_processors:
1738
+ if n == per_processor.prefix:
1739
+ per_processor.prefix = new_name
1740
+ # modify model config keys
1741
+ setattr(learner._config.model, new_name, getattr(learner._config.model, n))
1742
+ delattr(learner._config.model, n)
1743
+
1744
+ model_names.append(new_name)
1745
+ else:
1746
+ model_names.append(n)
1747
+
1748
+ learner._config.model.names = model_names
1749
+
1750
+ return learner
1751
+
1752
+
1753
+ def list_timm_models(pretrained=True):
1754
+ return timm.list_models(pretrained=pretrained)
1755
+
1756
+
1757
+ def is_lazy_weight_tensor(p: torch.Tensor) -> bool:
1758
+ from torch.nn.parameter import UninitializedParameter
1759
+
1760
+ if isinstance(p, UninitializedParameter):
1761
+ warnings.warn(
1762
+ "A layer with UninitializedParameter was found. "
1763
+ "Thus, the total number of parameters detected may be inaccurate."
1764
+ )
1765
+ return True
1766
+ return False