autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -8,8 +8,14 @@ from timm import create_model
8
8
  from timm.layers.linear import Linear
9
9
  from torch import nn
10
10
 
11
- from ..constants import AUTOMM, COLUMN, COLUMN_FEATURES, FEATURES, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASKS
12
- from .utils import assign_layer_ids, get_column_features, get_model_head
11
+ from ..constants import COLUMN, COLUMN_FEATURES, FEATURES, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASKS
12
+ from .utils import (
13
+ assign_layer_ids,
14
+ get_column_features,
15
+ get_image_size_mean_std,
16
+ get_model_head,
17
+ replace_missing_images_with_learnable,
18
+ )
13
19
 
14
20
  logger = logging.getLogger(__name__)
15
21
 
@@ -31,6 +37,10 @@ class TimmAutoModelForImagePrediction(nn.Module):
31
37
  num_classes: Optional[int] = 0,
32
38
  mix_choice: Optional[str] = "all_logits",
33
39
  pretrained: Optional[bool] = True,
40
+ image_size: Optional[int] = None,
41
+ image_norm: Optional[str] = None,
42
+ image_chan_num: Optional[int] = 3,
43
+ use_learnable_image: Optional[bool] = False,
34
44
  ):
35
45
  """
36
46
  Load a pretrained image backbone from TIMM.
@@ -51,10 +61,22 @@ class TimmAutoModelForImagePrediction(nn.Module):
51
61
  The logits output from individual images are averaged to generate the final output.
52
62
  pretrained
53
63
  Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
64
+ image_norm
65
+ How to normalize an image. We now support:
66
+ - inception
67
+ Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
68
+ - imagenet
69
+ Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
70
+ - clip
71
+ Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
72
+ std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
73
+ image_size
74
+ The provided width / height of a square image.
54
75
  """
55
76
  super().__init__()
56
77
  # In TIMM, if num_classes==0, then create_model would automatically set self.model.head = nn.Identity()
57
- logger.debug(f"initializing {checkpoint_name}")
78
+ logger.debug(f"initializing {prefix} (TimmAutoModelForImagePrediction)")
79
+ logger.debug(f"model checkpoint: {checkpoint_name}")
58
80
  if os.path.exists(checkpoint_name):
59
81
  checkpoint_path = f"{checkpoint_name}/pytorch_model.bin"
60
82
  try:
@@ -91,6 +113,18 @@ class TimmAutoModelForImagePrediction(nn.Module):
91
113
  logger.debug(f"mix_choice: {mix_choice}")
92
114
 
93
115
  self.prefix = prefix
116
+ self.image_size, self.image_mean, self.image_std = get_image_size_mean_std(
117
+ model_name=self.prefix,
118
+ config=self.config,
119
+ provided_size=image_size,
120
+ provided_norm_type=image_norm,
121
+ support_variable_input_size=self.support_variable_input_size(),
122
+ )
123
+ self.image_chan_num = image_chan_num
124
+ self.use_learnable_image = use_learnable_image
125
+ if self.use_learnable_image:
126
+ self.learnable_image = nn.Parameter(torch.zeros(image_chan_num, self.image_size, self.image_size))
127
+ logger.debug("will use a learnable image to replace missing ones")
94
128
 
95
129
  self.name_to_id = self.get_layer_ids()
96
130
  self.head_layer_names = [n for n, layer_id in self.name_to_id.items() if layer_id == 0]
@@ -152,6 +186,7 @@ class TimmAutoModelForImagePrediction(nn.Module):
152
186
  -------
153
187
  A dictionary with logits and features.
