autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,130 @@
1
+ import contextlib
2
+ import logging
3
+ import warnings
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def convert_to_torch_precision(precision: Union[int, str]):
12
+ """
13
+ Convert a precision integer or string to the corresponding torch precision.
14
+
15
+ Parameters
16
+ ----------
17
+ precision
18
+ a precision integer or string from the config.
19
+
20
+ Returns
21
+ -------
22
+ A torch precision object.
23
+ """
24
+ precision_mapping = {
25
+ 16: torch.half,
26
+ "16": torch.half,
27
+ "16-mixed": torch.half,
28
+ "16-true": torch.half,
29
+ "bf16": torch.bfloat16,
30
+ "bf16-mixed": torch.bfloat16,
31
+ "bf16-true": torch.bfloat16,
32
+ 32: torch.float32,
33
+ "32": torch.float32,
34
+ "32-true": torch.float32,
35
+ 64: torch.float64,
36
+ "64": torch.float64,
37
+ "64-true": torch.float64,
38
+ }
39
+
40
+ if precision in precision_mapping:
41
+ precision = precision_mapping[precision]
42
+ else:
43
+ raise ValueError(f"Unknown precision: {precision}")
44
+
45
+ return precision
46
+
47
+
48
+ def infer_precision(
49
+ num_gpus: int, precision: Union[int, str], as_torch: Optional[bool] = False, cpu_only_warning: bool = True
50
+ ):
51
+ """
52
+ Infer the proper precision based on the environment setup and the provided precision.
53
+
54
+ Parameters
55
+ ----------
56
+ num_gpus
57
+ GPU number.
58
+ precision
59
+ The precision provided in config.
60
+ as_torch
61
+ Whether to convert the precision to the Pytorch format.
62
+ cpu_only_warning
63
+ Whether to turn on warning if the instance has only CPU.
64
+
65
+ Returns
66
+ -------
67
+ The inferred precision.
68
+ """
69
+ if num_gpus == 0: # CPU only prediction
70
+ if cpu_only_warning:
71
+ warnings.warn(
72
+ "Only CPU is detected in the instance. "
73
+ "This may result in slow speed for MultiModalPredictor. "
74
+ "Consider using an instance with GPU support.",
75
+ UserWarning,
76
+ )
77
+ precision = 32 # Force to use fp32 for training since 16-mixed is not available in CPU
78
+ else:
79
+ if isinstance(precision, str) and "bf16" in precision and not torch.cuda.is_bf16_supported():
80
+ warnings.warn(
81
+ f"{precision} is not supported by the GPU device / cuda version. "
82
+ "Consider using GPU devices with versions after Amphere or upgrading cuda to be >=11.0. "
83
+ f"MultiModalPredictor is switching precision from {precision} to 32.",
84
+ UserWarning,
85
+ )
86
+ precision = 32
87
+
88
+ if as_torch:
89
+ precision = convert_to_torch_precision(precision=precision)
90
+
91
+ return precision
92
+
93
+
94
+ @contextlib.contextmanager
95
+ def double_precision_context():
96
+ """
97
+ Double precision context manager.
98
+ """
99
+ default_dtype = torch.get_default_dtype()
100
+ torch.set_default_dtype(torch.float64)
101
+ yield
102
+ torch.set_default_dtype(default_dtype)
103
+
104
+
105
+ def get_precision_context(precision: Union[int, str], device_type: Optional[str] = None):
106
+ """
107
+ Choose the proper context manager based on the precision.
108
+
109
+ Parameters
110
+ ----------
111
+ precision
112
+ The precision.
113
+ device_type
114
+ gpu or cpu.
115
+
116
+ Returns
117
+ -------
118
+ A precision context manager.
119
+ """
120
+ precision = convert_to_torch_precision(precision=precision)
121
+
122
+ if precision in [torch.half, torch.float16, torch.bfloat16]:
123
+ return torch.autocast(device_type=device_type, dtype=precision)
124
+ if precision == torch.float32:
125
+ assert torch.get_default_dtype() == torch.float32
126
+ return contextlib.nullcontext()
127
+ elif precision == torch.float64:
128
+ return double_precision_context()
129
+ else:
130
+ raise ValueError(f"Unknown precision: {precision}")
@@ -2,17 +2,17 @@ from typing import List, Optional
2
2
 
