autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -8,6 +8,7 @@ import os
8
8
  import warnings
9
9
  from typing import Dict, List, Optional, Union
10
10
 
11
+ import numpy as np
11
12
  import pandas as pd
12
13
  import transformers
13
14
 
@@ -17,14 +18,15 @@ from autogluon.core.metrics import Scorer
17
18
  from .constants import AUTOMM_TUTORIAL_MODE, FEW_SHOT_CLASSIFICATION, NER, OBJECT_DETECTION, SEMANTIC_SEGMENTATION
18
19
  from .learners import (
19
20
  BaseLearner,
21
+ EnsembleLearner,
20
22
  FewShotSVMLearner,
21
- MultiModalMatcher,
23
+ MatchingLearner,
22
24
  NERLearner,
23
25
  ObjectDetectionLearner,
24
26
  SemanticSegmentationLearner,
25
27
  )
26
- from .problem_types import PROBLEM_TYPES_REG
27
28
  from .utils import get_dir_ckpt_paths
29
+ from .utils.problem_types import PROBLEM_TYPES_REG
28
30
 
29
31
  pl_logger = logging.getLogger("lightning")
30
32
  pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/4621
@@ -64,6 +66,9 @@ class MultiModalPredictor:
64
66
  pretrained: Optional[bool] = True,
65
67
  validation_metric: Optional[str] = None,
66
68
  sample_data_path: Optional[str] = None,
69
+ use_ensemble: Optional[bool] = False,
70
+ ensemble_size: Optional[int] = 2,
71
+ ensemble_mode: Optional[str] = "one_shot",
67
72
  ):
68
73
  """
69
74
  Parameters
@@ -164,6 +169,15 @@ class MultiModalPredictor:
164
169
  If not provided, it would be automatically chosen based on the problem type.
165
170
  sample_data_path
166
171
  The path to sample data from which we can infer num_classes or classes used for object detection.
172
+ use_ensemble
173
+ Whether to use ensembling when fitting the predictor (Default False).
174
+ Currently, it works only on multimodal data (image+text, image+tabular, text+tabular, image+text+tabular) with classification or regression tasks.
175
+ ensemble_size
176
+ A multiple of number of models in the ensembling pool (Default 2). The actual ensemble size = ensemble_size * the model number
177
+ ensemble_mode
178
+ The mode of conducting ensembling:
179
+ - `one_shot`: the classic ensemble selection
180
+ - `sequential`: iteratively calling the classic ensemble selection with each time growing the model zoo by the best next model.
167
181
  """
168
182
  if problem_type is not None:
169
183
  problem_type = problem_type.lower()
@@ -192,7 +206,7 @@ class MultiModalPredictor:
192
206
  self._verbosity = verbosity
193
207
 
194
208
  if problem_property and problem_property.is_matching:
195
- learner_class = MultiModalMatcher
209
+ learner_class = MatchingLearner
196
210
  elif problem_type == OBJECT_DETECTION:
197
211
  learner_class = ObjectDetectionLearner
198
212
  elif problem_type == NER:
@@ -204,6 +218,9 @@ class MultiModalPredictor:
204
218
  else:
205
219
  learner_class = BaseLearner
206
220
 
221
+ if use_ensemble:
222
+ learner_class = EnsembleLearner
223
+
207
224
  self._learner = learner_class(
208
225
  label=label,
209
226
  problem_type=problem_type,
@@ -222,6 +239,8 @@ class MultiModalPredictor:
222
239
  query=query,
223
240
  response=response,
224
241
  match_label=match_label,
242
+ ensemble_size=ensemble_size,
243
+ ensemble_mode=ensemble_mode,
225
244
  )
226
245
 
227
246
  @property
@@ -413,6 +432,9 @@ class MultiModalPredictor:
413
432
  standalone: Optional[bool] = True,
414
433
  hyperparameter_tune_kwargs: Optional[dict] = None,
415
434
  clean_ckpts: Optional[bool] = True,
