autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -25,7 +25,7 @@ from torch import nn
25
25
 
26
26
  from autogluon.common.utils.resource_utils import ResourceManager
27
27
  from autogluon.common.utils.try_import import try_import_ray
28
- from autogluon.core.metrics import Scorer
28
+ from autogluon.core.metrics import Scorer, get_metric
29
29
  from autogluon.core.utils.loaders import load_pd
30
30
 
31
31
  from .. import version as ag_version
@@ -59,7 +59,6 @@ from ..constants import (
59
59
  NER,
60
60
  RAY_TUNE_CHECKPOINT,
61
61
  REGRESSION,
62
- ROIS,
63
62
  TEXT,
64
63
  TEXT_NER,
65
64
  TORCH_COMPILE_MIN_VERSION,
@@ -72,39 +71,51 @@ from ..constants import (
72
71
  from ..data import (
73
72
  BaseDataModule,
74
73
  MultiModalFeaturePreprocessor,
74
+ create_fusion_data_processors,
75
+ data_to_df,
76
+ get_mixup,
75
77
  infer_column_types,
78
+ infer_dtypes_by_model_names,
76
79
  infer_output_shape,
77
80
  infer_problem_type,
81
+ infer_scarcity_mode_by_data_size,
82
+ init_df_preprocessor,
78
83
  is_image_column,
84
+ split_train_tuning_data,
85
+ turn_on_off_feature_column_info,
86
+ )
87
+ from ..models import (
88
+ create_fusion_model,
89
+ get_model_postprocess_fn,
90
+ is_lazy_weight_tensor,
91
+ list_timm_models,
92
+ select_model,
79
93
  )
80
- from ..models import get_model_postprocess_fn
81
- from ..optimization.lit_distiller import DistillerLitModule
82
- from ..optimization.lit_module import LitModule
83
- from ..optimization.utils import (
94
+ from ..optim import (
95
+ compute_score,
96
+ get_aug_loss_func,
84
97
  get_loss_func,
85
- get_metric,
98
+ get_minmax_mode,
86
99
  get_norm_layer_param_names,
87
- get_trainable_params_efficient_finetune,
100
+ get_peft_param_names,
101
+ get_stopping_threshold,
102
+ get_torchmetric,
103
+ infer_metrics,
88
104
  )
89
- from ..problem_types import PROBLEM_TYPES_REG
105
+ from ..optim.lit_distiller import DistillerLitModule
106
+ from ..optim.lit_module import LitModule
90
107
  from ..utils import (
91
108
  AutoMMModelCheckpoint,
92
109
  AutoMMModelCheckpointIO,
93
- CustomUnpickler,
94
110
  DDPPredictionWriter,
95
111
  DistillationMixin,
96
112
  ExportMixin,
97
113
  LogFilter,
98
114
  RealtimeMixin,
99
115
  apply_log_filter,
100
- assign_feature_column_names,
101
116
  average_checkpoints,
102
117
  compute_inference_batch_size,
103
118
  compute_num_gpus,
104
- compute_score,
105
- create_fusion_data_processors,
106
- create_fusion_model,
107
- data_to_df,
108
119
  extract_from_output,
109
120
  filter_hyperparameters,
110
121
  get_config,
@@ -112,39 +123,25 @@ from ..utils import (
112
123
  get_gpu_message,
113
124
  get_load_ckpt_paths,
114
125
  get_local_pretrained_config_paths,
115
- get_minmax_mode,
116
- get_mixup,
117
- get_stopping_threshold,
118
126
  hyperparameter_tune,
119
- infer_dtypes_by_model_names,
120
- infer_metrics,
121
127
  infer_precision,
122
128
  infer_problem_type_by_eval_metric,
123
- infer_scarcity_mode_by_data_size,
124
- init_df_preprocessor,
125
129
  is_interactive_env,
126
130
  is_interactive_strategy,
127
- is_lazy_weight_tensor,
128
- list_timm_models,
129
- load_text_tokenizers,
130
131
  logits_to_prob,
131
132
  on_fit_end_message,
132
133
  on_fit_per_run_start_message,
133
134
  on_fit_start_message,
134
135
  run_ddp_only_once,
135
136
  save_pretrained_model_configs,
136
- save_text_tokenizers,
137
- select_model,
138
137
  setup_save_path,
139
138
  split_hyperparameters,
140
- split_train_tuning_data,
141
139
  tensor_to_ndarray,
142
- turn_on_off_feature_column_info,
143
140
  update_config_by_rules,
144
141
  update_hyperparameters,
145
142
  update_tabular_config_by_resources,
146
- upgrade_config,
147
143
  )
144
+ from ..utils.problem_types import PROBLEM_TYPES_REG
148
145
 
149
146
  pl_logger = logging.getLogger("lightning")
150
147
  pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/4621
@@ -253,6 +250,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
253
250
  self._eval_metric_func = None
254
251
  if isinstance(eval_metric, str):
255
252
  self._eval_metric_name = eval_metric.lower()
253
+ self.set_eval_metric_func()
256
254
  elif isinstance(eval_metric, Scorer):
257
255
  self._eval_metric_name = eval_metric.name
258
256
  self._eval_metric_func = eval_metric
@@ -350,6 +348,18 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
350
348
  )
351
349
  return model_size * 1e-6 # convert to megabytes
352
350
 
351
+ def set_eval_metric_func(self):
352
+ from .matching import MatchingLearner
353
+ from .ner import NERLearner
354
+ from .object_detection import ObjectDetectionLearner
355
+ from .semantic_segmentation import SemanticSegmentationLearner
356
+
357
+ if (
358
+ not isinstance(self, (NERLearner, SemanticSegmentationLearner, MatchingLearner, ObjectDetectionLearner))
359
+ and self._eval_metric_func is None
360
+ ):
361
+ self._eval_metric_func = get_metric(self._eval_metric_name)
362
+
353
363
  def ensure_fit_ready(self):
354
364
  if self._problem_type and not self.problem_property.support_fit:
355
365
  raise RuntimeError(
@@ -482,6 +492,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
482
492
  validation_metric_name=self._validation_metric_name,
483
493
  is_matching=is_matching,
484
494
  )
495
+ self.set_eval_metric_func()
485
496
  self._minmax_mode = get_minmax_mode(self._validation_metric_name)
486
497
  logger.debug(f"validation_metric_name: {self._validation_metric_name}")
487
498
  logger.debug(f"minmax_mode: {self._minmax_mode}")
@@ -598,7 +609,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
598
609
  # top_k_average is called inside hyperparameter_tune() when building the final predictor.
599
610
  self.top_k_average(
600
611
  save_path=self._save_path,
601
- top_k_average_method=self._config.optimization.top_k_average_method,
612
+ top_k_average_method=self._config.optim.top_k_average_method,
602
613
  strategy=strategy,
603
614
  strict_loading=strict_loading,
604
615
  # Not strict loading if using parameter-efficient finetuning
@@ -750,12 +761,13 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
750
761
  num_classes=self._output_shape,
751
762
  num_numerical_columns=len(df_preprocessor.numerical_feature_names),
752
763
  num_categories=df_preprocessor.categorical_num_categories,
764
+ numerical_fill_values=df_preprocessor.numerical_fill_values,
753
765
  )
754
766
  return model
755
767
 
756
768
  @staticmethod
757
769
  def compile_model_per_run(config, model):
758
- if OmegaConf.select(config, "env.compile.turn_on", default=False):
770
+ if config.env.compile.turn_on:
759
771
  assert version.parse(torch.__version__) >= version.parse(TORCH_COMPILE_MIN_VERSION), (
760
772
  f"torch.compile requires torch version >= {TORCH_COMPILE_MIN_VERSION}, "
761
773
  f"but torch version {torch.__version__} is detected."
@@ -763,21 +775,21 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
763
775
  logger.debug("Using torch.compile() in compiling the model.")
764
776
  model = torch.compile(
765
777
  model,
766
- mode=OmegaConf.select(config, "env.compile.mode", default="default"),
767
- dynamic=OmegaConf.select(config, "env.compile.dynamic", default=True),
768
- backend=OmegaConf.select(config, "env.compile.backend", default="inductor"),
778
+ mode=config.env.compile.mode,
779
+ dynamic=config.env.compile.dynamic,
780
+ backend=config.env.compile.backend,
769
781
  )
770
782
  return model
771
783
 
772
784
  @staticmethod
773
785
  def get_peft_param_names_per_run(model, config):
774
786
  peft_param_names = None
775
- peft = OmegaConf.select(config, "optimization.efficient_finetune")
787
+ peft = config.optim.peft
776
788
  if peft:
777
789
  norm_param_names = get_norm_layer_param_names(model)
778
- peft_param_names = get_trainable_params_efficient_finetune(
790
+ peft_param_names = get_peft_param_names(
779
791
  norm_param_names,
780
- efficient_finetune=peft,
792
+ peft=peft,
781
793
  )
782
794
  return peft_param_names
783
795
 
@@ -809,7 +821,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
809
821
  return data_processors
810
822
 
811
823
  def get_validation_metric_per_run(self):
812
- validation_metric, custom_metric_func = get_metric(
824
+ validation_metric, custom_metric_func = get_torchmetric(
813
825
  metric_name=self._validation_metric_name,
814
826
  num_classes=self._output_shape,
815
827
  problem_type=self._problem_type,
@@ -818,8 +830,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
818
830
 
819
831
  def get_mixup_func_per_run(self, config):
820
832
  mixup_active, mixup_func = get_mixup(
821
- model_config=OmegaConf.select(config, "model"),
822
- mixup_config=OmegaConf.select(config, "data.mixup"),
833
+ model_config=config.model,
834
+ mixup_config=config.data.mixup,
823
835
  num_classes=self._output_shape,
824
836
  )
825
837
  if mixup_active and (config.env.per_gpu_batch_size == 1 or config.env.per_gpu_batch_size % 2 == 1):
@@ -834,10 +846,14 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
834
846
  loss_func = get_loss_func(
835
847
  problem_type=self._problem_type,
836
848
  mixup_active=mixup_active,
837
- loss_func_name=OmegaConf.select(config, "optimization.loss_function"),
838
- config=config.optimization,
849
+ loss_func_name=config.optim.loss_func,
850
+ config=config.optim,
839
851
  )
840
- return loss_func
852
+ aug_loss_func = get_aug_loss_func(
853
+ config=config.optim,
854
+ problem_type=self._problem_type,
855
+ )
856
+ return loss_func, aug_loss_func
841
857
 
842
858
  def get_model_postprocess_fn_per_run(self, loss_func):
843
859
  model_postprocess_fn = get_model_postprocess_fn(
@@ -872,26 +888,46 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
872
888
  datamodule = BaseDataModule(**datamodule_kwargs)
873
889
  return datamodule
874
890
 
875
- def get_optimization_kwargs_per_run(self, config, validation_metric, custom_metric_func, loss_func, mixup_func):
891
+ def get_optim_kwargs_per_run(
892
+ self,
893
+ config,
894
+ validation_metric,
895
+ custom_metric_func,
896
+ loss_func,
897
+ aug_loss_func,
898
+ mixup_func,
899
+ grad_steps,
900
+ ):
876
901
  return dict(
877
- optim_type=config.optimization.optim_type,
878
- lr_choice=config.optimization.lr_choice,
879
- lr_schedule=config.optimization.lr_schedule,
880
- lr=config.optimization.learning_rate,
881
- lr_decay=config.optimization.lr_decay,
882
- end_lr=config.optimization.end_lr,
883
- lr_mult=config.optimization.lr_mult,
884
- weight_decay=config.optimization.weight_decay,
885
- warmup_steps=config.optimization.warmup_steps,
886
- track_grad_norm=OmegaConf.select(config, "optimization.track_grad_norm", default=-1),
902
+ optim_type=config.optim.optim_type,
903
+ lr_choice=config.optim.lr_choice,
904
+ lr_schedule=config.optim.lr_schedule,
905
+ lr=config.optim.lr,
906
+ lr_decay=config.optim.lr_decay,
907
+ end_lr=config.optim.end_lr,
908
+ lr_mult=config.optim.lr_mult,
909
+ weight_decay=config.optim.weight_decay,
910
+ warmup_steps=config.optim.warmup_steps,
911
+ track_grad_norm=config.optim.track_grad_norm,
887
912
  validation_metric=validation_metric,
888
913
  validation_metric_name=self._validation_metric_name,
889
914
  custom_metric_func=custom_metric_func,
890
915
  loss_func=loss_func,
891
916
  mixup_fn=mixup_func,
892
- efficient_finetune=OmegaConf.select(config, "optimization.efficient_finetune"),
893
- mixup_off_epoch=OmegaConf.select(config, "data.mixup.turn_off_epoch"),
894
- skip_final_val=OmegaConf.select(config, "optimization.skip_final_val", default=False),
917
+ peft=config.optim.peft,
918
+ mixup_off_epoch=config.data.mixup.turn_off_epoch,
919
+ skip_final_val=config.optim.skip_final_val,
920
+ cross_modal_align=config.optim.cross_modal_align,
921
+ cross_modal_align_weight=config.optim.cross_modal_align_weight,
922
+ automatic_optimization=config.optim.automatic_optimization,
923
+ accumulate_grad_batches=grad_steps,
924
+ gradient_clip_val=config.optim.gradient_clip_val,
925
+ gradient_clip_algorithm=config.optim.gradient_clip_algorithm,
926
+ use_aug_optim=config.optim.lemda.turn_on,
927
+ aug_loss_func=aug_loss_func,
928
+ aug_lr=config.optim.lemda.lr,
929
+ aug_weight_decay=config.optim.lemda.weight_decay,
930
+ aug_optim_type=config.optim.lemda.optim_type,
895
931
  )
896
932
 
897
933
  def get_litmodule_per_run(
@@ -899,7 +935,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
899
935
  model=None,
900
936
  model_postprocess_fn=None,
901
937
  peft_param_names=None,
902
- optimization_kwargs=dict(),
938
+ optim_kwargs=dict(),
903
939
  distillation_kwargs=dict(),
904
940
  is_train=True,
905
941
  ):
@@ -908,7 +944,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
908
944
  return DistillerLitModule(
909
945
  student_model=model,
910
946
  teacher_model=self._teacher_learner._model,
911
- **optimization_kwargs,
947
+ **optim_kwargs,
912
948
  **distillation_kwargs,
913
949
  )
914
950
  else:
@@ -916,13 +952,13 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
916
952
  model=model,
917
953
  model_postprocess_fn=model_postprocess_fn,
918
954
  trainable_param_names=peft_param_names,
919
- **optimization_kwargs,
955
+ **optim_kwargs,
920
956
  )
921
957
  else:
922
958
  return LitModule(
923
959
  model=self._model,
924
960
  model_postprocess_fn=self._model_postprocess_fn,
925
- **optimization_kwargs,
961
+ **optim_kwargs,
926
962
  )
927
963
 
928
964
  def get_callbacks_per_run(self, save_path=None, config=None, litmodule=None, pred_writer=None, is_train=True):
@@ -935,7 +971,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
935
971
 
936
972
  checkpoint_callback = AutoMMModelCheckpoint(
937
973
  dirpath=save_path,
938
- save_top_k=config.optimization.top_k,
974
+ save_top_k=config.optim.top_k,
939
975
  verbose=True,
940
976
  monitor=litmodule.validation_metric_name,
941
977
  mode=self._minmax_mode,
@@ -943,7 +979,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
943
979
  )
944
980
  early_stopping_callback = pl.callbacks.EarlyStopping(
945
981
  monitor=litmodule.validation_metric_name,
946
- patience=config.optimization.patience,
982
+ patience=config.optim.patience,
947
983
  mode=self._minmax_mode,
948
984
  stopping_threshold=get_stopping_threshold(self._validation_metric_name),
949
985
  )
@@ -1020,7 +1056,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1020
1056
  assert (
1021
1057
  version.parse(pl.__version__) >= version.parse(DEEPSPEED_MIN_PL_VERSION)
1022
1058
  ), f"For DeepSpeed Offloading to work reliably you need at least lightning version {DEEPSPEED_MIN_PL_VERSION}, however, found {pl.__version__}. Please update your lightning version."
1023
- from ..optimization.deepspeed import CustomDeepSpeedStrategy
1059
+ from ..optim.deepspeed import CustomDeepSpeedStrategy
1024
1060
 
1025
1061
  strategy = CustomDeepSpeedStrategy(
1026
1062
  stage=3,
@@ -1059,7 +1095,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1059
1095
 
1060
1096
  num_gpus = compute_num_gpus(
1061
1097
  config_num_gpus=config.env.num_gpus,
1062
- accelerator=OmegaConf.select(config, "env.accelerator", default="auto"),
1098
+ accelerator=config.env.accelerator,
1063
1099
  )
1064
1100
  num_gpus = self.update_num_gpus_by_data_size(num_gpus=num_gpus, data=data)
1065
1101
  strategy = self.get_strategy_per_run(num_gpus=num_gpus, config=config)
@@ -1096,36 +1132,38 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1096
1132
  is_train=True,
1097
1133
  ):
1098
1134
  if is_train:
1135
+ trainer_kwargs = dict(
1136
+ accelerator="gpu" if num_gpus > 0 else config.env.accelerator,
1137
+ devices=num_gpus if num_gpus > 0 else "auto",
1138
+ num_nodes=config.env.num_nodes,
1139
+ precision=precision,
1140
+ strategy=strategy if strategy else "auto",
1141
+ benchmark=False,
1142
+ deterministic=config.env.deterministic,
1143
+ max_epochs=config.optim.max_epochs,
1144
+ max_steps=config.optim.max_steps,
1145
+ max_time=max_time,
1146
+ callbacks=callbacks,
1147
+ logger=tb_logger,
1148
+ log_every_n_steps=config.optim.log_every_n_steps,
1149
+ enable_progress_bar=enable_progress_bar,
1150
+ fast_dev_run=config.env.fast_dev_run,
1151
+ val_check_interval=config.optim.val_check_interval,
1152
+ check_val_every_n_epoch=config.optim.check_val_every_n_epoch,
1153
+ plugins=plugins,
1154
+ )
1155
+ if config.optim.automatic_optimization:
1156
+ trainer_kwargs.update(
1157
+ dict(
1158
+ gradient_clip_val=config.optim.gradient_clip_val,
1159
+ gradient_clip_algorithm=config.optim.gradient_clip_algorithm,
1160
+ accumulate_grad_batches=grad_steps,
1161
+ )
1162
+ )
1099
1163
  blacklist_msgs = ["already configured with model summary"]
1100
1164
  log_filter = LogFilter(blacklist_msgs)
1101
1165
  with apply_log_filter(log_filter):
1102
- trainer = pl.Trainer(
1103
- accelerator="gpu" if num_gpus > 0 else OmegaConf.select(config, "env.accelerator", default="auto"),
1104
- devices=num_gpus if num_gpus > 0 else "auto",
1105
- num_nodes=config.env.num_nodes,
1106
- precision=precision,
1107
- strategy=strategy if strategy else "auto",
1108
- benchmark=False,
1109
- deterministic=config.env.deterministic,
1110
- max_epochs=config.optimization.max_epochs,
1111
- max_steps=config.optimization.max_steps,
1112
- max_time=max_time,
1113
- callbacks=callbacks,
1114
- logger=tb_logger,
1115
- gradient_clip_val=OmegaConf.select(config, "optimization.gradient_clip_val", default=1),
1116
- gradient_clip_algorithm=OmegaConf.select(
1117
- config, "optimization.gradient_clip_algorithm", default="norm"
1118
- ),
1119
- accumulate_grad_batches=grad_steps,
1120
- log_every_n_steps=OmegaConf.select(config, "optimization.log_every_n_steps", default=10),
1121
- enable_progress_bar=enable_progress_bar,
1122
- fast_dev_run=config.env.fast_dev_run,
1123
- val_check_interval=config.optimization.val_check_interval,
1124
- check_val_every_n_epoch=config.optimization.check_val_every_n_epoch
1125
- if hasattr(config.optimization, "check_val_every_n_epoch")
1126
- else 1,
1127
- plugins=plugins,
1128
- )
1166
+ trainer = pl.Trainer(**trainer_kwargs)
1129
1167
  else:
1130
1168
  blacklist_msgs = []
1131
1169
  if self._verbosity <= 3: # turn off logging in prediction
@@ -1140,9 +1178,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1140
1178
 
1141
1179
  with apply_log_filter(log_filter):
1142
1180
  trainer = pl.Trainer(
1143
- accelerator="gpu"
1144
- if num_gpus > 0
1145
- else OmegaConf.select(self._config, "env.accelerator", default="auto"),
1181
+ accelerator="gpu" if num_gpus > 0 else self._config.env.accelerator,
1146
1182
  devices=num_gpus if num_gpus > 0 else "auto",
1147
1183
  num_nodes=self._config.env.num_nodes,
1148
1184
  precision=precision,
@@ -1253,8 +1289,11 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1253
1289
  )
1254
1290
  validation_metric, custom_metric_func = self.get_validation_metric_per_run()
1255
1291
  mixup_active, mixup_func = self.get_mixup_func_per_run(config=config)
1256
- loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
1292
+ loss_func, aug_loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
1257
1293
  model_postprocess_fn = self.get_model_postprocess_fn_per_run(loss_func=loss_func)
1294
+ num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)
1295
+ precision = self.get_precision_per_run(num_gpus=num_gpus, precision=config.env.precision)
1296
+ grad_steps = self.get_grad_steps(num_gpus=num_gpus, config=config)
1258
1297
 
1259
1298
  if max_time == timedelta(seconds=0):
1260
1299
  return dict(
@@ -1278,26 +1317,25 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1278
1317
  per_gpu_batch_size=config.env.per_gpu_batch_size,
1279
1318
  num_workers=config.env.num_workers,
1280
1319
  )
1281
- optimization_kwargs = self.get_optimization_kwargs_per_run(
1320
+ optim_kwargs = self.get_optim_kwargs_per_run(
1282
1321
  config=config,
1283
1322
  validation_metric=validation_metric,
1284
1323
  custom_metric_func=custom_metric_func,
1285
1324
  loss_func=loss_func,
1325
+ aug_loss_func=aug_loss_func,
1286
1326
  mixup_func=mixup_func,
1327
+ grad_steps=grad_steps,
1287
1328
  )
1288
1329
  litmodule = self.get_litmodule_per_run(
1289
1330
  model=model,
1290
1331
  model_postprocess_fn=model_postprocess_fn,
1291
1332
  peft_param_names=peft_param_names,
1292
- optimization_kwargs=optimization_kwargs,
1333
+ optim_kwargs=optim_kwargs,
1293
1334
  distillation_kwargs=distillation_kwargs,
1294
1335
  )
1295
1336
  callbacks = self.get_callbacks_per_run(save_path=save_path, config=config, litmodule=litmodule)
1296
1337
  plugins = self.get_plugins_per_run(model=model, peft_param_names=peft_param_names)
1297
1338
  tb_logger = self.get_tb_logger(save_path=save_path)
1298
- num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)
1299
- precision = self.get_precision_per_run(num_gpus=num_gpus, precision=config.env.precision)
1300
- grad_steps = self.get_grad_steps(num_gpus=num_gpus, config=config)
1301
1339
  config = self.post_update_config_per_run(
1302
1340
  config=config,
1303
1341
  num_gpus=num_gpus,
@@ -1361,9 +1399,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1361
1399
  standalone=True,
1362
1400
  clean_ckpts=True,
1363
1401
  ):
1364
- minmax_mode = get_minmax_mode(
1365
- self._eval_metric_name if self._eval_metric_func is None else self._eval_metric_func
1366
- )
1402
+ eval_metric = self._eval_metric_name if self._eval_metric_func is None else self._eval_metric_func
1403
+ minmax_mode = get_minmax_mode(eval_metric)
1367
1404
  best_k_models_yaml_path = os.path.join(save_path, BEST_K_MODELS_FILE)
1368
1405
  if os.path.exists(best_k_models_yaml_path):
1369
1406
  with open(best_k_models_yaml_path, "r") as f:
@@ -1409,9 +1446,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1409
1446
  prefix=prefix,
1410
1447
  strict=strict_loading,
1411
1448
  )
1412
- best_score = self.evaluate(self._tuning_data, metrics=[self._eval_metric_name])[
1413
- self._eval_metric_name
1414
- ]
1449
+ best_score = self.evaluate(self._tuning_data, metrics=[eval_metric])
1450
+ best_score = next(iter(best_score.values()))
1415
1451
  for i in range(1, len(top_k_model_paths)):
1416
1452
  cand_avg_state_dict = average_checkpoints(
1417
1453
  checkpoint_paths=ingredients + [top_k_model_paths[i]],
@@ -1421,9 +1457,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1421
1457
  prefix=prefix,
1422
1458
  strict=strict_loading,
1423
1459
  )
1424
- cand_score = self.evaluate(self._tuning_data, metrics=[self._eval_metric_name])[
1425
- self._eval_metric_name
1426
- ]
1460
+ cand_score = self.evaluate(self._tuning_data, metrics=[eval_metric])
1461
+ cand_score = next(iter(cand_score.values()))
1427
1462
  if monitor_op(cand_score, best_score):
1428
1463
  # Add new ingredient
1429
1464
  ingredients.append(top_k_model_paths[i])
@@ -1432,7 +1467,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1432
1467
  ingredients = [top_k_model_paths[0]]
1433
1468
  else:
1434
1469
  raise ValueError(
1435
- f"The key for 'optimization.top_k_average_method' is not supported. "
1470
+ f"The key for 'optim.top_k_average_method' is not supported. "
1436
1471
  f"We only support '{GREEDY_SOUP}', '{UNIFORM_SOUP}' and '{BEST}'. "
1437
1472
  f"The provided value is '{top_k_average_method}'."
1438
1473
  )
@@ -1495,7 +1530,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1495
1530
  # TODO: Using optimiation_kwargs for inference is confusing and bad design. Remove as soon as fixed in lightning.
1496
1531
  if self._config.env.strategy == DEEPSPEED_OFFLOADING and DEEPSPEED_MODULE not in sys.modules:
1497
1532
  # Need to initialize DeepSpeed and optimizer as currently required in lightning's integration of deepspeed.
1498
- from ..optimization.deepspeed import CustomDeepSpeedStrategy
1533
+ from ..optim.deepspeed import CustomDeepSpeedStrategy
1499
1534
 
1500
1535
  strategy = CustomDeepSpeedStrategy(
1501
1536
  stage=3,
@@ -1505,21 +1540,21 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1505
1540
  reduce_bucket_size=self._config.env.deepspeed_allreduce_size,
1506
1541
  )
1507
1542
 
1508
- optimization_kwargs = dict(
1509
- optim_type=self._config.optimization.optim_type,
1510
- lr_choice=self._config.optimization.lr_choice,
1511
- lr_schedule=self._config.optimization.lr_schedule,
1512
- lr=self._config.optimization.learning_rate,
1513
- lr_decay=self._config.optimization.lr_decay,
1514
- end_lr=self._config.optimization.end_lr,
1515
- lr_mult=self._config.optimization.lr_mult,
1516
- weight_decay=self._config.optimization.weight_decay,
1517
- warmup_steps=self._config.optimization.warmup_steps,
1543
+ optim_kwargs = dict(
1544
+ optim_type=self._config.optim.optim_type,
1545
+ lr_choice=self._config.optim.lr_choice,
1546
+ lr_schedule=self._config.optim.lr_schedule,
1547
+ lr=self._config.optim.lr,
1548
+ lr_decay=self._config.optim.lr_decay,
1549
+ end_lr=self._config.optim.end_lr,
1550
+ lr_mult=self._config.optim.lr_mult,
1551
+ weight_decay=self._config.optim.weight_decay,
1552
+ warmup_steps=self._config.optim.warmup_steps,
1518
1553
  )
1519
1554
  else:
1520
- optimization_kwargs = {}
1555
+ optim_kwargs = {}
1521
1556
 
1522
- return strategy, optimization_kwargs
1557
+ return strategy, optim_kwargs
1523
1558
 
1524
1559
  def get_pred_writer(self, strategy):
1525
1560
  pred_writer = None
@@ -1650,9 +1685,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1650
1685
  def get_predict_batch_size_per_run(self, num_gpus: int, strategy: str):
1651
1686
  return compute_inference_batch_size(
1652
1687
  per_gpu_batch_size=self._config.env.per_gpu_batch_size,
1653
- eval_batch_size_ratio=OmegaConf.select(self._config, "env.eval_batch_size_ratio"),
1654
- per_gpu_batch_size_evaluation=self._config.env.per_gpu_batch_size_evaluation,
1655
- # backward compatibility.
1688
+ inference_batch_size_ratio=self._config.env.inference_batch_size_ratio,
1656
1689
  num_gpus=num_gpus,
1657
1690
  strategy=strategy,
1658
1691
  )
@@ -1737,19 +1770,19 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
1737
1770
  )
1738
1771
  return outputs
1739
1772
 
1740
- strategy, optimization_kwargs = self.prepare_deepspeed_offloading(strategy=strategy)
1773
+ strategy, optim_kwargs = self.prepare_deepspeed_offloading(strategy=strategy)
1741
1774
  datamodule = self.get_datamodule_per_run(
1742
1775
  df_preprocessor=df_preprocessor,
1743
1776
  data_processors=data_processors,
1744
1777
  per_gpu_batch_size=batch_size,
1745
- num_workers=self._config.env.num_workers_evaluation,
1778
+ num_workers=self._config.env.num_workers_inference,
1746
1779
  predict_data=data,
1747
1780
  is_train=False,
1748
1781
  )
1749
1782
  pred_writer = self.get_pred_writer(strategy=strategy)
1750
1783
  callbacks = self.get_callbacks_per_run(pred_writer=pred_writer, is_train=False)
1751
- # TODO: remove optimization_kwargs from inference
1752
- litmodule = self.get_litmodule_per_run(optimization_kwargs=optimization_kwargs, is_train=False)
1784
+ # TODO: remove optim_kwargs from inference
1785
+ litmodule = self.get_litmodule_per_run(optim_kwargs=optim_kwargs, is_train=False)
1753
1786
  trainer = self.init_trainer_per_run(
1754
1787
  num_gpus=num_gpus,
1755
1788
  precision=precision,
@@ -2167,7 +2200,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2167
2200
  config = config if config else self._config
2168
2201
  config = copy.deepcopy(config)
2169
2202
  model = model if model else self._model
2170
- if standalone and not OmegaConf.select(config, "optimization.efficient_finetune"):
2203
+ if standalone and not config.optim.peft:
2171
2204
  config = save_pretrained_model_configs(model=model, config=config, path=path)
2172
2205
  os.makedirs(path, exist_ok=True)
2173
2206
  OmegaConf.save(config=config, f=os.path.join(path, "config.yaml"))
@@ -2182,10 +2215,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2182
2215
  # Save text tokenizers before saving data processors
2183
2216
  for modality in [TEXT, TEXT_NER, NER, DOCUMENT]:
2184
2217
  if modality in data_processors:
2185
- data_processors[modality] = save_text_tokenizers(
2186
- text_processors=data_processors[modality],
2187
- path=path,
2188
- )
2218
+ for per_processor in data_processors[modality]:
2219
+ per_processor.save_tokenizer(path)
2189
2220
 
2190
2221
  # Clear the documents cache dictionary before saving.
2191
2222
  for modality in [DOCUMENT]:
@@ -2196,10 +2227,13 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2196
2227
  with open(os.path.join(path, "data_processors.pkl"), "wb") as fp:
2197
2228
  pickle.dump(data_processors, fp)
2198
2229
 
2230
+ with open(os.path.join(path, "eval_metric.pkl"), "wb") as fp:
2231
+ pickle.dump(self._eval_metric_func, fp)
2232
+
2199
2233
  with open(os.path.join(path, f"assets.json"), "w") as fp:
2200
2234
  json.dump(
2201
2235
  {
2202
- "class_name": self.__class__.__name__,
2236
+ "learner_class": self.__class__.__name__,
2203
2237
  "column_types": self._column_types,
2204
2238
  "label_column": self._label_column,
2205
2239
  "problem_type": self._problem_type,
@@ -2241,68 +2275,37 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2241
2275
 
2242
2276
  with open(os.path.join(path, "assets.json"), "r") as fp:
2243
2277
  assets = json.load(fp)
2244
- config = upgrade_config(config, assets["version"])
2245
2278
 
2246
2279
  with open(os.path.join(path, "df_preprocessor.pkl"), "rb") as fp:
2247
- df_preprocessor = CustomUnpickler(fp).load()
2248
- if (
2249
- not hasattr(df_preprocessor, "_rois_feature_names")
2250
- and hasattr(df_preprocessor, "_image_feature_names")
2251
- and ROIS in df_preprocessor._image_feature_names
2252
- ): # backward compatibility for mmlab models
2253
- df_preprocessor._image_feature_names = [
2254
- name for name in df_preprocessor._image_feature_names if name != ROIS
2255
- ]
2256
- df_preprocessor._rois_feature_names = [ROIS]
2257
-
2280
+ df_preprocessor = pickle.load(fp) # nosec B301
2258
2281
  try:
2259
2282
  with open(os.path.join(path, "data_processors.pkl"), "rb") as fp:
2260
- data_processors = CustomUnpickler(fp).load()
2283
+ data_processors = pickle.load(fp) # nosec B301
2261
2284
  # Load text tokenizers after loading data processors.
2262
- for modality in [
2263
- TEXT,
2264
- TEXT_NER,
2265
- NER,
2266
- DOCUMENT,
2267
- ]: # NER is included for backward compatibility
2285
+ for modality in [TEXT, TEXT_NER, NER, DOCUMENT]:
2268
2286
  if modality in data_processors:
2269
- data_processors[modality] = load_text_tokenizers(
2270
- text_processors=data_processors[modality],
2271
- path=path,
2272
- )
2273
-
2274
- # backward compatibility. Add feature column names in each data processor.
2275
- data_processors = assign_feature_column_names(
2276
- data_processors=data_processors,
2277
- df_preprocessor=df_preprocessor,
2278
- )
2287
+ for per_processor in data_processors[modality]:
2288
+ per_processor.load_tokenizer(path)
2279
2289
 
2280
2290
  # Only keep the modalities with non-empty processors.
2281
2291
  data_processors = {k: v for k, v in data_processors.items() if len(v) > 0}
2282
- except: # backward compatibility. reconstruct the data processor in case something went wrong.
2292
+ except: # reconstruct the data processor in case something went wrong.
2283
2293
  data_processors = None
2284
2294
 
2285
2295
  learner._label_column = assets["label_column"]
2286
2296
  learner._problem_type = assets["problem_type"]
2287
- if "pipeline" in assets: # backward compatibility
2288
- learner._problem_type = assets["pipeline"]
2289
- if "presets" in assets:
2290
- learner._presets = assets["presets"]
2291
- if "best_score" in assets: # backward compatibility
2292
- learner._best_score = assets["best_score"]
2293
- if "total_train_time" in assets: # backward compatibility
2294
- learner._total_train_time = assets["total_train_time"]
2297
+ learner._presets = assets["presets"]
2298
+ learner._best_score = assets["best_score"]
2299
+ learner._total_train_time = assets["total_train_time"]
2295
2300
  learner._eval_metric_name = assets["eval_metric_name"]
2301
+ with open(os.path.join(path, "eval_metric.pkl"), "rb") as fp:
2302
+ learner._eval_metric_func = pickle.load(fp) # nosec B301
2296
2303
  learner._verbosity = verbosity
2297
2304
  learner._resume = resume
2298
2305
  learner._save_path = path # in case the original exp dir is copied to somewhere else
2299
2306
  learner._pretrained_path = path
2300
- if "pretrained" in assets:
2301
- learner._pretrained = assets["pretrained"]
2302
- if "fit_called" in assets:
2303
- learner._fit_called = assets["fit_called"]
2304
- else:
2305
- learner._fit_called = True # backward compatible
2307
+ learner._pretrained = assets["pretrained"]
2308
+ learner._fit_called = assets["fit_called"]
2306
2309
  learner._config = config
2307
2310
  learner._output_shape = assets["output_shape"]
2308
2311
  if "classes" in assets:
@@ -2311,10 +2314,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2311
2314
  learner._validation_metric_name = assets["validation_metric_name"]
2312
2315
  learner._df_preprocessor = df_preprocessor
2313
2316
  learner._data_processors = data_processors
2314
- if "minmax_mode" in assets:
2315
- learner._minmax_mode = assets["minmax_mode"]
2316
- else:
2317
- learner._minmax_mode = get_minmax_mode(learner._validation_metric_name)
2317
+ learner._minmax_mode = assets["minmax_mode"]
2318
2318
 
2319
2319
  return learner
2320
2320
 
@@ -2352,7 +2352,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2352
2352
  assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
2353
2353
  learner = cls(label="dummy_label")
2354
2354
  learner = cls._load_metadata(learner=learner, path=dir_path, resume=resume, verbosity=verbosity)
2355
- peft = OmegaConf.select(learner._config, "optimization.efficient_finetune")
2355
+ peft = learner._config.optim.peft
2356
2356
  learner._model = create_fusion_model(
2357
2357
  config=learner._config,
2358
2358
  num_classes=learner._output_shape,
@@ -2379,8 +2379,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
2379
2379
  loss_func = get_loss_func(
2380
2380
  problem_type=learner._problem_type,
2381
2381
  mixup_active=False,
2382
- loss_func_name=OmegaConf.select(learner._config, "optimization.loss_function"),
2383
- config=learner._config.optimization,
2382
+ loss_func_name=learner._config.optim.loss_func,
2383
+ config=learner._config.optim,
2384
2384
  num_classes=learner._output_shape,
2385
2385
  )
2386
2386
  model_postprocess_fn = get_model_postprocess_fn(