3
3
  from autogluon.common.utils.try_import import try_import_ray
4
4
 
5
- from .constants import (
5
+ from ..constants import (
6
6
  BEST_QUALITY,
7
7
  BINARY,
8
8
  DATA,
9
9
  DEFAULT,
10
- ENVIRONMENT,
10
+ ENV,
11
11
  HIGH_QUALITY,
12
12
  MEDIUM_QUALITY,
13
13
  MODEL,
14
14
  MULTICLASS,
15
- OPTIMIZATION,
15
+ OPTIM,
16
16
  REGRESSION,
17
17
  )
18
18
  from .registry import Registry
@@ -32,9 +32,9 @@ def get_default_hpo_setup():
32
32
  }
33
33
 
34
34
  default_tunable_hyperparameters = {
35
- "optimization.learning_rate": tune.loguniform(1e-5, 1e-2),
36
- "optimization.optim_type": tune.choice(["adamw", "sgd"]),
37
- "optimization.max_epochs": tune.choice(list(range(5, 31))),
35
+ "optim.lr": tune.loguniform(1e-5, 1e-2),
36
+ "optim.optim_type": tune.choice(["adamw", "sgd"]),
37
+ "optim.max_epochs": tune.choice(list(range(5, 31))),
38
38
  "env.batch_size": tune.choice([16, 32, 64, 128, 256]),
39
39
  }
40
40
 
@@ -146,7 +146,7 @@ def default(presets: str = DEFAULT):
146
146
  "model.hf_text.checkpoint_name": "google/electra-small-discriminator",
147
147
  "model.timm_image.checkpoint_name": "mobilenetv3_large_100",
148
148
  "model.document_transformer.checkpoint_name": "microsoft/layoutlmv2-base-uncased",
149
- "optimization.learning_rate": 4e-4,
149
+ "optim.lr": 4e-4,
150
150
  }
151
151
  )
152
152
  elif presets == BEST_QUALITY:
@@ -186,7 +186,7 @@ def default(presets: str = DEFAULT):
186
186
  hyperparameters.update(
187
187
  {
188
188
  "model.hf_text.checkpoint_name": "microsoft/mdeberta-v3-base",
189
- "optimization.top_k": 1,
189
+ "optim.top_k": 1,
190
190
  "env.precision": "bf16-mixed",
191
191
  "env.per_gpu_batch_size": 4,
192
192
  }
@@ -222,7 +222,7 @@ def few_shot_classification(presets: str = DEFAULT):
222
222
  "model.names": ["hf_text", "clip"],
223
223
  "model.clip.checkpoint_name": "openai/clip-vit-large-patch14-336",
224
224
  "model.clip.image_size": 336,
225
- "env.eval_batch_size_ratio": 1,
225
+ "env.inference_batch_size_ratio": 1,
226
226
  }
227
227
  )
228
228
  hyperparameter_tune_kwargs = {}
@@ -258,14 +258,14 @@ def zero_shot_image_classification(presets: str = DEFAULT):
258
258
  {
259
259
  "model.clip.checkpoint_name": "openai/clip-vit-large-patch14-336",
260
260
  "model.clip.image_size": 336,
261
- "env.eval_batch_size_ratio": 1,
261
+ "env.inference_batch_size_ratio": 1,
262
262
  }
263
263
  )
264
264
  elif presets == HIGH_QUALITY:
265
265
  hyperparameters.update(
266
266
  {
267
267
  "model.clip.checkpoint_name": "openai/clip-vit-large-patch14",
268
- "env.eval_batch_size_ratio": 1,
268
+ "env.inference_batch_size_ratio": 1,
269
269
  }
270
270
  )
271
271
  elif presets == MEDIUM_QUALITY:
@@ -300,27 +300,27 @@ def object_detection(presets: str = DEFAULT):
300
300
  hyperparameters = {
301
301
  "model.names": ["mmdet_image"],
302
302
  "model.mmdet_image.frozen_layers": [],
303
- "optimization.patience": 20,
304
- "optimization.val_check_interval": 1.0,
305
- "optimization.check_val_every_n_epoch": 1,
303
+ "optim.patience": 20,
304
+ "optim.val_check_interval": 1.0,
305
+ "optim.check_val_every_n_epoch": 1,
306
306
  "env.batch_size": 32,
307
307
  "env.per_gpu_batch_size": 1,
308
308
  "env.num_workers": 2,
309
- "optimization.learning_rate": 1e-5,
310
- "optimization.weight_decay": 1e-4,
311
- "optimization.lr_mult": 10,
312
- "optimization.lr_choice": "two_stages",
313
- "optimization.lr_schedule": "multi_step",
314
- "optimization.gradient_clip_val": 0.1,
315
- "optimization.max_epochs": 60,
316
- "optimization.warmup_steps": 0.0,
317
- "optimization.top_k": 1,
318
- "optimization.top_k_average_method": "best",
319
- "env.eval_batch_size_ratio": 1,
309
+ "optim.lr": 1e-5,
310
+ "optim.weight_decay": 1e-4,
311
+ "optim.lr_mult": 10,
312
+ "optim.lr_choice": "two_stages",
313
+ "optim.lr_schedule": "multi_step",
314
+ "optim.gradient_clip_val": 0.1,
315
+ "optim.max_epochs": 60,
316
+ "optim.warmup_steps": 0.0,
317
+ "optim.top_k": 1,
318
+ "optim.top_k_average_method": "best",
319
+ "env.inference_batch_size_ratio": 1,
320
320
  "env.strategy": "ddp",
321
321
  "env.auto_select_gpus": True, # Turn on for detection to return devices in a list, TODO: fix the extra GPU usage bug
322
322
  "env.num_gpus": -1,
323
- "optimization.lr_decay": 0.9,
323
+ "optim.lr_decay": 0.9,
324
324
  }
325
325
  hyperparameter_tune_kwargs = {}
326
326
 
@@ -335,15 +335,15 @@ def object_detection(presets: str = DEFAULT):
335
335
  {
336
336
  "model.mmdet_image.checkpoint_name": "yolox_l",
337
337
  "env.per_gpu_batch_size": 2, # Works on 8G GPU
338
- "optimization.learning_rate": 5e-5,
339
- "optimization.patience": 5,
340
- "optimization.max_epochs": 50,
341
- "optimization.val_check_interval": 1.0,
342
- "optimization.check_val_every_n_epoch": 3,
343
- "optimization.lr_mult": 100,
344
- "optimization.weight_decay": 1e-3,
345
- "optimization.lr_schedule": "cosine_decay",
346
- "optimization.gradient_clip_val": 1,
338
+ "optim.lr": 5e-5,
339
+ "optim.patience": 5,
340
+ "optim.max_epochs": 50,
341
+ "optim.val_check_interval": 1.0,
342
+ "optim.check_val_every_n_epoch": 3,
343
+ "optim.lr_mult": 100,
344
+ "optim.weight_decay": 1e-3,
345
+ "optim.lr_schedule": "cosine_decay",
346
+ "optim.gradient_clip_val": 1,
347
347
  }
348
348
  )
349
349
  elif presets in [DEFAULT, HIGH_QUALITY]:
@@ -386,32 +386,32 @@ def semantic_segmentation(presets: str = DEFAULT):
386
386
  "model.sam.checkpoint_name": "facebook/sam-vit-huge",
387
387
  "env.batch_size": 4,
388
388
  "env.per_gpu_batch_size": 1,
389
- "env.eval_batch_size_ratio": 1,
389
+ "env.inference_batch_size_ratio": 1,
390
390
  "env.strategy": "ddp_find_unused_parameters_true",
391
391
  "env.auto_select_gpus": False,
392
392
  "env.num_gpus": -1,
393
393
  "env.num_workers": 4,
394
394
  "env.precision": "16-mixed",
