autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -8,8 +8,14 @@ from timm import create_model
|
|
8
8
|
from timm.layers.linear import Linear
|
9
9
|
from torch import nn
|
10
10
|
|
11
|
-
from ..constants import
|
12
|
-
from .utils import
|
11
|
+
from ..constants import COLUMN, COLUMN_FEATURES, FEATURES, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASKS
|
12
|
+
from .utils import (
|
13
|
+
assign_layer_ids,
|
14
|
+
get_column_features,
|
15
|
+
get_image_size_mean_std,
|
16
|
+
get_model_head,
|
17
|
+
replace_missing_images_with_learnable,
|
18
|
+
)
|
13
19
|
|
14
20
|
logger = logging.getLogger(__name__)
|
15
21
|
|
@@ -31,6 +37,10 @@ class TimmAutoModelForImagePrediction(nn.Module):
|
|
31
37
|
num_classes: Optional[int] = 0,
|
32
38
|
mix_choice: Optional[str] = "all_logits",
|
33
39
|
pretrained: Optional[bool] = True,
|
40
|
+
image_size: Optional[int] = None,
|
41
|
+
image_norm: Optional[str] = None,
|
42
|
+
image_chan_num: Optional[int] = 3,
|
43
|
+
use_learnable_image: Optional[bool] = False,
|
34
44
|
):
|
35
45
|
"""
|
36
46
|
Load a pretrained image backbone from TIMM.
|
@@ -51,10 +61,22 @@ class TimmAutoModelForImagePrediction(nn.Module):
|
|
51
61
|
The logits output from individual images are averaged to generate the final output.
|
52
62
|
pretrained
|
53
63
|
Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
|
64
|
+
image_norm
|
65
|
+
How to normalize an image. We now support:
|
66
|
+
- inception
|
67
|
+
Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
|
68
|
+
- imagenet
|
69
|
+
Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
|
70
|
+
- clip
|
71
|
+
Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
|
72
|
+
std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
|
73
|
+
image_size
|
74
|
+
The provided width / height of a square image.
|
54
75
|
"""
|
55
76
|
super().__init__()
|
56
77
|
# In TIMM, if num_classes==0, then create_model would automatically set self.model.head = nn.Identity()
|
57
|
-
logger.debug(f"initializing {
|
78
|
+
logger.debug(f"initializing {prefix} (TimmAutoModelForImagePrediction)")
|
79
|
+
logger.debug(f"model checkpoint: {checkpoint_name}")
|
58
80
|
if os.path.exists(checkpoint_name):
|
59
81
|
checkpoint_path = f"{checkpoint_name}/pytorch_model.bin"
|
60
82
|
try:
|
@@ -91,6 +113,18 @@ class TimmAutoModelForImagePrediction(nn.Module):
|
|
91
113
|
logger.debug(f"mix_choice: {mix_choice}")
|
92
114
|
|
93
115
|
self.prefix = prefix
|
116
|
+
self.image_size, self.image_mean, self.image_std = get_image_size_mean_std(
|
117
|
+
model_name=self.prefix,
|
118
|
+
config=self.config,
|
119
|
+
provided_size=image_size,
|
120
|
+
provided_norm_type=image_norm,
|
121
|
+
support_variable_input_size=self.support_variable_input_size(),
|
122
|
+
)
|
123
|
+
self.image_chan_num = image_chan_num
|
124
|
+
self.use_learnable_image = use_learnable_image
|
125
|
+
if self.use_learnable_image:
|
126
|
+
self.learnable_image = nn.Parameter(torch.zeros(image_chan_num, self.image_size, self.image_size))
|
127
|
+
logger.debug("will use a learnable image to replace missing ones")
|
94
128
|
|
95
129
|
self.name_to_id = self.get_layer_ids()
|
96
130
|
self.head_layer_names = [n for n, layer_id in self.name_to_id.items() if layer_id == 0]
|
@@ -152,6 +186,7 @@ class TimmAutoModelForImagePrediction(nn.Module):
|
|
152
186
|
-------
|
153
187
|
A dictionary with logits and features.
|
154
188
|
"""
|
189
|
+
column_features = column_feature_masks = dict()
|
155
190
|
if self.mix_choice == "all_images": # mix inputs
|
156
191
|
mixed_images = (
|
157
192
|
images.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None, None, None]
|
@@ -162,49 +197,55 @@ class TimmAutoModelForImagePrediction(nn.Module):
|
|
162
197
|
else:
|
163
198
|
logits = features
|
164
199
|
|
165
|
-
column_features = {}
|
166
|
-
column_feature_masks = {}
|
167
|
-
|
168
200
|
elif self.mix_choice == "all_logits": # mix outputs
|
169
201
|
b, n, c, h, w = images.shape
|
202
|
+
steps = torch.arange(0, n).type_as(image_valid_num)
|
203
|
+
image_masks = steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1)) # (b, n)
|
204
|
+
|
205
|
+
if self.use_learnable_image:
|
206
|
+
images = replace_missing_images_with_learnable(
|
207
|
+
images=images,
|
208
|
+
image_masks=image_masks,
|
209
|
+
learnable_image=self.learnable_image,
|
210
|
+
)
|
170
211
|
features = self.model(images.reshape((b * n, c, h, w))) # (b*n, num_features)
|
171
212
|
if self.num_classes > 0:
|
172
213
|
logits = self.head(features)
|
173
|
-
|
174
|
-
|
175
|
-
features = features.reshape((b, n, -1))
|
214
|
+
logits = logits.reshape((b, n, -1)) # (b, n, num_classes)
|
215
|
+
# reshape features after head prediction
|
216
|
+
features = features.reshape((b, n, -1)) # (b, n, num_features)
|
176
217
|
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
218
|
+
if not self.use_learnable_image:
|
219
|
+
features = features * image_masks[:, :, None].type_as(features) # (b, n, num_features)
|
220
|
+
|
221
|
+
# need to collect column features before summing them
|
181
222
|
if image_column_names:
|
182
223
|
assert len(image_column_names) == len(image_column_indices), "invalid image column inputs"
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
224
|
+
# collect features by image column names
|
225
|
+
column_features, column_feature_masks = get_column_features(
|
226
|
+
batch=dict(zip(image_column_names, image_column_indices)),
|
227
|
+
column_name_prefix=self.image_column_prefix,
|
228
|
+
features=features,
|
229
|
+
valid_lengths=image_valid_num,
|
230
|
+
)
|
231
|
+
|
232
|
+
if self.use_learnable_image:
|
233
|
+
features = features.mean(dim=1)
|
234
|
+
else:
|
235
|
+
features = features.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_features)
|
195
236
|
if self.num_classes > 0:
|
196
|
-
|
197
|
-
|
237
|
+
if self.use_learnable_image:
|
238
|
+
logits = logits.mean(dim=1)
|
239
|
+
else:
|
240
|
+
logits = logits * image_masks[:, :, None].type_as(logits) # (b, n, num_classes)
|
241
|
+
logits = logits.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_classes)
|
198
242
|
else:
|
199
243
|
logits = features
|
200
244
|
|
201
245
|
else:
|
202
246
|
raise ValueError(f"unknown mix_choice: {self.mix_choice}")
|
203
247
|
|
204
|
-
|
205
|
-
return features, logits
|
206
|
-
else:
|
207
|
-
return features, logits, column_features, column_feature_masks
|
248
|
+
return features, logits, column_features, column_feature_masks
|
208
249
|
|
209
250
|
def get_output_dict(
|
210
251
|
self,
|
@@ -215,7 +256,8 @@ class TimmAutoModelForImagePrediction(nn.Module):
|
|
215
256
|
):
|
216
257
|
ret = {COLUMN_FEATURES: {FEATURES: {}, MASKS: {}}}
|
217
258
|
|
218
|
-
if column_features
|
259
|
+
if column_features is not None and len(column_features) > 0:
|
260
|
+
assert column_feature_masks is not None and len(column_features) == len(column_feature_masks)
|
219
261
|
ret[COLUMN_FEATURES][FEATURES].update(column_features)
|
220
262
|
ret[COLUMN_FEATURES][MASKS].update(column_feature_masks)
|
221
263
|
|