154
188
  """
189
+ column_features = column_feature_masks = dict()
155
190
  if self.mix_choice == "all_images": # mix inputs
156
191
  mixed_images = (
157
192
  images.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None, None, None]
@@ -162,49 +197,55 @@ class TimmAutoModelForImagePrediction(nn.Module):
162
197
  else:
163
198
  logits = features
164
199
 
165
- column_features = {}
166
- column_feature_masks = {}
167
-
168
200
  elif self.mix_choice == "all_logits": # mix outputs
169
201
  b, n, c, h, w = images.shape
202
+ steps = torch.arange(0, n).type_as(image_valid_num)
203
+ image_masks = steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1)) # (b, n)
204
+
205
+ if self.use_learnable_image:
206
+ images = replace_missing_images_with_learnable(
207
+ images=images,
208
+ image_masks=image_masks,
209
+ learnable_image=self.learnable_image,
210
+ )
170
211
  features = self.model(images.reshape((b * n, c, h, w))) # (b*n, num_features)
171
212
  if self.num_classes > 0:
172
213
  logits = self.head(features)
173
- steps = torch.arange(0, n).type_as(image_valid_num)
174
- image_masks = (steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1))).type_as(features) # (b, n)
175
- features = features.reshape((b, n, -1)) * image_masks[:, :, None] # (b, n, num_features)
214
+ logits = logits.reshape((b, n, -1)) # (b, n, num_classes)
215
+ # reshape features after head prediction
216
+ features = features.reshape((b, n, -1)) # (b, n, num_features)
176
217
 
177
- batch = {
178
- self.image_key: images,
179
- self.image_valid_num_key: image_valid_num,
180
- }
218
+ if not self.use_learnable_image:
219
+ features = features * image_masks[:, :, None].type_as(features) # (b, n, num_features)
220
+
221
+ # need to collect column features before summing them
181
222
  if image_column_names:
182
223
  assert len(image_column_names) == len(image_column_indices), "invalid image column inputs"
183
- for idx, name in enumerate(image_column_names):
184
- batch[name] = image_column_indices[idx]
185
-
186
- # collect features by image column names
187
- column_features, column_feature_masks = get_column_features(
188
- batch=batch,
189
- column_name_prefix=self.image_column_prefix,
190
- features=features,
191
- valid_lengths=image_valid_num,
192
- )
193
-
194
- features = features.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_features)
224
+ # collect features by image column names
225
+ column_features, column_feature_masks = get_column_features(
226
+ batch=dict(zip(image_column_names, image_column_indices)),
227
+ column_name_prefix=self.image_column_prefix,
228
+ features=features,
229
+ valid_lengths=image_valid_num,
230
+ )
231
+
232
+ if self.use_learnable_image:
233
+ features = features.mean(dim=1)
234
+ else:
235
+ features = features.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_features)
195
236
  if self.num_classes > 0:
196
- logits = logits.reshape((b, n, -1)) * image_masks[:, :, None] # (b, n, num_classes)
197
- logits = logits.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_classes)
237
+ if self.use_learnable_image:
238
+ logits = logits.mean(dim=1)
239
+ else:
240
+ logits = logits * image_masks[:, :, None].type_as(logits) # (b, n, num_classes)
241
+ logits = logits.sum(dim=1) / torch.clamp(image_valid_num, min=1e-6)[:, None] # (b, num_classes)
198
242
  else:
199
243
  logits = features
200
244
 
201
245
  else:
202
246
  raise ValueError(f"unknown mix_choice: {self.mix_choice}")
203
247
 
204
- if column_features == {} or column_feature_masks == {}:
205
- return features, logits
206
- else:
207
- return features, logits, column_features, column_feature_masks
248
+ return features, logits, column_features, column_feature_masks
208
249
 
209
250
  def get_output_dict(
210
251
  self,
@@ -215,7 +256,8 @@ class TimmAutoModelForImagePrediction(nn.Module):
215
256
  ):
216
257
  ret = {COLUMN_FEATURES: {FEATURES: {}, MASKS: {}}}
217
258
 
218
- if column_features != None:
259
+ if column_features is not None and len(column_features) > 0:
260
+ assert column_feature_masks is not None and len(column_features) == len(column_feature_masks)
219
261
  ret[COLUMN_FEATURES][FEATURES].update(column_features)
220
262
  ret[COLUMN_FEATURES][MASKS].update(column_feature_masks)
221
263