395
- "optimization.learning_rate": 1e-4,
396
- "optimization.loss_function": "structure_loss",
397
- "optimization.lr_decay": 0,
398
- "optimization.lr_mult": 1,
399
- "optimization.lr_choice": "single_stage",
400
- "optimization.lr_schedule": "polynomial_decay",
401
- "optimization.max_epochs": 30,
402
- "optimization.top_k": 3,
403
- "optimization.top_k_average_method": "best",
404
- "optimization.warmup_steps": 0.0,
405
- "optimization.weight_decay": 0.0001,
406
- "optimization.patience": 10,
407
- "optimization.val_check_interval": 1.0,
408
- "optimization.check_val_every_n_epoch": 1,
409
- "optimization.efficient_finetune": "lora",
410
- "optimization.lora.module_filter": [".*vision_encoder.*attn"],
411
- "optimization.lora.filter": ["q", "v"],
412
- "optimization.extra_trainable_params": [".*mask_decoder"],
413
- "optimization.lora.r": 3,
414
- "optimization.lora.alpha": 32,
395
+ "optim.lr": 1e-4,
396
+ "optim.loss_func": "structure_loss",
397
+ "optim.lr_decay": 0,
398
+ "optim.lr_mult": 1,
399
+ "optim.lr_choice": "single_stage",
400
+ "optim.lr_schedule": "polynomial_decay",
401
+ "optim.max_epochs": 30,
402
+ "optim.top_k": 3,
403
+ "optim.top_k_average_method": "best",
404
+ "optim.warmup_steps": 0.0,
405
+ "optim.weight_decay": 0.0001,
406
+ "optim.patience": 10,
407
+ "optim.val_check_interval": 1.0,
408
+ "optim.check_val_every_n_epoch": 1,
409
+ "optim.peft": "lora",
410
+ "optim.lora.module_filter": [".*vision_encoder.*attn"],
411
+ "optim.lora.filter": ["q", "v"],
412
+ "optim.extra_trainable_params": [".*mask_decoder"],
413
+ "optim.lora.r": 3,
414
+ "optim.lora.alpha": 32,
415
415
  }
416
416
  hyperparameter_tune_kwargs = {}
417
417
 
@@ -444,7 +444,7 @@ def ocr_text_detection(presets: str = DEFAULT):
444
444
  hyperparameters = {
445
445
  "model.names": ["mmocr_text_detection"],
446
446
  "model.mmocr_text_detection.checkpoint_name": "TextSnake",
447
- "env.eval_batch_size_ratio": 1,
447
+ "env.inference_batch_size_ratio": 1,
448
448
  "env.num_gpus": 1,
449
449
  "env.precision": 32,
450
450
  }
@@ -479,7 +479,7 @@ def ocr_text_recognition(presets: str = DEFAULT):
479
479
  hyperparameters = {
480
480
  "model.names": ["mmocr_text_recognition"],
481
481
  "model.mmocr_text_recognition.checkpoint_name": "ABINet",
482
- "env.eval_batch_size_ratio": 1,
482
+ "env.inference_batch_size_ratio": 1,
483
483
  "env.num_gpus": 1,
484
484
  "env.precision": 32,
485
485
  }
@@ -514,7 +514,7 @@ def feature_extraction(presets: str = DEFAULT): # TODO: rename the problem type
514
514
  "model.names": ["hf_text"],
515
515
  "model.hf_text.checkpoint_name": "sentence-transformers/msmarco-MiniLM-L-12-v3",
516
516
  "model.hf_text.pooling_mode": "mean",
517
- "env.eval_batch_size_ratio": 1,
517
+ "env.inference_batch_size_ratio": 1,
518
518
  }
519
519
  hyperparameter_tune_kwargs = {}
520
520
 
@@ -651,7 +651,7 @@ def image_text_similarity(presets: str = DEFAULT):
651
651
  hyperparameters = {
652
652
  "model.names": ["clip"],
653
653
  "matcher.loss.type": "multi_negatives_softmax_loss",
654
- "optimization.learning_rate": 1e-5,
654
+ "optim.lr": 1e-5,
655
655
  }
656
656
  hyperparameter_tune_kwargs = {}
657
657
 
@@ -747,9 +747,196 @@ def ner(presets: str = DEFAULT):
747
747
  return hyperparameters, hyperparameter_tune_kwargs
748
748
 
749
749
 
