autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -18,19 +18,10 @@ from autogluon.core.metrics import Scorer
18
18
  from autogluon.core.utils.loaders import load_pd
19
19
 
20
20
  from ..constants import CLIP, COLUMN_FEATURES, HF_TEXT, TIMM_IMAGE, Y_PRED, Y_TRUE
21
- from ..data import BaseDataModule, MultiModalFeaturePreprocessor
22
- from ..utils import (
23
- CustomUnpickler,
24
- LogFilter,
25
- apply_log_filter,
26
- compute_score,
27
- data_to_df,
28
- extract_from_output,
29
- get_available_devices,
30
- logits_to_prob,
31
- select_model,
32
- turn_on_off_feature_column_info,
33
- )
21
+ from ..data import BaseDataModule, MultiModalFeaturePreprocessor, data_to_df, turn_on_off_feature_column_info
22
+ from ..models import select_model
23
+ from ..optim import compute_score
24
+ from ..utils import LogFilter, apply_log_filter, extract_from_output, get_available_devices, logits_to_prob
34
25
  from .base import BaseLearner
35
26
 
36
27
  logger = logging.getLogger(__name__)
@@ -62,7 +53,7 @@ class FewShotSVMLearner(BaseLearner):
62
53
  "model.hf_text.checkpoint_name": "sentence-transformers/all-mpnet-base-v2",
63
54
  "model.hf_text.pooling_mode": "mean",
64
55
  "env.per_gpu_batch_size": 32,
65
- "env.eval_batch_size_ratio": 4,
56
+ "env.inference_batch_size_ratio": 4,
66
57
  }
67
58
  presets
68
59
  Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
@@ -589,7 +580,7 @@ class FewShotSVMLearner(BaseLearner):
589
580
  ):
590
581
  predictor = super().load(path=path, resume=resume, verbosity=verbosity)
591
582
  with open(os.path.join(path, "svm.pkl"), "rb") as fp:
592
- params = CustomUnpickler(fp).load()
583
+ params = pickle.load(fp) # nosec B301
593
584
  svm = make_pipeline(StandardScaler(), SVC(gamma="auto"))
594
585
  svm.set_params(**params)
595
586
  predictor._svm = svm
@@ -46,41 +46,43 @@ from ..constants import (
46
46
  Y_PRED_PROB,
47
47
  Y_TRUE,
48
48
  )
49
- from ..data import BaseDataModule, MultiModalFeaturePreprocessor, infer_column_types
50
- from ..optimization import MatcherLitModule, get_matcher_loss_func, get_matcher_miner_func, get_metric
51
- from ..presets import matcher_presets
52
- from ..problem_types import PROBLEM_TYPES_REG
53
- from ..utils import (
54
- CustomUnpickler,
55
- assign_feature_column_names,
56
- average_checkpoints,
49
+ from ..data import (
50
+ BaseDataModule,
51
+ MultiModalFeaturePreprocessor,
52
+ create_fusion_data_processors,
53
+ data_to_df,
54
+ infer_column_types,
55
+ infer_dtypes_by_model_names,
56
+ init_df_preprocessor,
57
+ )
58
+ from ..models import is_lazy_weight_tensor, select_model
59
+ from ..optim import (
60
+ MatcherLitModule,
57
61
  compute_ranking_score,
58
62
  compute_score,
63
+ get_matcher_loss_func,
64
+ get_matcher_miner_func,
65
+ get_torchmetric,
66
+ )
67
+ from ..utils import (
68
+ average_checkpoints,
59
69
  compute_semantic_similarity,
60
70
  convert_data_for_ranking,
61
- create_fusion_data_processors,
62
71
  create_siamese_model,
63
72
  customize_model_names,
64
- data_to_df,
65
73
  extract_from_output,
66
74
  get_config,
67
75
  get_dir_ckpt_paths,
68
76
  get_load_ckpt_paths,
69
77
  get_local_pretrained_config_paths,
70
- get_minmax_mode,
71
78
  hyperparameter_tune,
72
- infer_dtypes_by_model_names,
73
- init_df_preprocessor,
74
- is_lazy_weight_tensor,
75
- load_text_tokenizers,
79
+ matcher_presets,
76
80
  on_fit_end_message,
77
81
  save_pretrained_model_configs,
78
- save_text_tokenizers,
79
- select_model,
80
82
  split_hyperparameters,
81
83
  update_config_by_rules,
82
- upgrade_config,
83
84
  )
85
+ from ..utils.problem_types import PROBLEM_TYPES_REG
84
86
  from .base import BaseLearner
85
87
 
86
88
  pl_logger = logging.getLogger("lightning")
@@ -88,9 +90,9 @@ pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/
88
90
  logger = logging.getLogger(__name__)
89
91
 
90
92
 
91
- class MultiModalMatcher(BaseLearner):
93
+ class MatchingLearner(BaseLearner):
92
94
  """
93
- MultiModalMatcher is a framework to learn/extract embeddings for multimodal data including image, text, and tabular.
95
+ MatchingLearner is a framework to learn/extract embeddings for multimodal data including image, text, and tabular.
94
96
  These embeddings can be used e.g. with cosine-similarity to find items with similar semantic meanings.
95
97
  This can be useful for computing the semantic similarity of two items, semantic search, paraphrase mining, etc.
96
98
  """
@@ -448,7 +450,7 @@ class MultiModalMatcher(BaseLearner):
448
450
  # top_k_average is called inside hyperparameter_tune() when building the final predictor.
449
451
  self.top_k_average(
450
452
  save_path=self._save_path,
451
- top_k_average_method=self._config.optimization.top_k_average_method,
453
+ top_k_average_method=self._config.optim.top_k_average_method,
452
454
  standalone=standalone,
453
455
  clean_ckpts=clean_ckpts,
454
456
  )
@@ -476,7 +478,7 @@ class MultiModalMatcher(BaseLearner):
476
478
  **kwargs,
477
479
  ):
