autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,558 +0,0 @@
1
- import functools
2
- import json
3
- import logging
4
- import warnings
5
- from typing import Dict, List, Optional, Tuple, Union
6
-
7
- import timm
8
- from omegaconf import DictConfig, OmegaConf
9
- from torch import Tensor, nn
10
-
11
- from ..constants import (
12
- ALL_MODALITIES,
13
- AUTOMM,
14
- CATEGORICAL,
15
- CATEGORICAL_MLP,
16
- CLIP,
17
- DOCUMENT,
18
- DOCUMENT_TRANSFORMER,
19
- FT_TRANSFORMER,
20
- FUSION_MLP,
21
- FUSION_NER,
22
- FUSION_TRANSFORMER,
23
- HF_TEXT,
24
- IMAGE,
25
- MMDET_IMAGE,
26
- MMOCR_TEXT_DET,
27
- MMOCR_TEXT_RECOG,
28
- NER,
29
- NER_TEXT,
30
- NUMERICAL,
31
- NUMERICAL_MLP,
32
- PEFT_ADDITIVE_STRATEGIES,
33
- SAM,
34
- SEMANTIC_SEGMENTATION_IMG,
35
- T_FEW,
36
- TEXT,
37
- TEXT_NER,
38
- TIMM_IMAGE,
39
- XYXY,
40
- )
41
- from ..data import MultiModalFeaturePreprocessor
42
- from ..models import (
43
- CategoricalMLP,
44
- CLIPForImageText,
45
- DocumentTransformer,
46
- FT_Transformer,
47
- HFAutoModelForNER,
48
- HFAutoModelForTextPrediction,
49
- MMDetAutoModelForObjectDetection,
50
- MMOCRAutoModelForTextDetection,
51
- MMOCRAutoModelForTextRecognition,
52
- MultimodalFusionMLP,
53
- MultimodalFusionNER,
54
- MultimodalFusionTransformer,
55
- NumericalMLP,
56
- SAMForSemanticSegmentation,
57
- TFewModel,
58
- TimmAutoModelForImagePrediction,
59
- )
60
- from ..models.utils import inject_adaptation_to_linear_layer
61
-
62
- logger = logging.getLogger(__name__)
63
-
64
-
65
- def select_model(
66
- config: DictConfig,
67
- df_preprocessor: MultiModalFeaturePreprocessor,
68
- strict: Optional[bool] = True,
69
- ):
70
- """
71
- Filter model config through the detected modalities in the training data.
72
- If MultiModalFeaturePreprocessor can't detect some modality,
73
- this function will remove the models that use this modality. This function is to
74
- maximize the user flexibility in defining the config.
75
- For example, if one uses the default, including hf_text and timm_image, as the model config template
76
- but the training data don't have images, this function will filter out timm_image.
77
-
78
- Parameters
79
- ----------
80
- config
81
- A DictConfig object. The model config should be accessible by "config.model"
82
- df_preprocessor
83
- A MultiModalFeaturePreprocessor object, which has called .fit() on the training data.
84
- Column names of the same modality are grouped into one list. If a modality's list is empty,
85
- it means the training data don't have this modality.
86
- strict
87
- If False, allow retaining one model when partial modalities are available for that model.
88
-
89
- Returns
90
- -------
91
- Config with some unused models removed.
92
- """
93
- data_status = {}
94
- for per_modality in ALL_MODALITIES:
95
- data_status[per_modality] = False
96
- if len(df_preprocessor.image_feature_names) > 0:
97
- data_status[IMAGE] = True
98
- if len(df_preprocessor.text_feature_names) > 0:
99
- data_status[TEXT] = True
100
- if len(df_preprocessor.categorical_feature_names) > 0:
101
- data_status[CATEGORICAL] = True
102
- if len(df_preprocessor.numerical_feature_names) > 0:
103
- data_status[NUMERICAL] = True
104
- if len(df_preprocessor.ner_feature_names) > 0:
105
- data_status[TEXT_NER] = True
106
- if len(df_preprocessor.document_feature_names) > 0:
107
- data_status[DOCUMENT] = True
108
- if len(df_preprocessor.semantic_segmentation_feature_names) > 0:
109
- data_status[SEMANTIC_SEGMENTATION_IMG] = True
110
-
111
- names = config.model.names
112
- if isinstance(names, str):
113
- names = [names]
114
- selected_model_names = []
115
- fusion_model_name = []
116
- for model_name in names:
117
- model_config = getattr(config.model, model_name)
118
- strict = getattr(model_config, "requires_all_dtypes", strict)
119
- if not model_config.data_types:
120
- fusion_model_name.append(model_name)
121
- continue
122
- model_data_status = [data_status[d_type] for d_type in model_config.data_types]
123
- if all(model_data_status):
124
- selected_model_names.append(model_name)
125
- else:
126
- if any(model_data_status) and not strict:
127
- selected_model_names.append(model_name)
128
- else:
129
- delattr(config.model, model_name)
130
-
131
- if len(selected_model_names) == 0:
132
- raise ValueError("No model is available for this dataset.")
133
- # only allow no more than 1 fusion model
134
- if len(fusion_model_name) > 1:
135
- raise ValueError(f"More than one fusion models `{fusion_model_name}` are detected, but only one is allowed.")
136
-
137
- if len(selected_model_names) > 1:
138
- assert len(fusion_model_name) == 1
139
- selected_model_names.extend(fusion_model_name)
140
- elif len(fusion_model_name) == 1 and hasattr(config.model, fusion_model_name[0]):
141
- delattr(config.model, fusion_model_name[0])
142
-
143
- config.model.names = selected_model_names
144
- logger.debug(f"selected models: {selected_model_names}")
145
- for model_name in selected_model_names:
146
- logger.debug(f"model dtypes: {getattr(config.model, model_name).data_types}")
147
-
148
- # clean up unused model configs
149
- model_keys = list(config.model.keys())
150
- for model_name in model_keys:
151
- if model_name not in selected_model_names + ["names"]:
152
- delattr(config.model, model_name)
153
-
154
- return config
155
-
156
-
157
- def create_model(
158
- model_name: str,
159
- model_config: DictConfig,
160
- num_classes: Optional[int] = 0,
161
- classes: Optional[list] = None,
162
- num_numerical_columns: Optional[int] = None,
163
- num_categories: Optional[List[int]] = None,
164
- pretrained: Optional[bool] = True,
165
- ):
166
- """
167
- Create a single model.
168
-
169
- Parameters
170
- ----------
171
- model_name
172
- Name of the model.
173
- model_config
174
- Config of the model.
175
- num_classes
176
- The class number for a classification task. It should be 1 for a regression task.
177
- classes
178
- All classes in this dataset.
179
- num_numerical_columns
180
- The number of numerical columns in the training dataframe.
181
- num_categories
182
- The category number for each categorical column in the training dataframe.
183
- pretrained
184
- Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
185
-
186
- Returns
187
- -------
188
- A model.
189
- """
190
- if model_name.lower().startswith(CLIP):
191
- model = CLIPForImageText(
192
- prefix=model_name,
193
- checkpoint_name=model_config.checkpoint_name,
194
- num_classes=num_classes,
195
- pretrained=pretrained,
196
- tokenizer_name=model_config.tokenizer_name,
197
- )
198
- elif model_name.lower().startswith(TIMM_IMAGE):
199
- model = TimmAutoModelForImagePrediction(
200
- prefix=model_name,
201
- checkpoint_name=model_config.checkpoint_name,
202
- num_classes=num_classes,
203
- mix_choice=model_config.mix_choice,
204
- pretrained=pretrained,
205
- )
206
- elif model_name.lower().startswith(HF_TEXT):
207
- model = HFAutoModelForTextPrediction(
208
- prefix=model_name,
209
- checkpoint_name=model_config.checkpoint_name,
210
- num_classes=num_classes,
211
- pooling_mode=OmegaConf.select(model_config, "pooling_mode", default="cls"),
212
- gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
213
- low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
214
- pretrained=pretrained,
215
- tokenizer_name=model_config.tokenizer_name,
216
- use_fast=OmegaConf.select(model_config, "use_fast", default=True),
217
- )
218
- elif model_name.lower().startswith(T_FEW):
219
- model = TFewModel(
220
- prefix=model_name,
221
- checkpoint_name=model_config.checkpoint_name,
222
- length_norm=model_config.length_norm, # Normalizes length to adjust for length bias in target template
223
- unlikely_loss=model_config.unlikely_loss, # Adds loss term that lowers probability of incorrect outputs
224
- mc_loss=model_config.mc_loss, # Adds multiple choice cross entropy loss
225
- num_classes=num_classes,
226
- gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
227
- low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
228
- pretrained=pretrained,
229
- tokenizer_name=model_config.tokenizer_name,
230
- )
231
- elif model_name.lower().startswith(NUMERICAL_MLP):
232
- model = NumericalMLP(
233
- prefix=model_name,
234
- in_features=num_numerical_columns,
235
- hidden_features=model_config.hidden_size,
236
- out_features=model_config.hidden_size,
237
- num_layers=model_config.num_layers,
238
- activation=model_config.activation,
239
- dropout_prob=model_config.drop_rate,
240
- normalization=model_config.normalization,
241
- d_token=OmegaConf.select(model_config, "d_token"),
242
- embedding_arch=OmegaConf.select(model_config, "embedding_arch"),
243
- num_classes=num_classes,
244
- )
245
- elif model_name.lower().startswith(CATEGORICAL_MLP):
246
- model = CategoricalMLP(
247
- prefix=model_name,
248
- num_categories=num_categories,
249
- out_features=model_config.hidden_size,
250
- num_layers=model_config.num_layers,
251
- activation=model_config.activation,
252
- dropout_prob=model_config.drop_rate,
253
- normalization=model_config.normalization,
254
- num_classes=num_classes,
255
- )
256
- elif model_name.lower().startswith(DOCUMENT_TRANSFORMER):
257
- model = DocumentTransformer(
258
- prefix=model_name,
259
- checkpoint_name=model_config.checkpoint_name,
260
- num_classes=num_classes,
261
- pooling_mode=OmegaConf.select(model_config, "pooling_mode", default="cls"),
262
- gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
263
- low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
264
- pretrained=pretrained,
265
- tokenizer_name=model_config.tokenizer_name,
266
- )
267
- elif model_name.lower().startswith(MMDET_IMAGE):
268
- model = MMDetAutoModelForObjectDetection(
269
- prefix=model_name,
270
- checkpoint_name=model_config.checkpoint_name,
271
- config_file=OmegaConf.select(model_config, "config_file", default=None),
272
- classes=classes,
273
- pretrained=pretrained,
274
- output_bbox_format=OmegaConf.select(model_config, "output_bbox_format", default=XYXY),
275
- frozen_layers=OmegaConf.select(model_config, "frozen_layers", default=None),
276
- )
277
- elif model_name.lower().startswith(MMOCR_TEXT_DET):
278
- model = MMOCRAutoModelForTextDetection(
279
- prefix=model_name,
280
- checkpoint_name=model_config.checkpoint_name,
281
- )
282
- elif model_name.lower().startswith(MMOCR_TEXT_RECOG):
283
- model = MMOCRAutoModelForTextRecognition(
284
- prefix=model_name,
285
- checkpoint_name=model_config.checkpoint_name,
286
- )
287
- elif model_name.lower().startswith(NER_TEXT):
288
- model = HFAutoModelForNER(
289
- prefix=model_name,
290
- checkpoint_name=model_config.checkpoint_name,
291
- num_classes=num_classes,
292
- gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
293
- low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
294
- pretrained=pretrained,
295
- tokenizer_name=model_config.tokenizer_name,
296
- )
297
- elif model_name.lower().startswith(FUSION_MLP):
298
- model = functools.partial(
299
- MultimodalFusionMLP,
300
- prefix=model_name,
301
- hidden_features=model_config.hidden_sizes,
302
- num_classes=num_classes,
303
- adapt_in_features=model_config.adapt_in_features,
304
- activation=model_config.activation,
305
- dropout_prob=model_config.drop_rate,
306
- normalization=model_config.normalization,
307
- loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
308
- )
309
- elif model_name.lower().startswith(FUSION_NER):
310
- model = functools.partial(
311
- MultimodalFusionNER,
312
- prefix=model_name,
313
- hidden_features=model_config.hidden_sizes,
314
- num_classes=num_classes,
315
- adapt_in_features=model_config.adapt_in_features,
316
- activation=model_config.activation,
317
- dropout_prob=model_config.drop_rate,
318
- normalization=model_config.normalization,
319
- loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
320
- )
321
- elif model_name.lower().startswith(FUSION_TRANSFORMER):
322
- model = functools.partial(
323
- MultimodalFusionTransformer,
324
- prefix=model_name,
325
- hidden_features=model_config.hidden_size,
326
- num_classes=num_classes,
327
- n_blocks=model_config.n_blocks,
328
- attention_n_heads=model_config.attention_n_heads,
329
- ffn_d_hidden=model_config.ffn_d_hidden,
330
- attention_dropout=model_config.attention_dropout,
331
- residual_dropout=model_config.residual_dropout,
332
- ffn_dropout=model_config.ffn_dropout,
333
- attention_normalization=model_config.normalization,
334
- ffn_normalization=model_config.normalization,
335
- head_normalization=model_config.normalization,
336
- ffn_activation=model_config.ffn_activation,
337
- head_activation=model_config.head_activation,
338
- adapt_in_features=model_config.adapt_in_features,
339
- loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
340
- additive_attention=OmegaConf.select(model_config, "additive_attention", default=False),
341
- share_qv_weights=OmegaConf.select(model_config, "share_qv_weights", default=False),
342
- )
343
- elif model_name.lower().startswith(FT_TRANSFORMER):
344
- model = FT_Transformer(
345
- prefix=model_name,
346
- num_numerical_columns=num_numerical_columns,
347
- num_categories=num_categories,
348
- embedding_arch=model_config.embedding_arch,
349
- token_dim=model_config.token_dim,
350
- hidden_size=model_config.hidden_size,
351
- hidden_features=model_config.hidden_size,
352
- num_classes=num_classes,
353
- num_blocks=model_config.num_blocks,
354
- attention_n_heads=model_config.attention_n_heads,
355
- attention_dropout=model_config.attention_dropout,
356
- attention_normalization=model_config.normalization,
357
- ffn_hidden_size=model_config.ffn_hidden_size,
358
- ffn_dropout=model_config.ffn_dropout,
359
- ffn_normalization=model_config.normalization,
360
- ffn_activation=model_config.ffn_activation,
361
- residual_dropout=model_config.residual_dropout,
362
- head_normalization=model_config.normalization,
363
- head_activation=model_config.head_activation,
364
- additive_attention=OmegaConf.select(model_config, "additive_attention", default=False),
365
- share_qv_weights=OmegaConf.select(model_config, "share_qv_weights", default=False),
366
- pooling_mode=OmegaConf.select(model_config, "pooling_mode", default="cls"),
367
- checkpoint_name=model_config.checkpoint_name,
368
- pretrained=pretrained,
369
- )
370
- elif model_name.lower().startswith(SAM):
371
- model = SAMForSemanticSegmentation(
372
- prefix=model_name,
373
- checkpoint_name=model_config.checkpoint_name,
374
- num_classes=num_classes,
375
- pretrained=pretrained,
376
- frozen_layers=OmegaConf.select(model_config, "frozen_layers", default=None),
377
- num_mask_tokens=OmegaConf.select(model_config, "num_mask_tokens", default=1),
378
- )
379
- else:
380
- raise ValueError(f"unknown model name: {model_name}")
381
-
382
- return model
383
-
384
-
385
- def create_fusion_model(
386
- config: DictConfig,
387
- num_classes: Optional[int] = None,
388
- classes: Optional[list] = None,
389
- num_numerical_columns: Optional[int] = None,
390
- num_categories: Optional[List[int]] = None,
391
- pretrained: Optional[bool] = True,
392
- ):
393
- """
394
- Create models. It supports the auto models of huggingface text and timm image.
395
- Multimodal models, e.g., CLIP, should be added case-by-case since their configs and usages
396
- may be different. It uses MLP for the numerical features, categorical features, and late-fusion.
397
-
398
- Parameters
399
- ----------
400
- config
401
- A DictConfig object. The model config should be accessible by "config.model".
402
- num_classes
403
- The class number for a classification task. It should be 1 for a regression task.
404
- classes
405
- All classes in this dataset.
406
- num_numerical_columns
407
- The number of numerical columns in the training dataframe.
408
- num_categories
409
- The category number for each categorical column in the training dataframe.
410
- pretrained
411
- Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
412
-
413
- Returns
414
- -------
415
- A Pytorch model.
416
- """
417
- names = config.model.names
418
- if isinstance(names, str):
419
- names = [names]
420
- # make sure no duplicate model names
421
- assert len(names) == len(set(names))
422
- logger.debug(f"output_shape: {num_classes}")
423
- names = sorted(names)
424
- config.model.names = names
425
- single_models = []
426
- fusion_model = None
427
-
428
- for model_name in names:
429
- model_config = getattr(config.model, model_name)
430
- model = create_model(
431
- model_name=model_name,
432
- model_config=model_config,
433
- num_classes=num_classes,
434
- classes=classes,
435
- num_numerical_columns=num_numerical_columns,
436
- num_categories=num_categories,
437
- pretrained=pretrained,
438
- )
439
-
440
- if isinstance(model, functools.partial): # fusion model
441
- if fusion_model is None:
442
- fusion_model = model
443
- else:
444
- raise ValueError(
445
- f"More than one fusion models are detected in {names}. Only one fusion model is allowed."
446
- )
447
- else: # single model
448
- if (
449
- OmegaConf.select(config, "optimization.efficient_finetune") is not None
450
- and OmegaConf.select(config, "optimization.efficient_finetune") != "None"
451
- ):
452
- model = apply_model_adaptation(model, config)
453
- single_models.append(model)
454
-
455
- if len(single_models) > 1:
456
- # must have one fusion model if there are multiple independent models
457
- return fusion_model(models=single_models)
458
- elif len(single_models) == 1:
459
- return single_models[0]
460
- else:
461
- raise ValueError(f"No available models for {names}")
462
-
463
-
464
- def apply_model_adaptation(model: nn.Module, config: DictConfig) -> nn.Module:
465
- """
466
- Apply an adaptation to the model for efficient fine-tuning.
467
-
468
- Parameters
469
- ----------
470
- model
471
- A PyTorch model.
472
- config:
473
- A DictConfig object. The optimization config should be accessible by "config.optimization".
474
- """
475
- if OmegaConf.select(config, "optimization.efficient_finetune") in PEFT_ADDITIVE_STRATEGIES:
476
- model = inject_adaptation_to_linear_layer(
477
- model=model,
478
- efficient_finetune=OmegaConf.select(config, "optimization.efficient_finetune"),
479
- lora_r=config.optimization.lora.r,
480
- lora_alpha=config.optimization.lora.alpha,
481
- module_filter=config.optimization.lora.module_filter,
482
- filter=config.optimization.lora.filter,
483
- extra_trainable_params=OmegaConf.select(config, "optimization.extra_trainable_params"),
484
- conv_lora_expert_num=config.optimization.lora.conv_lora_expert_num,
485
- )
486
- model.name_to_id = model.get_layer_ids() # Need to update name to id dictionary.
487
-
488
- return model
489
-
490
-
491
- def modify_duplicate_model_names(
492
- learner,
493
- postfix: str,
494
- blacklist: List[str],
495
- ):
496
- """
497
- Modify a learner's model names if they exist in a blacklist.
498
-
499
- Parameters
500
- ----------
501
- learner
502
- A BaseLearner object.
503
- postfix
504
- The postfix used to change the duplicate names.
505
- blacklist
506
- A list of names. The provided learner can't use model names in the list.
507
-
508
- Returns
509
- -------
510
- The learner guaranteed has no duplicate model names with the blacklist names.
511
- """
512
- model_names = []
513
- for n in learner._config.model.names:
514
- if n in blacklist:
515
- new_name = f"{n}_{postfix}"
516
- assert new_name not in blacklist
517
- assert new_name not in learner._config.model.names
518
- # modify model prefix
519
- if n == learner._model.prefix:
520
- learner._model.prefix = new_name
521
- else:
522
- assert isinstance(learner._model.model, nn.ModuleList)
523
- for per_model in learner._model.model:
524
- if n == per_model.prefix:
525
- per_model.prefix = new_name
526
- break
527
- # modify data processor prefix
528
- for per_modality_processors in learner._data_processors.values():
529
- for per_processor in per_modality_processors:
530
- if n == per_processor.prefix:
531
- per_processor.prefix = new_name
532
- # modify model config keys
533
- setattr(learner._config.model, new_name, getattr(learner._config.model, n))
534
- delattr(learner._config.model, n)
535
-
536
- model_names.append(new_name)
537
- else:
538
- model_names.append(n)
539
-
540
- learner._config.model.names = model_names
541
-
542
- return learner
543
-
544
-
545
- def list_timm_models(pretrained=True):
546
- return timm.list_models(pretrained=pretrained)
547
-
548
-
549
- def is_lazy_weight_tensor(p: Tensor) -> bool:
550
- from torch.nn.parameter import UninitializedParameter
551
-
552
- if isinstance(p, UninitializedParameter):
553
- warnings.warn(
554
- "A layer with UninitializedParameter was found. "
555
- "Thus, the total number of parameters detected may be inaccurate."
556
- )
557
- return True
558
- return False