750
- def list_automm_presets(verbose: bool = False):
750
+ @automm_presets.register()
751
+ def ensemble(presets: str = DEFAULT):
752
+ hyperparameters = {
753
+ "lf_mlp": {
754
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
755
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
756
+ "model.hf_text.text_trivial_aug_maxscale": 0,
757
+ "data.categorical.convert_to_text": False,
758
+ "data.numerical.convert_to_text": False,
759
+ "optim.cross_modal_align": "null",
760
+ "data.modality_dropout": 0,
761
+ "model.timm_image.use_learnable_image": False,
762
+ "optim.lemda.turn_on": False,
763
+ },
764
+ "lf_transformer": {
765
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_transformer"],
766
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
767
+ "model.hf_text.text_trivial_aug_maxscale": 0,
768
+ "data.categorical.convert_to_text": False,
769
+ "data.numerical.convert_to_text": False,
770
+ "optim.cross_modal_align": "null",
771
+ "data.modality_dropout": 0,
772
+ "model.timm_image.use_learnable_image": False,
773
+ "optim.lemda.turn_on": False,
774
+ },
775
+ "lf_clip": {
776
+ "model.names": ["ft_transformer", "clip_image", "clip_text", "fusion_mlp"],
777
+ "model.clip_image.data_types": ["image"],
778
+ "model.clip_text.data_types": ["text"],
779
+ "model.clip_image.train_transforms": ["resize_shorter_side", "center_crop"],
780
+ "model.clip_text.text_trivial_aug_maxscale": 0,
781
+ "data.categorical.convert_to_text": False,
782
+ "data.numerical.convert_to_text": False,
783
+ "optim.cross_modal_align": "null",
784
+ "data.modality_dropout": 0,
785
+ "model.clip_image.use_learnable_image": False,
786
+ "optim.lemda.turn_on": False,
787
+ },
788
+ "early_fusion": {
789
+ "model.names": ["meta_transformer"],
790
+ "model.meta_transformer.checkpoint_path": "null",
791
+ "model.meta_transformer.train_transforms": ["resize_shorter_side", "center_crop"],
792
+ "model.meta_transformer.text_trivial_aug_maxscale": 0,
793
+ "data.categorical.convert_to_text": False,
794
+ "data.numerical.convert_to_text": False,
795
+ "optim.cross_modal_align": "null",
796
+ "data.modality_dropout": 0,
797
+ "model.meta_transformer.use_learnable_image": False,
798
+ "optim.lemda.turn_on": False,
799
+ },
800
+ "convert_categorical_to_text": {
801
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
802
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
803
+ "model.hf_text.text_trivial_aug_maxscale": 0,
804
+ "data.categorical.convert_to_text": True,
805
+ "data.categorical.convert_to_text_template": "latex",
806
+ "data.numerical.convert_to_text": False,
807
+ "optim.cross_modal_align": "null",
808
+ "data.modality_dropout": 0,
809
+ "model.timm_image.use_learnable_image": False,
810
+ "optim.lemda.turn_on": False,
811
+ },
812
+ "convert_numeric_to_text": {
813
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
814
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
815
+ "model.hf_text.text_trivial_aug_maxscale": 0,
816
+ "data.categorical.convert_to_text": False,
817
+ "data.numerical.convert_to_text": True,
818
+ "optim.cross_modal_align": "null",
819
+ "data.modality_dropout": 0,
820
+ "model.timm_image.use_learnable_image": False,
821
+ "optim.lemda.turn_on": False,
822
+ },
823
+ "cross_modal_align_pos_only": {
824
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
825
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
826
+ "model.hf_text.text_trivial_aug_maxscale": 0,
827
+ "data.categorical.convert_to_text": False,
828
+ "data.numerical.convert_to_text": False,
829
+ "optim.cross_modal_align": "positive_only",
830
+ "optim.cross_modal_align_weight": 1,
831
+ "data.modality_dropout": 0,
832
+ "model.timm_image.use_learnable_image": False,
833
+ "optim.lemda.turn_on": False,
834
+ },
835
+ "input_aug": {
836
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
837
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop", "trivial_augment"],
838
+ "model.hf_text.text_trivial_aug_maxscale": 0.1,
839
+ "data.categorical.convert_to_text": False,
840
+ "data.numerical.convert_to_text": False,
841
+ "optim.cross_modal_align": "null",
842
+ "data.modality_dropout": 0,
843
+ "model.timm_image.use_learnable_image": False,
844
+ "optim.lemda.turn_on": False,
845
+ },
846
+ "feature_aug_lemda": {
847
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
848
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
849
+ "model.hf_text.text_trivial_aug_maxscale": 0,
850
+ "data.categorical.convert_to_text": False,
851
+ "data.numerical.convert_to_text": False,
852
+ "optim.cross_modal_align": "null",
853
+ "data.modality_dropout": 0,
854
+ "model.timm_image.use_learnable_image": False,
855
+ "optim.lemda.turn_on": True,
856
+ "optim.automatic_optimization": False,
857
+ },
858
+ "modality_dropout": {
859
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
860
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
861
+ "model.hf_text.text_trivial_aug_maxscale": 0,
862
+ "data.categorical.convert_to_text": False,
863
+ "data.numerical.convert_to_text": False,
864
+ "optim.cross_modal_align": "null",
865
+ "data.modality_dropout": 0.2,
866
+ "model.timm_image.use_learnable_image": False,
867
+ "optim.lemda.turn_on": False,
868
+ },
869
+ "learnable_image": {
870
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
871
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
872
+ "model.hf_text.text_trivial_aug_maxscale": 0,
873
+ "data.categorical.convert_to_text": False,
874
+ "data.numerical.convert_to_text": False,
875
+ "optim.cross_modal_align": "null",
876
+ "data.modality_dropout": 0,
877
+ "model.timm_image.use_learnable_image": True,
878
+ "optim.lemda.turn_on": False,
879
+ },
880
+ "modality_dropout_and_learnable_image": {
881
+ "model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
882
+ "model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
883
+ "model.hf_text.text_trivial_aug_maxscale": 0,
884
+ "data.categorical.convert_to_text": False,
885
+ "data.numerical.convert_to_text": False,
886
+ "optim.cross_modal_align": "null",
887
+ "data.modality_dropout": 0.2,
888
+ "model.timm_image.use_learnable_image": True,
889
+ "optim.lemda.turn_on": False,
890
+ },
891
+ }
892
+
893
+ if presets in [DEFAULT, HIGH_QUALITY]:
894
+ for v in hyperparameters.values():
895
+ if "timm_image" in v["model.names"]:
896
+ v["model.timm_image.checkpoint_name"] = "caformer_b36.sail_in22k_ft_in1k"
897
+ if "hf_text" in v["model.names"]:
898
+ v["model.hf_text.checkpoint_name"] = "google/electra-base-discriminator"
899
+ if "meta_transformer" in v["model.names"]:
900
+ v["model.meta_transformer.model_version"] = "base"
901
+ if "clip_image" in v["model.names"]:
902
+ v["model.clip_image.checkpoint_name"] = "openai/clip-vit-base-patch32"
903
+ if "clip_text" in v["model.names"]:
904
+ v["model.clip_text.checkpoint_name"] = "openai/clip-vit-base-patch32"
905
+
906
+ elif presets == MEDIUM_QUALITY:
907
+ for v in hyperparameters.values():
908
+ if "timm_image" in v["model.names"]:
909
+ v["model.timm_image.checkpoint_name"] = "mobilenetv3_large_100"
910
+ if "hf_text" in v["model.names"]:
911
+ v["model.hf_text.checkpoint_name"] = "google/electra-small-discriminator"
912
+ if "meta_transformer" in v["model.names"]:
913
+ v["model.meta_transformer.model_version"] = "base"
914
+ if "clip_image" in v["model.names"]:
915
+ v["model.clip_image.checkpoint_name"] = "openai/clip-vit-base-patch32"
916
+ if "clip_text" in v["model.names"]:
917
+ v["model.clip_text.checkpoint_name"] = "openai/clip-vit-base-patch32"
918
+ elif presets == BEST_QUALITY:
919
+ for v in hyperparameters.values():
920
+ if "timm_image" in v["model.names"]:
921
+ v["model.timm_image.checkpoint_name"] = "swin_large_patch4_window7_224"
922
+ if "hf_text" in v["model.names"]:
923
+ v["model.hf_text.checkpoint_name"] = "microsoft/deberta-v3-base"
924
+ if "meta_transformer" in v["model.names"]:
925
+ v["model.meta_transformer.model_version"] = "large"
926
+ if "clip_image" in v["model.names"]:
927
+ v["model.clip_image.checkpoint_name"] = "openai/clip-vit-large-patch14"
928
+ if "clip_text" in v["model.names"]:
929
+ v["model.clip_text.checkpoint_name"] = "openai/clip-vit-large-patch14"
930
+ else:
931
+ raise ValueError(f"Unknown preset type: {presets}")
932
+
933
+ return hyperparameters, None
934
+
935
+
936
+ def list_presets(verbose: bool = False):
751
937
  """
752
938
  List all available presets.
939
+
753
940
  Returns
754
941
  -------
755
942
  A list of presets.
@@ -765,7 +952,7 @@ def list_automm_presets(verbose: bool = False):
765
952
  return preset_details
766
953
 
767
954
 
768
- def get_basic_automm_config(extra: Optional[List[str]] = None):
955
+ def get_basic_config(extra: Optional[List[str]] = None):
769
956
  """