435
+ predictions: Optional[List[np.ndarray]] = None,
436
+ labels: Optional[np.ndarray] = None,
437
+ predictors: Optional[List[Union[str, MultiModalPredictor]]] = None,
416
438
  ):
417
439
  """
418
440
  Fit models to predict a column of a data table (label) based on the other columns (features).
@@ -435,6 +457,8 @@ class MultiModalPredictor:
435
457
  time_limit
436
458
  How long `fit()` should run for (wall clock time in seconds).
437
459
  If not specified, `fit()` will run until the model has completed training.
460
+ Note that, if use_ensemble=True, the total running time would be time_limit * N,
461
+ where N is the number of models in the ensemble.
438
462
  save_path
439
463
  Path to directory where models and artifacts should be saved.
440
464
  hyperparameters
@@ -506,6 +530,13 @@ class MultiModalPredictor:
506
530
  teacher_learner = teacher_predictor
507
531
  else:
508
532
  teacher_learner = teacher_predictor._learner
533
+
534
+ if predictors is None:
535
+ learners = None
536
+ else:
537
+ assert isinstance(predictors, list)
538
+ learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
539
+
509
540
  self._learner.fit(
510
541
  train_data=train_data,
511
542
  presets=presets,
@@ -522,6 +553,9 @@ class MultiModalPredictor:
522
553
  hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
523
554
  clean_ckpts=clean_ckpts,
524
555
  id_mappings=id_mappings,
556
+ predictions=predictions,
557
+ labels=labels,
558
+ learners=learners,
525
559
  )
526
560
 
527
561
  return self
@@ -540,6 +574,8 @@ class MultiModalPredictor:
540
574
  return_pred: Optional[bool] = False,
541
575
  realtime: Optional[bool] = False,
542
576
  eval_tool: Optional[str] = None,
577
+ predictions: Optional[List[np.ndarray]] = None,
578
+ labels: Optional[np.ndarray] = None,
543
579
  ):
544
580
  """
545
581
  Evaluate the model on a given dataset.
@@ -595,6 +631,8 @@ class MultiModalPredictor:
595
631
  similarity_type=similarity_type,
596
632
  cutoffs=cutoffs,
597
633
  label=label,
634
+ predictions=predictions,
635
+ labels=labels,
598
636
  )
599
637
 