478
480
  """
479
- Fit MultiModalMatcher. Train the model to learn embeddings to simultaneously maximize and minimize
481
+ Fit MatchingLearner. Train the model to learn embeddings to simultaneously maximize and minimize
480
482
  the semantic similarities of positive and negative pairs.
481
483
  The data may contain image, text, numeric, or categorical features.
482
484
 
@@ -538,7 +540,7 @@ class MultiModalMatcher(BaseLearner):
538
540
 
539
541
  Returns
540
542
  -------
541
- An "MultiModalMatcher" object (itself).
543
+ An "MatchingLearner" object (itself).
542
544
  """
543
545
  self.setup_save_path(save_path=save_path)
544
546
  training_start = self.on_fit_start(presets=presets)
@@ -805,7 +807,7 @@ class MultiModalMatcher(BaseLearner):
805
807
  label_processors_count = {k: len(v) for k, v in label_processors.items()}
806
808
  logger.debug(f"label_processors_count: {label_processors_count}")
807
809
 
808
- validation_metric, custom_metric_func = get_metric(
810
+ validation_metric, custom_metric_func = get_torchmetric(
809
811
  metric_name=self._validation_metric_name,
810
812
  num_classes=self._output_shape,
811
813
  is_matching=self._pipeline in matcher_presets.list_keys(),
@@ -863,17 +865,17 @@ class MultiModalMatcher(BaseLearner):
863
865
  validate_data=self._tuning_data,
864
866
  id_mappings=id_mappings,
865
867
  )
866
- optimization_kwargs = dict(
867
- optim_type=config.optimization.optim_type,
868
- lr_choice=config.optimization.lr_choice,
869
- lr_schedule=config.optimization.lr_schedule,
870
- lr=config.optimization.learning_rate,
871
- lr_decay=config.optimization.lr_decay,
872
- end_lr=config.optimization.end_lr,
873
- lr_mult=config.optimization.lr_mult,
874
- weight_decay=config.optimization.weight_decay,
875
- warmup_steps=config.optimization.warmup_steps,
876
- track_grad_norm=OmegaConf.select(config, "optimization.track_grad_norm", default=-1),
868
+ optim_kwargs = dict(
869
+ optim_type=config.optim.optim_type,
870
+ lr_choice=config.optim.lr_choice,
871
+ lr_schedule=config.optim.lr_schedule,
872
+ lr=config.optim.lr,
873
+ lr_decay=config.optim.lr_decay,
874
+ end_lr=config.optim.end_lr,
875
+ lr_mult=config.optim.lr_mult,
876
+ weight_decay=config.optim.weight_decay,
877
+ warmup_steps=config.optim.warmup_steps,
878
+ track_grad_norm=config.optim.track_grad_norm,
877
879
  )
878
880
  metrics_kwargs = dict(
879
881
  validation_metric=validation_metric,
@@ -893,7 +895,7 @@ class MultiModalMatcher(BaseLearner):
893
895
  loss_func=loss_func,
894
896
  miner_func=miner_func,
895
897
  **metrics_kwargs,
896
- **optimization_kwargs,
898
+ **optim_kwargs,
897
899
  )
898
900
  callbacks = self.get_callbacks_per_run(save_path=save_path, config=config, litmodule=litmodule)
899
901
  tb_logger = self.get_tb_logger(save_path=save_path)
@@ -1025,7 +1027,7 @@ class MultiModalMatcher(BaseLearner):
1025
1027
  ingredients = [top_k_model_paths[0]]
1026
1028
  else:
1027
1029
  raise ValueError(
1028
- f"The key for 'optimization.top_k_average_method' is not supported. "
1030
+ f"The key for 'optim.top_k_average_method' is not supported. "
1029
1031
  f"We only support '{GREEDY_SOUP}', '{UNIFORM_SOUP}' and '{BEST}'. "
1030
1032
  f"The provided value is '{top_k_average_method}'."
1031
1033
  )
@@ -1202,7 +1204,7 @@ class MultiModalMatcher(BaseLearner):
1202
1204
  df_preprocessor=df_preprocessor,
1203
1205
  data_processors=data_processors,
1204
1206
  per_gpu_batch_size=batch_size,
1205
- num_workers=self._config.env.num_workers_evaluation,
1207
+ num_workers=self._config.env.num_workers_inference,
1206
1208
  predict_data=data,
1207
1209
  id_mappings=id_mappings,
1208
1210
  )
@@ -1931,18 +1933,14 @@ class MultiModalMatcher(BaseLearner):
1931
1933
  # Save text tokenizers before saving data processors
1932
1934
  query_processors = copy.deepcopy(query_processors)
1933
1935
  if TEXT in query_processors:
1934
- query_processors[TEXT] = save_text_tokenizers(
1935
- text_processors=query_processors[TEXT],
1936
- path=path,
1937
- )
1936
+ for per_text_processor in query_processors[TEXT]:
1937
+ per_text_processor.save_tokenizer(path)
1938
1938
 
1939
1939
  # Save text tokenizers before saving data processors
1940
1940
  response_processors = copy.deepcopy(response_processors)
1941
1941
  if TEXT in response_processors:
1942
- response_processors[TEXT] = save_text_tokenizers(
1943
- text_processors=response_processors[TEXT],
1944
- path=path,
1945
- )
1942
+ for per_text_processor in response_processors[TEXT]:
1943
+ per_text_processor.save_tokenizer(path)
1946
1944
 
1947
1945
  data_processors = {
1948
1946
  QUERY: query_processors,
@@ -1955,7 +1953,7 @@ class MultiModalMatcher(BaseLearner):
1955
1953
  with open(os.path.join(path, f"assets.json"), "w") as fp:
1956
1954
  json.dump(
1957
1955
  {
1958
- "class_name": self.__class__.__name__,
1956
+ "learner_class": self.__class__.__name__,
1959
1957
  "query": self._query,
1960
1958
  "response": self._response,
1961
1959
  "match_label": self._match_label,
@@ -1990,7 +1988,7 @@ class MultiModalMatcher(BaseLearner):
1990
1988
 
1991
1989
  @staticmethod
1992
1990
  def _load_metadata(
1993
- matcher: MultiModalMatcher,
1991
+ matcher: MatchingLearner,
1994
1992
  path: str,
1995
1993
  resume: Optional[bool] = False,
1996
1994
  verbosity: Optional[int] = 3,
@@ -2013,11 +2011,8 @@ class MultiModalMatcher(BaseLearner):
2013
2011
  with open(os.path.join(path, "assets.json"), "r") as fp:
2014
2012
  assets = json.load(fp)
2015
2013
 
2016
- query_config = upgrade_config(query_config, assets["version"])
2017
- response_config = upgrade_config(response_config, assets["version"])
2018
-
2019
2014
  with open(os.path.join(path, "df_preprocessor.pkl"), "rb") as fp:
2020
- df_preprocessor = CustomUnpickler(fp).load()
2015
+ df_preprocessor = pickle.load(fp) # nosec B301
2021
2016
 
2022
2017
  query_df_preprocessor = df_preprocessor[QUERY]
2023
2018
  response_df_preprocessor = df_preprocessor[RESPONSE]
@@ -2025,7 +2020,7 @@ class MultiModalMatcher(BaseLearner):
2025
2020
 
2026
2021
  try:
2027
2022
  with open(os.path.join(path, "data_processors.pkl"), "rb") as fp:
2028
- data_processors = CustomUnpickler(fp).load()
2023
+ data_processors = pickle.load(fp) # nosec B301
2029
2024
 
2030
2025
  query_processors = data_processors[QUERY]
2031
2026
  response_processors = data_processors[RESPONSE]
@@ -2033,32 +2028,20 @@ class MultiModalMatcher(BaseLearner):
2033
2028
 
2034
2029
  # Load text tokenizers after loading data processors.
2035
2030
  if TEXT in query_processors:
2036
- query_processors[TEXT] = load_text_tokenizers(
2037
- text_processors=query_processors[TEXT],
2038
- path=path,
2039
- )
2040
- # backward compatibility. Add feature column names in each data processor.
2041
- query_processors = assign_feature_column_names(
2042
- data_processors=query_processors,
2043
- df_preprocessor=query_df_preprocessor,
2044
- )
2031
+ for per_text_processor in query_processors[TEXT]:
2032
+ per_text_processor.load_tokenizer(path)
2033
+
2045
2034
  # Only keep the modalities with non-empty processors.
2046
2035
  query_processors = {k: v for k, v in query_processors.items() if len(v) > 0}
2047
2036
 
2048
2037
  # Load text tokenizers after loading data processors.
2049
2038
  if TEXT in response_processors:
2050
- response_processors[TEXT] = load_text_tokenizers(
2051
- text_processors=response_processors[TEXT],
2052
- path=path,
2053
- )
2054
- # backward compatibility. Add feature column names in each data processor.
2055
- response_processors = assign_feature_column_names(
2056
- data_processors=response_processors,
2057
- df_preprocessor=response_df_preprocessor,
2058
- )
2039
+ for per_text_processor in response_processors[TEXT]:
2040
+ per_text_processor.load_tokenizer(path)
2041
+
2059
2042
  # Only keep the modalities with non-empty processors.
2060
2043
  response_processors = {k: v for k, v in response_processors.items() if len(v) > 0}
2061
- except: # backward compatibility. reconstruct the data processor in case something went wrong.
2044
+ except: # reconstruct the data processor in case something went wrong.
2062
2045
  query_processors = None
2063
2046
  response_processors = None
2064
2047
  label_processors = None
@@ -2069,19 +2052,14 @@ class MultiModalMatcher(BaseLearner):
2069
2052
  matcher._label_column = assets["label_column"]
2070
2053
  matcher._problem_type = assets["problem_type"]
2071
2054
  matcher._pipeline = assets["pipeline"]
2072
- if "presets" in assets:
2073
- matcher._presets = assets["presets"]
2055
+ matcher._presets = assets["presets"]
2074
2056
  matcher._eval_metric_name = assets["eval_metric_name"]
2075
2057
  matcher._verbosity = verbosity
2076
2058
  matcher._resume = resume
2077
2059
  matcher._save_path = path # in case the original exp dir is copied to somewhere else
2078
2060
  matcher._pretrained_path = path
2079
- if "pretrained" in assets:
2080
- matcher._pretrained = assets["pretrained"]
2081
- if "fit_called" in assets:
2082
- matcher._fit_called = assets["fit_called"]
2083
- else:
2084
- matcher._fit_called = True # backward compatible
2061
+ matcher._pretrained = assets["pretrained"]
2062
+ matcher._fit_called = assets["fit_called"]
2085
2063
  matcher._config = config
2086
2064
  matcher._query_config = query_config
2087
2065
  matcher._response_config = response_config
@@ -2094,10 +2072,7 @@ class MultiModalMatcher(BaseLearner):
2094
2072
  matcher._query_processors = query_processors
2095
2073
  matcher._response_processors = response_processors
2096
2074
  matcher._label_processors = label_processors
2097
- if "minmax_mode" in assets:
2098
- matcher._minmax_mode = assets["minmax_mode"]
2099
- else:
2100
- matcher._minmax_mode = get_minmax_mode(matcher._validation_metric_name)
2075
+ matcher._minmax_mode = assets["minmax_mode"]
2101
2076
 
2102
2077
  return matcher
2103
2078
 
@@ -7,15 +7,16 @@ from typing import Callable, Dict, List, Optional, Union
7
7
 
8
8
  import lightning.pytorch as pl
9
9
  import pandas as pd
10
- from omegaconf import DictConfig, OmegaConf
10
+ from omegaconf import DictConfig
11
11
  from torch import nn
12
12
 
13
13
  from autogluon.core.metrics import Scorer
14
14
 
15
15
  from ..constants import NER, NER_RET, Y_PRED, Y_TRUE
16
16
  from ..data import MultiModalFeaturePreprocessor
17
- from ..optimization import NerLitModule, get_metric
18
- from ..utils import compute_score, create_fusion_model, extract_from_output, merge_bio_format
17
+ from ..models import create_fusion_model
18
+ from ..optim import NerLitModule, compute_score, get_minmax_mode, get_torchmetric, infer_metrics
19
+ from ..utils import extract_from_output, merge_bio_format
19
20
  from .base import BaseLearner
20
21
 
21
22
  logger = logging.getLogger(__name__)
@@ -86,7 +87,7 @@ class NERLearner(BaseLearner):
86
87
  self._output_shape = output_shape # since ner infers output_shape in fit_per_run(), the learners needs to update the attribute afterwards.
87
88
 
88
89
  def get_validation_metric_per_run(self, output_shape: int):
89
- validation_metric, custom_metric_func = get_metric(
90
+ validation_metric, custom_metric_func = get_torchmetric(
90
91
  metric_name=self._validation_metric_name,
91
92
  num_classes=output_shape,
92
93
  problem_type=self._problem_type,
@@ -109,38 +110,38 @@ class NERLearner(BaseLearner):
109
110
  )
110
111
  return model
111
112
 
112
- def get_optimization_kwargs_per_run(self, config, validation_metric, custom_metric_func, loss_func):
113
+ def get_optim_kwargs_per_run(self, config, validation_metric, custom_metric_func, loss_func):
113
114
  return dict(
114
- optim_type=config.optimization.optim_type,
115
- lr_choice=config.optimization.lr_choice,
116
- lr_schedule=config.optimization.lr_schedule,
117
- lr=config.optimization.learning_rate,
118
- lr_decay=config.optimization.lr_decay,
119
- end_lr=config.optimization.end_lr,
120
- lr_mult=config.optimization.lr_mult,
121
- weight_decay=config.optimization.weight_decay,
122
- warmup_steps=config.optimization.warmup_steps,
123
- track_grad_norm=OmegaConf.select(config, "optimization.track_grad_norm", default=-1),
115
+ optim_type=config.optim.optim_type,
116
+ lr_choice=config.optim.lr_choice,
117
+ lr_schedule=config.optim.lr_schedule,
118
+ lr=config.optim.lr,
119
+ lr_decay=config.optim.lr_decay,
120
+ end_lr=config.optim.end_lr,
121
+ lr_mult=config.optim.lr_mult,
122
+ weight_decay=config.optim.weight_decay,
123
+ warmup_steps=config.optim.warmup_steps,
124
+ track_grad_norm=config.optim.track_grad_norm,
124
125
  validation_metric=validation_metric,
125
126
  validation_metric_name=self._validation_metric_name,
126
127
  custom_metric_func=custom_metric_func,
127
128
  loss_func=loss_func,
128
- efficient_finetune=OmegaConf.select(config, "optimization.efficient_finetune"),
129
- skip_final_val=OmegaConf.select(config, "optimization.skip_final_val", default=False),
129
+ peft=config.optim.peft,
130
+ skip_final_val=config.optim.skip_final_val,
130
131
  )
131
132
 
132
133
  def get_litmodule_per_run(
133
134
  self,
134
135
  model: Optional[nn.Module] = None,
135
136
  peft_param_names: Optional[List[str]] = None,
136
- optimization_kwargs: Optional[dict] = None,
137
+ optim_kwargs: Optional[dict] = None,
137
138
  is_train=True,
138
139
  ):
139
140
  if is_train:
140
141
  return NerLitModule(
141
142
  model=model,
142
143
  trainable_param_names=peft_param_names,
143
- **optimization_kwargs,
144
+ **optim_kwargs,
144
145
  )
145
146
  else:
146
147
  return NerLitModule(model=self._model)
@@ -214,7 +215,7 @@ class NERLearner(BaseLearner):
214
215
  advanced_hyperparameters=advanced_hyperparameters,
215
216
  )
216
217
  validation_metric, custom_metric_func = self.get_validation_metric_per_run(output_shape=output_shape)
217
- loss_func = self.get_loss_func_per_run(config=config)
218
+ loss_func, _ = self.get_loss_func_per_run(config=config)
218
219
  if max_time == timedelta(seconds=0):
219
220
  return dict(
220
221
  config=config,
@@ -230,7 +231,7 @@ class NERLearner(BaseLearner):
230
231
  per_gpu_batch_size=config.env.per_gpu_batch_size,
231
232
  num_workers=config.env.num_workers,
232
233
  )
233
- optimization_kwargs = self.get_optimization_kwargs_per_run(
234
+ optim_kwargs = self.get_optim_kwargs_per_run(
234
235
  config=config,
235
236
  validation_metric=validation_metric,
236
237
  custom_metric_func=custom_metric_func,
@@ -239,7 +240,7 @@ class NERLearner(BaseLearner):
239
240
  litmodule = self.get_litmodule_per_run(
240
241
  model=model,
241
242
  peft_param_names=peft_param_names,
242
- optimization_kwargs=optimization_kwargs,
243
+ optim_kwargs=optim_kwargs,
243
244
  )
244
245
  callbacks = self.get_callbacks_per_run(save_path=save_path, config=config, litmodule=litmodule)
245
246
  plugins = self.get_plugins_per_run(model=model, peft_param_names=peft_param_names)
@@ -5,25 +5,30 @@ from datetime import timedelta
5
5
  from typing import Dict, List, Optional, Union
6
6
 
7
7
  import pandas as pd
8
- from omegaconf import DictConfig, OmegaConf
8
+ from omegaconf import DictConfig
9
9
  from torch import nn
10
10
 
11
11
  from ..constants import BBOX, DDP, MAP, MULTI_IMAGE_MIX_DATASET, OBJECT_DETECTION, XYWH
12
- from ..data import BaseDataModule, MultiImageMixDataset, MultiModalFeaturePreprocessor, infer_rois_column_type
13
- from ..optimization import LitModule, MMDetLitModule
12
+ from ..data import (
13
+ BaseDataModule,
14
+ MultiImageMixDataset,
15
+ MultiModalFeaturePreprocessor,
16
+ infer_rois_column_type,
17
+ split_train_tuning_data,
18
+ )
19
+ from ..models import create_fusion_model
20
+ from ..optim import MMDetLitModule
14
21
  from ..utils import (
15
22
  check_if_packages_installed,
16
23
  cocoeval,
17
24
  convert_pred_to_xywh,
18
25
  convert_result_df,
19
- create_fusion_model,
20
26
  extract_from_output,
21
27
  from_coco_or_voc,
22
28
  get_detection_classes,
23
29
  object_detection_data_to_df,
24
30
  save_result_coco_format,
25
31
  setup_save_path,
26
- split_train_tuning_data,
27
32
  )
28
33
  from .base import BaseLearner
29
34
 
@@ -303,18 +308,18 @@ class ObjectDetectionLearner(BaseLearner):
303
308
 
304
309
  return num_gpus
305
310
 
306
- def get_optimization_kwargs_per_run(self, config, validation_metric, custom_metric_func):
311
+ def get_optim_kwargs_per_run(self, config, validation_metric, custom_metric_func):
307
312
  return dict(
308
- optim_type=config.optimization.optim_type,
309
- lr_choice=config.optimization.lr_choice,
310
- lr_schedule=config.optimization.lr_schedule,
311
- lr=config.optimization.learning_rate,
312
- lr_decay=config.optimization.lr_decay,
313
- end_lr=config.optimization.end_lr,
314
- lr_mult=config.optimization.lr_mult,
315
- weight_decay=config.optimization.weight_decay,
316
- warmup_steps=config.optimization.warmup_steps,
317
- track_grad_norm=OmegaConf.select(config, "optimization.track_grad_norm", default=-1),
313
+ optim_type=config.optim.optim_type,
314
+ lr_choice=config.optim.lr_choice,
315
+ lr_schedule=config.optim.lr_schedule,
316
+ lr=config.optim.lr,
317
+ lr_decay=config.optim.lr_decay,
318
+ end_lr=config.optim.end_lr,
319
+ lr_mult=config.optim.lr_mult,
320
+ weight_decay=config.optim.weight_decay,
321
+ warmup_steps=config.optim.warmup_steps,
322
+ track_grad_norm=config.optim.track_grad_norm,
318
323
  validation_metric=validation_metric,
319
324
  validation_metric_name=self._validation_metric_name,
320
325
  custom_metric_func=custom_metric_func,
@@ -323,7 +328,7 @@ class ObjectDetectionLearner(BaseLearner):
323
328
  def get_litmodule_per_run(
324
329
  self,
325
330
  model: Optional[nn.Module] = None,
326
- optimization_kwargs: Optional[dict] = None,
331
+ optim_kwargs: Optional[dict] = None,
327
332
  is_train=True,
328
333
  ):
329
334
  if self._problem_type == OBJECT_DETECTION:
@@ -334,7 +339,7 @@ class ObjectDetectionLearner(BaseLearner):
334
339
  if is_train:
335
340
  return LightningModule(
336
341
  model=model,
337
- **optimization_kwargs,
342
+ **optim_kwargs,
338
343
  )
339
344
  else:
340
345
  return LightningModule(model=self._model)
@@ -395,14 +400,14 @@ class ObjectDetectionLearner(BaseLearner):
395
400
  num_workers=config.env.num_workers,
396
401
  model_config=model.config,
397
402
  )
398
- optimization_kwargs = self.get_optimization_kwargs_per_run(
403
+ optim_kwargs = self.get_optim_kwargs_per_run(
399
404
  config=config,
400
405
  validation_metric=validation_metric,
401
406
  custom_metric_func=custom_metric_func,
402
407
  )
403
408
  litmodule = self.get_litmodule_per_run(
404
409
  model=model,
405
- optimization_kwargs=optimization_kwargs,
410
+ optim_kwargs=optim_kwargs,
406
411
  )
407
412
  callbacks = self.get_callbacks_per_run(save_path=save_path, config=config, litmodule=litmodule)
408
413
  plugins = self.get_plugins_per_run(model=model)
@@ -524,7 +529,7 @@ class ObjectDetectionLearner(BaseLearner):
524
529
  df_preprocessor=df_preprocessor,
525
530
  data_processors=data_processors,
526
531
  per_gpu_batch_size=batch_size,
527
- num_workers=self._config.env.num_workers_evaluation,
532
+ num_workers=self._config.env.num_workers_inference,
528
533
  predict_data=data,
529
534
  is_train=False,
530
535
  )
@@ -7,19 +7,17 @@ import numpy as np
7
7
  import pandas as pd
8
8
  import torch
9
9
  import torch.nn.functional as F
10
- from omegaconf import OmegaConf
11
10
  from PIL import Image
12
11
  from scipy.special import softmax
13
12
 
14
13
  from autogluon.core.metrics import Scorer
15
14
 
16
- from ..constants import LABEL, LOGITS, SEMANTIC_MASK, SEMANTIC_SEGMENTATION, SEMANTIC_SEGMENTATION_IMG
17
- from ..optimization.lit_semantic_seg import SemanticSegmentationLitModule
18
- from ..optimization.semantic_seg_metrics import Balanced_Error_Rate_Pred as Balanced_Error_Rate
19
- from ..optimization.semantic_seg_metrics import Binary_IoU_Pred as Binary_IoU
20
- from ..optimization.semantic_seg_metrics import COD_METRICS_NAMES_Pred as COD_METRICS_NAMES
21
- from ..optimization.semantic_seg_metrics import Multiclass_IoU_Pred as Multiclass_IoU
22
- from ..optimization.utils import get_loss_func, get_norm_layer_param_names, get_trainable_params_efficient_finetune
15
+ from ..constants import LABEL, LOGITS, SEMANTIC_MASK, SEMANTIC_SEGMENTATION_IMG
16
+ from ..optim import SemanticSegmentationLitModule, get_loss_func, get_norm_layer_param_names, get_peft_param_names
17
+ from ..optim.metrics.semantic_seg_metrics import Balanced_Error_Rate_Pred as Balanced_Error_Rate
18
+ from ..optim.metrics.semantic_seg_metrics import Binary_IoU_Pred as Binary_IoU
19
+ from ..optim.metrics.semantic_seg_metrics import COD_METRICS_NAMES_Pred as COD_METRICS_NAMES
20
+ from ..optim.metrics.semantic_seg_metrics import Multiclass_IoU_Pred as Multiclass_IoU
23
21
  from ..utils import extract_from_output, setup_save_path
24
22
  from .base import BaseLearner
25
23
 
@@ -122,24 +120,24 @@ class SemanticSegmentationLearner(BaseLearner):
122
120
  @staticmethod
123
121
  def get_peft_param_names_per_run(model, config):
124
122
  peft_param_names = None
125
- peft = OmegaConf.select(config, "optimization.efficient_finetune")
123
+ peft = config.optim.peft
126
124
  if peft:
127
125
  norm_param_names = get_norm_layer_param_names(model)
128
- peft_param_names = get_trainable_params_efficient_finetune(
126
+ peft_param_names = get_peft_param_names(
129
127
  norm_param_names,
130
- efficient_finetune=peft,
131
- extra_params=OmegaConf.select(config, "optimization.extra_trainable_params"),
128
+ peft=peft,
129
+ extra_params=config.optim.extra_trainable_params,
132
130
  )
133
131
  return peft_param_names
134
132
 
135
133
  def get_loss_func_per_run(self, config, mixup_active=None):
136
134
  loss_func = get_loss_func(
137
135
  problem_type=self._problem_type,
138
- loss_func_name=OmegaConf.select(config, "optimization.loss_function"),
139
- config=config.optimization,
136
+ loss_func_name=config.optim.loss_func,
137
+ config=config.optim,
140
138
  num_classes=self._output_shape,
141
139
  )
142
- return loss_func
140
+ return loss_func, None
143
141
 
144
142
  def evaluate_semantic_segmentation(
145
143
  self,
@@ -240,7 +238,7 @@ class SemanticSegmentationLearner(BaseLearner):
240
238
  model=None,
241
239
  model_postprocess_fn=None,
242
240
  peft_param_names=None,
243
- optimization_kwargs=None,
241
+ optim_kwargs=None,
244
242
  distillation_kwargs=None,
245
243
  is_train=True,
246
244
  ):
@@ -249,13 +247,13 @@ class SemanticSegmentationLearner(BaseLearner):
249
247
  model=model,
250
248
  model_postprocess_fn=model_postprocess_fn,
251
249
  trainable_param_names=peft_param_names,
252
- **optimization_kwargs,
250
+ **optim_kwargs,
253
251
  )
254
252
  else:
255
253
  return SemanticSegmentationLitModule(
256
254
  model=self._model,
257
255
  model_postprocess_fn=self._model_postprocess_fn,
258
- **optimization_kwargs,
256
+ **optim_kwargs,
259
257
  )
260
258
 
261
259
  def on_predict_start(self, data: pd.DataFrame):