770
957
  Get the basic config of AutoMM.
771
958
 
@@ -776,13 +963,13 @@ def get_basic_automm_config(extra: Optional[List[str]] = None):
776
963
 
777
964
  Returns
778
965
  -------
779
- A dict config with keys: MODEL, DATA, OPTIMIZATION, ENVIRONMENT, and their default values.
966
+ A dict config with keys: MODEL, DATA, OPTIM, ENV, and their default values.
780
967
  """
781
968
  config = {
782
969
  MODEL: DEFAULT,
783
970
  DATA: DEFAULT,
784
- OPTIMIZATION: DEFAULT,
785
- ENVIRONMENT: DEFAULT,
971
+ OPTIM: DEFAULT,
972
+ ENV: DEFAULT,
786
973
  }
787
974
  if extra:
788
975
  for k in extra:
@@ -791,7 +978,7 @@ def get_basic_automm_config(extra: Optional[List[str]] = None):
791
978
  return config
792
979
 
793
980
 
794
- def get_automm_presets(problem_type: str, presets: str):
981
+ def get_presets(problem_type: str, presets: str):
795
982
  """
796
983
  Get the default hyperparameters and hyperparameter_tune_kwargs given problem type and presets.
797
984
 
@@ -831,3 +1018,9 @@ def get_automm_presets(problem_type: str, presets: str):
831
1018
  )