600
638
  def predict(
@@ -807,18 +845,19 @@ class MultiModalPredictor:
807
845
 
808
846
  with open(os.path.join(dir_path, "assets.json"), "r") as fp:
809
847
  assets = json.load(fp)
810
- if "class_name" in assets and assets["class_name"] == "MultiModalMatcher":
811
- learner_class = MultiModalMatcher
812
- elif assets["problem_type"] == OBJECT_DETECTION:
848
+ learner_class = BaseLearner
849
+ if assets["learner_class"] == "MatchingLearner":
850
+ learner_class = MatchingLearner
851
+ elif assets["learner_class"] == "EnsembleLearner":
852
+ learner_class = EnsembleLearner
853
+ elif assets["learner_class"] == "FewShotSVMLearner":
854
+ learner_class = FewShotSVMLearner
855
+ elif assets["learner_class"] == "ObjectDetectionLearner":
813
856
  learner_class = ObjectDetectionLearner
814
- elif assets["problem_type"] == NER:
857
+ elif assets["learner_class"] == "NERLearner":
815
858
  learner_class = NERLearner
816
- elif assets["problem_type"] == FEW_SHOT_CLASSIFICATION:
817
- learner_class = FewShotSVMLearner
818
- elif assets["problem_type"] == SEMANTIC_SEGMENTATION:
859
+ elif assets["learner_class"] == "SemanticSegmentationLearner":
819
860
  learner_class = SemanticSegmentationLearner
820
- else:
821
- learner_class = BaseLearner
822
861
 
823
862
  predictor._learner = learner_class.load(path=path, resume=resume, verbosity=verbosity)
824
863
  return predictor
@@ -16,66 +16,34 @@ from .config import (
16
16
  update_config_by_rules,
17
17
  update_hyperparameters,
18
18
  update_tabular_config_by_resources,
19
- upgrade_config,
20
- )
21
- from .data import (
22
- assign_feature_column_names,
23
- create_data_processor,
24
- create_fusion_data_processors,
25
- data_to_df,
26
- get_mixup,
27
- infer_dtypes_by_model_names,
28
- infer_scarcity_mode_by_data_size,
29
- init_df_preprocessor,
30
- split_train_tuning_data,
31
- turn_on_off_feature_column_info,
19
+ update_ensemble_hyperparameters,
32
20
  )
21
+ from .device import compute_num_gpus, get_available_devices, move_to_device
33
22
  from .distillation import DistillationMixin
34
23
  from .download import download, is_url
35
- from .environment import (
36
- check_if_packages_installed,
37
- compute_inference_batch_size,
38
- compute_num_gpus,
39
- get_available_devices,
40
- get_precision_context,
41
- infer_precision,
42
- is_interactive_env,
43
- is_interactive_strategy,
44
- move_to_device,
45
- run_ddp_only_once,
46
- )
47
24
  from .export import ExportMixin
48
25
  from .hpo import hyperparameter_tune
49
- from .inference import RealtimeMixin, extract_from_output
50
- from .load import CustomUnpickler, get_dir_ckpt_paths, get_load_ckpt_paths, load_text_tokenizers
26
+ from .inference import RealtimeMixin, compute_inference_batch_size, extract_from_output
27
+ from .load import CustomUnpickler, protected_zip_extraction, get_dir_ckpt_paths, get_load_ckpt_paths
51
28
  from .log import (
52
29
  LogFilter,
53
30
  apply_log_filter,
54
31
  get_gpu_message,
55
- make_exp_dir,
56
32
  on_fit_end_message,
57
33
  on_fit_per_run_start_message,
58
34
  on_fit_start_message,
59
35
  )
60
36
  from .matcher import compute_semantic_similarity, convert_data_for_ranking, create_siamese_model, semantic_search
61
- from .metric import (
62
- compute_ranking_score,
63
- compute_score,
64
- get_minmax_mode,
65
- get_stopping_threshold,
66
- infer_metrics,
67
- infer_problem_type_by_eval_metric,
37
+ from .misc import (
38
+ logits_to_prob,
39
+ path_expander,
40
+ path_to_base64str_expander,
41
+ path_to_bytearray_expander,
42
+ shopee_dataset,
43
+ tensor_to_ndarray,
44
+ merge_bio_format,
68
45
  )
69
- from .misc import logits_to_prob, merge_bio_format, shopee_dataset, tensor_to_ndarray
70
46
  from .mmcv import CollateMMDet, CollateMMOcr
71
- from .model import (
72
- create_fusion_model,
73
- create_model,
74
- is_lazy_weight_tensor,
75
- list_timm_models,
76
- modify_duplicate_model_names,
77
- select_model,
78
- )
79
47
  from .object_detection import (
80
48
  COCODataset,
81
49
  bbox_ratio_xywh_to_index_xyxy,
@@ -94,5 +62,11 @@ from .object_detection import (
94
62
  save_result_voc_format,
95
63
  visualize_detection,
96
64
  )
97
- from .save import process_save_path, save_text_tokenizers, setup_save_path
65
+ from .precision import get_precision_context, infer_precision
66
+ from .presets import get_basic_config, get_ensemble_presets, get_presets, list_presets, matcher_presets
67
+ from .problem_types import PROBLEM_TYPES_REG, infer_problem_type_by_eval_metric
68
+ from .save import process_save_path, setup_save_path, make_exp_dir
69
+ from .strategy import is_interactive_strategy, run_ddp_only_once
70
+ from .env import is_interactive_env
98
71
  from .visualizer import NERVisualizer, ObjectDetectionVisualizer, SemanticSegmentationVisualizer, visualize_ner
72
+ from .install import check_if_packages_installed
@@ -10,7 +10,18 @@ import lightning.pytorch as pl
10
10
  import torch
11
11
  from lightning.pytorch.callbacks import BasePredictionWriter
12
12
 
13
- from ..constants import BBOX, LM_TARGET, LOGIT_SCALE, LOGITS, TEMPLATE_LOGITS, WEIGHT
13
+ from ..constants import (
14
+ AUG_LOGITS,
15
+ BBOX,
16
+ LOGIT_SCALE,
17
+ MULTIMODAL_FEATURES,
18
+ MULTIMODAL_FEATURES_POST_AUG,
19
+ MULTIMODAL_FEATURES_PRE_AUG,
20
+ ORI_LOGITS,
21
+ VAE_MEAN,
22
+ VAE_VAR,
23
+ WEIGHT,
24
+ )
14
25
 
15
26
  logger = logging.getLogger(__name__)
16
27
 
@@ -146,7 +157,17 @@ class DDPPredictionWriter(BasePredictionWriter):
146
157
  return dict()
147
158
 
148
159
  for k, v in x[0].items():
149
- if k in [WEIGHT, LOGIT_SCALE]: # ignore the keys
160
+ if k in [
161
+ WEIGHT,
162
+ LOGIT_SCALE,
163
+ MULTIMODAL_FEATURES,
164
+ MULTIMODAL_FEATURES_PRE_AUG,
165
+ MULTIMODAL_FEATURES_POST_AUG,
166
+ ORI_LOGITS,
167
+ AUG_LOGITS,
168
+ VAE_MEAN,
169
+ VAE_VAR,
170
+ ]: # ignore the keys
150
171
  continue
151
172
  elif isinstance(v, dict):
152
173
  results[k] = self.collate([i[k] for i in x])
@@ -1,16 +1,30 @@
1
+ """
2
+ Some utilities are copied from
3
+ https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/utilities/cloud_io.py
4
+ to address warnings:
5
+ LightningDeprecationWarning: lightning.pytorch.utilities.cloud_io.atomic_save has been
6
+ deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you
7
+ can copy over its implementation.
8
+ """
9
+
10
+ import io
1
11
  import logging
2
12
  import os
3
13
  import re
4
14
  import shutil
5
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from pathlib import Path
16
+ from typing import IO, Any, Callable, Dict, List, Optional, Tuple, Union
6
17
 
18
+ import fsspec
7
19
  import lightning.pytorch as pl
8
20
  import torch
9
21
  from lightning.pytorch.strategies import DeepSpeedStrategy
10
22
  from lightning.pytorch.utilities.rank_zero import rank_zero_warn
11
23
 
12
- from .cloud_io import _atomic_save, get_filesystem
13
- from .cloud_io import _load as pl_load
24
+ from .env import get_filesystem
25
+
26
+ _DEVICE = Union[torch.device, str, int]
27
+ _MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
14
28
 
15
29
  logger = logging.getLogger(__name__)
16
30
 
@@ -65,6 +79,45 @@ def average_checkpoints(
65
79
  return avg_state_dict
66
80
 
67
81
 
82
+ def pl_load(
83
+ path_or_url: Union[IO, str, Path],
84
+ map_location: _MAP_LOCATION_TYPE = None,
85
+ ) -> Any:
86
+ """Loads a checkpoint.
87
+
88
+ Args:
89
+ path_or_url: Path or URL of the checkpoint.
90
+ map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
91
+ """
92
+ if not isinstance(path_or_url, (str, Path)):
93
+ # any sort of BytesIO or similar
94
+ return torch.load(path_or_url, map_location=map_location) # nosec B614
95
+ if str(path_or_url).startswith("http"):
96
+ return torch.hub.load_state_dict_from_url(
97
+ str(path_or_url),
98
+ map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
99
+ )
100
+ fs = get_filesystem(path_or_url)
101
+ with fs.open(path_or_url, "rb") as f:
102
+ return torch.load(f, map_location=map_location) # nosec B614
103
+
104
+
105
+ def pl_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
106
+ """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
107
+
108
+ Args:
109
+ checkpoint: The object to save.
110
+ Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
111
+ accepts.
112
+ filepath: The path to which the checkpoint will be saved.
113
+ This points to the file that the checkpoint will be stored in.
114
+ """
115
+ bytesbuffer = io.BytesIO()
116
+ torch.save(checkpoint, bytesbuffer) # nosec B614
117
+ with fsspec.open(filepath, "wb") as f:
118
+ f.write(bytesbuffer.getvalue())
119
+
120
+
68
121
  class AutoMMModelCheckpointIO(pl.plugins.CheckpointIO):
69
122
  """
70
123
  Class that customizes how checkpoints are saved. Saves either the entire model or only parameters that have been explicitly updated during training. The latter reduces memory footprint substantially when training very large models with parameter-efficient finetuning methods.
@@ -124,14 +177,14 @@ class AutoMMModelCheckpointIO(pl.plugins.CheckpointIO):
124
177
  fs.makedirs(os.path.dirname(path), exist_ok=True)
125
178
  try:
126
179
  # write the checkpoint dictionary on the file
127
- _atomic_save(checkpoint, path)
180
+ pl_save(checkpoint, path)
128
181
  except AttributeError as err:
129
182
  # todo (sean): is this try catch necessary still?
130
183
  # https://github.com/Lightning-AI/lightning/pull/431
131
184
  key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
132
185
  checkpoint.pop(key, None)
133
186
  rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
134
- _atomic_save(checkpoint, path)
187
+ pl_save(checkpoint, path)
135
188
 
136
189
  def load_checkpoint(self, path, map_location: Optional[Any] = None) -> Dict[str, Any]:
137
190
  """
@@ -5,22 +5,11 @@ import re
5
5
  import warnings
6
6
  from typing import Callable, Dict, List, Optional, Tuple, Union
7
7
 
8
- from omegaconf import DictConfig, OmegaConf
9
- from packaging import version
8
+ from omegaconf import DictConfig, ListConfig, OmegaConf
10
9
  from torch import nn
11
10
 
12
- from ..constants import (
13
- AUTOMM,
14
- DATA,
15
- FT_TRANSFORMER,
16
- FUSION_TRANSFORMER,
17
- HF_MODELS,
18
- MODEL,
19
- REGRESSION,
20
- VALID_CONFIG_KEYS,
21
- )
22
- from ..presets import get_automm_presets, get_basic_automm_config
23
- from .data import get_detected_data_types
11
+ from ..constants import DATA, FT_TRANSFORMER, FUSION_TRANSFORMER, HF_MODELS, MODEL, REGRESSION, VALID_CONFIG_KEYS
12
+ from .presets import get_basic_config, get_ensemble_presets, get_presets
24
13
 
25
14
  logger = logging.getLogger(__name__)
26
15
 
@@ -68,7 +57,7 @@ def get_default_config(config: Optional[Union[Dict, DictConfig]] = None, extra:
68
57
  Parameters
69
58
  ----------
70
59
  config
71
- A dictionary including four keys: "model", "data", "optimization", and "environment".
60
+ A dictionary including four keys: "model", "data", "optim", and "env".
72
61
  If any key is not given, we will fill in with the default value.
73
62
  extra
74
63
  A list of extra config keys.
@@ -83,7 +72,7 @@ def get_default_config(config: Optional[Union[Dict, DictConfig]] = None, extra:
83
72
  if config is None:
84
73
  config = {}
85
74
 
86
- basic_config = get_basic_automm_config(extra=extra)
75
+ basic_config = get_basic_config(extra=extra)
87
76
  for k, default_value in basic_config.items():
88
77
  if k not in config:
89
78
  config[k] = default_value
@@ -119,7 +108,7 @@ def get_config(
119
108
  extra: Optional[List[str]] = None,
120
109
  ):
121
110
  """
122
- Construct configurations for model, data, optimization, and environment.
111
+ Construct configurations for model, data, optim, and env.
123
112
  It supports to overrides some default configurations.
124
113
 
125
114
  Parameters
@@ -129,29 +118,29 @@ def get_config(
129
118
  presets
130
119
  Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
131
120
  config
132
- A dictionary including four keys: "model", "data", "optimization", and "environment".
121
+ A dictionary including four keys: "model", "data", "optim", and "env".
133
122
  If any key is not given, we will fill in with the default value.
134
123
 
135
124
  The value of each key can be a string, yaml path, or DictConfig object. For example:
136
125
  config = {
137
126
  "model": "default",
138
127
  "data": "default",
139
- "optimization": "default",
140
- "environment": "default",
128
+ "optim": "default",
129
+ "env": "default",
141
130
  }
142
131
  or
143
132
  config = {
144
133
  "model": "/path/to/model/config.yaml",
145
134
  "data": "/path/to/data/config.yaml",
146
- "optimization": "/path/to/optimization/config.yaml",
147
- "environment": "/path/to/environment/config.yaml",
135
+ "optim": "/path/to/optim/config.yaml",
136
+ "env": "/path/to/env/config.yaml",
148
137
  }
149
138
  or
150
139
  config = {
151
140
  "model": OmegaConf.load("/path/to/model/config.yaml"),
152
141
  "data": OmegaConf.load("/path/to/data/config.yaml"),
153
- "optimization": OmegaConf.load("/path/to/optimization/config.yaml"),
154
- "environment": OmegaConf.load("/path/to/environment/config.yaml"),
142
+ "optim": OmegaConf.load("/path/to/optim/config.yaml"),
143
+ "env": OmegaConf.load("/path/to/env/config.yaml"),
155
144
  }
156
145
  overrides
157
146
  This is to override some default configurations.
@@ -185,7 +174,7 @@ def get_config(
185
174
  if presets is None:
186
175
  preset_overrides = None
187
176
  else:
188
- preset_overrides, _ = get_automm_presets(problem_type=problem_type, presets=presets)
177
+ preset_overrides, _ = get_presets(problem_type=problem_type, presets=presets)
189
178
 
190
179
  config = get_default_config(config, extra=extra)
191
180
  # apply the preset's overrides
@@ -404,33 +393,6 @@ def get_local_pretrained_config_paths(config: DictConfig, path: str) -> DictConf
404
393
  return config
405
394
 
406
395
 
407
- def upgrade_config(config, loaded_version):
408
- """Upgrade outdated configurations
409
-
410
- Parameters
411
- ----------
412
- config
413
- The configuration
414
- loaded_version
415
- The version of the config that has been loaded
416
-
417
- Returns
418
- -------
419
- config
420
- The upgraded configuration
421
- """
422
- # backward compatibility for variable image size.
423
- if version.parse(loaded_version) <= version.parse("0.6.2"):
424
- logger.info(f"Start to upgrade the previous configuration trained by AutoMM version={loaded_version}.")
425
- if OmegaConf.select(config, "model.timm_image") is not None:
426
- logger.warning(
427
- "Loading a model that has been trained via AutoGluon Multimodal<=0.6.2. "
428
- "Setting config.model.timm_image.image_size = None."
429
- )
430
- config.model.timm_image.image_size = None
431
- return config
432
-
433
-
434
396
  def parse_dotlist_conf(conf):
435
397
  """
436
398
  Parse the config files that is potentially in the dotlist format to a dictionary.
@@ -499,6 +461,7 @@ def apply_omegaconf_overrides(
499
461
  The updated configuration.
500
462
  """
501
463
  overrides = parse_dotlist_conf(overrides)
464
+ overrides = make_overrides_backward_compatible(overrides)
502
465
 
503
466
  def _check_exist_dotlist(C, key_in_dotlist):
504
467
  if not isinstance(key_in_dotlist, list):
@@ -519,10 +482,34 @@ def apply_omegaconf_overrides(
519
482
  f"overrides={overrides}"
520
483
  )
521
484
  override_conf = OmegaConf.from_dotlist([f"{ele[0]}={ele[1]}" for ele in overrides.items()])
485
+ replace_none_str(override_conf)
522
486
  conf = OmegaConf.merge(conf, override_conf)
523
487
  return conf
524
488
 
525
489
 
490
+ def replace_none_str(config: Union[DictConfig, ListConfig, dict, list]):
491
+ """
492
+ In-place replace "None" and "none" strings in the config with None.
493
+
494
+ Parameters
495
+ ----------
496
+ config
497
+ A config of type DictConfig, ListConfig, dict, or list.
498
+ """
499
+ if isinstance(config, (dict, DictConfig)):
500
+ for key, value in config.items():
501
+ if isinstance(value, str) and value.lower() == "none":
502
+ config[key] = None
503
+ elif isinstance(value, (dict, list, DictConfig, ListConfig)):
504
+ replace_none_str(value)
505
+ elif isinstance(config, (list, ListConfig)):
506
+ for i, value in enumerate(config):
507
+ if isinstance(value, str) and value.lower() == "none":
508
+ config[i] = None
509
+ elif isinstance(value, (dict, list, DictConfig, ListConfig)):
510
+ replace_none_str(value)
511
+
512
+
526
513
  def update_config_by_rules(
527
514
  problem_type: str,
528
515
  config: DictConfig,
@@ -542,11 +529,11 @@ def update_config_by_rules(
542
529
  -------
543
530
  The modified config.
544
531
  """
545
- loss_func = OmegaConf.select(config, "optimization.loss_function")
532
+ loss_func = config.optim.loss_func
546
533
  if loss_func is not None:
547
534
  if problem_type == REGRESSION and "bce" in loss_func.lower():
548
535
  # To use BCELoss for regression problems, need to first scale the labels.
549
- config.data.label.numerical_label_preprocessing = "minmaxscaler"
536
+ config.data.label.numerical_preprocessing = "minmaxscaler"
550
537
 
551
538
  return config
552
539
 
@@ -658,7 +645,7 @@ def update_hyperparameters(
658
645
  -------
659
646
  The updated hyperparameters and hyperparameter_tune_kwargs.
660
647
  """
661
- hyperparameters, hyperparameter_tune_kwargs = get_automm_presets(problem_type=problem_type, presets=presets)
648
+ hyperparameters, hyperparameter_tune_kwargs = get_presets(problem_type=problem_type, presets=presets)
662
649
 
663
650
  if hyperparameter_tune_kwargs and provided_hyperparameter_tune_kwargs:
664
651
  hyperparameter_tune_kwargs.update(provided_hyperparameter_tune_kwargs)
@@ -732,6 +719,8 @@ def filter_hyperparameters(
732
719
  # Filter models whose data types are not detected.
733
720
  # Avoid sampling unused checkpoints, e.g., hf_text models for image classification, to run jobs,
734
721
  # which wastes resources and time.
722
+ from ..data.utils import get_detected_data_types
723
+
735
724
  detected_data_types = get_detected_data_types(column_types)
736
725
  selected_model_names = []
737
726
  for model_name in hyperparameters[model_names_key]:
@@ -796,3 +785,86 @@ def split_hyperparameters(hyperparameters: Dict):
796
785
  raise ValueError(f"transform_types {v} contain neither all strings nor all callable objects.")
797
786
 
798
787
  return hyperparameters, advanced_hyperparameters
788
+
789
+
790
+ def update_ensemble_hyperparameters(
791
+ presets,
792
+ provided_hyperparameters,
793
+ ):
794
+ presets_hyperparameters, _ = get_ensemble_presets(presets=presets)
795
+ if provided_hyperparameters:
796
+ learner_names = provided_hyperparameters.pop("learner_names", None)
797
+ if learner_names:
798
+ assert isinstance(
799
+ learner_names, list
800
+ ), f"learner_names should be a list, but got type {type(learner_names)}"
801
+ presets_hyperparameters = {k: v for k, v in presets_hyperparameters.items() if k in learner_names}
802
+ provided_hyperparameters = {k: v for k, v in provided_hyperparameters.items() if k in learner_names}
803
+
804
+ hyperparameters = copy.deepcopy(provided_hyperparameters)
805
+ for k, v in presets_hyperparameters.items():
806
+ if k not in hyperparameters:
807
+ hyperparameters[k] = v
808
+ else:
809
+ for kk, vv in presets_hyperparameters[k].items():
810
+ if kk not in hyperparameters[k]: # don't use presets to overwrite user-provided
811
+ hyperparameters[k][kk] = vv
812
+ else:
813
+ hyperparameters = presets_hyperparameters
814
+
815
+ return hyperparameters
816
+
817
+
818
+ def make_overrides_backward_compatible(overrides: Dict):
819
+ """
820
+ Some config keys were changed in PR https://github.com/autogluon/autogluon/pull/4737
821
+ This function is to make the changes backward compatible.
822
+
823
+ Parameters
824
+ ----------
825
+ overrides
826
+ A dictionary containing the user-provided hyperparameters,
827
+ which may contain old config keys.
828
+
829
+ Returns
830
+ -------
831
+ Overrides with up-to-date config keys.
832
+ """
833
+ key_pairs = {
834
+ "optim.learning_rate": "optim.lr",
835
+ "optim.efficient_finetune": "optim.peft",
836
+ "optim.loss_function": "optim.loss_func",
837
+ "env.num_workers_evaluation": "env.num_workers_inference",
838
+ "env.eval_batch_size_ratio": "env.inference_batch_size_ratio",
839
+ "data.label.numerical_label_preprocessing": "data.label.numerical_preprocessing",
840
+ "model.categorical_mlp.drop_rate": "model.categorical_mlp.dropout",
841
+ "model.numerical_mlp.drop_rate": "model.numerical_mlp.dropout",
842
+ "model.numerical_mlp.d_token": "model.numerical_mlp.token_dim",
843
+ "model.timm_image.max_img_num_per_col": "model.timm_image.max_image_num_per_column",
844
+ "model.clip.max_img_num_per_col": "model.clip.max_image_num_per_column",
845
+ "model.clip_image.max_img_num_per_col": "model.clip_image.max_image_num_per_column",
846
+ "model.fusion_mlp.weight": "model.fusion_mlp.aux_loss_weight",
847
+ "model.fusion_mlp.drop_rate": "model.fusion_mlp.dropout",
848
+ "model.fusion_transformer.n_blocks": "model.fusion_transformer.num_blocks",
849
+ "model.fusion_transformer.attention_n_heads": "model.fusion_transformer.attention_num_heads",
850
+ "model.fusion_transformer.ffn_d_hidden": "model.fusion_transformer.ffn_hidden_size",
851
+ "model.ft_transformer.attention_n_heads": "model.ft_transformer.attention_num_heads",
852
+ }
853
+ for k in list(overrides.keys()):
854
+ provided_k = k
855
+ if k.startswith("optimization."):
856
+ k = "optim." + k[len("optimization.") :]
857
+ logger.warning(
858
+ f"The provided hyperparameter name {provided_k} contains a deprecated key `optimization.`. "
859
+ f"Please replace `optimization.` with `optim.` when customizing the optimization hyperparameters."
860
+ )
861
+
862
+ if k in key_pairs:
863
+ overrides[key_pairs[k]] = overrides.pop(provided_k)
864
+ logger.warning(
865
+ f"The hyperparameter name {provided_k} is depreciated. "
866
+ f"We recommend using the new name {key_pairs[k]} instead."
867
+ f"The deprecated hyperparameter will raise an exception starting in AutoGluon 1.4.0"
868
+ )
869
+
870
+ return overrides