832
1019
 
833
1020
  return hyperparameters, hyperparameter_tune_kwargs
1021
+
1022
+
1023
+ def get_ensemble_presets(presets):
1024
+ if not presets:
1025
+ presets = DEFAULT
1026
+ return automm_presets.create("ensemble", presets)
@@ -1,9 +1,10 @@
1
1
  """Problem types supported in MultiModalPredictor"""
2
2
 
3
+ import logging
3
4
  from dataclasses import dataclass, field
4
5
  from typing import List, Optional, Set
5
6
 
6
- from .constants import (
7
+ from ..constants import (
7
8
  ACCURACY,
8
9
  BINARY,
9
10
  CATEGORICAL,
@@ -38,6 +39,8 @@ from .constants import (
38
39
  )
39
40
  from .registry import Registry
40
41
 
42
+ logger = logging.getLogger(__name__)
43
+
41
44
  PROBLEM_TYPES_REG = Registry("problem_type_properties")
42
45
 
43
46
 
@@ -277,3 +280,29 @@ PROBLEM_TYPES_REG.register(
277
280
  _fallback_validation_metric=ACCURACY,
278
281
  ),
279
282
  )
283
+
284
+
285
+ def infer_problem_type_by_eval_metric(eval_metric_name: str, problem_type: str):
286
+ if eval_metric_name is not None and eval_metric_name.lower() in [
287
+ "rmse",
288
+ "r2",
289
+ "pearsonr",
290
+ "spearmanr",
291
+ ]:
292
+ if problem_type is None:
293
+ logger.debug(
294
+ f"Infer problem type to be a regression problem "
295
+ f"since the evaluation metric is set as {eval_metric_name}."
296
+ )
297
+ problem_type = REGRESSION
298
+ else:
299
+ problem_prop = PROBLEM_TYPES_REG.get(problem_type)
300
+ if NUMERICAL not in problem_prop.supported_label_type:
301
+ raise ValueError(
302
+ f"The provided evaluation metric will require the problem "
303
+ f"to support label type = {NUMERICAL}. However, "
304
+ f"the provided problem type = {problem_type} only "
305
+ f"supports label type = {problem_prop.supported_label_type}."
306
+ )
307
+
308
+ return problem_type