autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -25,7 +25,7 @@ from torch import nn
|
|
25
25
|
|
26
26
|
from autogluon.common.utils.resource_utils import ResourceManager
|
27
27
|
from autogluon.common.utils.try_import import try_import_ray
|
28
|
-
from autogluon.core.metrics import Scorer
|
28
|
+
from autogluon.core.metrics import Scorer, get_metric
|
29
29
|
from autogluon.core.utils.loaders import load_pd
|
30
30
|
|
31
31
|
from .. import version as ag_version
|
@@ -59,7 +59,6 @@ from ..constants import (
|
|
59
59
|
NER,
|
60
60
|
RAY_TUNE_CHECKPOINT,
|
61
61
|
REGRESSION,
|
62
|
-
ROIS,
|
63
62
|
TEXT,
|
64
63
|
TEXT_NER,
|
65
64
|
TORCH_COMPILE_MIN_VERSION,
|
@@ -72,39 +71,51 @@ from ..constants import (
|
|
72
71
|
from ..data import (
|
73
72
|
BaseDataModule,
|
74
73
|
MultiModalFeaturePreprocessor,
|
74
|
+
create_fusion_data_processors,
|
75
|
+
data_to_df,
|
76
|
+
get_mixup,
|
75
77
|
infer_column_types,
|
78
|
+
infer_dtypes_by_model_names,
|
76
79
|
infer_output_shape,
|
77
80
|
infer_problem_type,
|
81
|
+
infer_scarcity_mode_by_data_size,
|
82
|
+
init_df_preprocessor,
|
78
83
|
is_image_column,
|
84
|
+
split_train_tuning_data,
|
85
|
+
turn_on_off_feature_column_info,
|
86
|
+
)
|
87
|
+
from ..models import (
|
88
|
+
create_fusion_model,
|
89
|
+
get_model_postprocess_fn,
|
90
|
+
is_lazy_weight_tensor,
|
91
|
+
list_timm_models,
|
92
|
+
select_model,
|
79
93
|
)
|
80
|
-
from ..
|
81
|
-
|
82
|
-
|
83
|
-
from ..optimization.utils import (
|
94
|
+
from ..optim import (
|
95
|
+
compute_score,
|
96
|
+
get_aug_loss_func,
|
84
97
|
get_loss_func,
|
85
|
-
|
98
|
+
get_minmax_mode,
|
86
99
|
get_norm_layer_param_names,
|
87
|
-
|
100
|
+
get_peft_param_names,
|
101
|
+
get_stopping_threshold,
|
102
|
+
get_torchmetric,
|
103
|
+
infer_metrics,
|
88
104
|
)
|
89
|
-
from ..
|
105
|
+
from ..optim.lit_distiller import DistillerLitModule
|
106
|
+
from ..optim.lit_module import LitModule
|
90
107
|
from ..utils import (
|
91
108
|
AutoMMModelCheckpoint,
|
92
109
|
AutoMMModelCheckpointIO,
|
93
|
-
CustomUnpickler,
|
94
110
|
DDPPredictionWriter,
|
95
111
|
DistillationMixin,
|
96
112
|
ExportMixin,
|
97
113
|
LogFilter,
|
98
114
|
RealtimeMixin,
|
99
115
|
apply_log_filter,
|
100
|
-
assign_feature_column_names,
|
101
116
|
average_checkpoints,
|
102
117
|
compute_inference_batch_size,
|
103
118
|
compute_num_gpus,
|
104
|
-
compute_score,
|
105
|
-
create_fusion_data_processors,
|
106
|
-
create_fusion_model,
|
107
|
-
data_to_df,
|
108
119
|
extract_from_output,
|
109
120
|
filter_hyperparameters,
|
110
121
|
get_config,
|
@@ -112,39 +123,25 @@ from ..utils import (
|
|
112
123
|
get_gpu_message,
|
113
124
|
get_load_ckpt_paths,
|
114
125
|
get_local_pretrained_config_paths,
|
115
|
-
get_minmax_mode,
|
116
|
-
get_mixup,
|
117
|
-
get_stopping_threshold,
|
118
126
|
hyperparameter_tune,
|
119
|
-
infer_dtypes_by_model_names,
|
120
|
-
infer_metrics,
|
121
127
|
infer_precision,
|
122
128
|
infer_problem_type_by_eval_metric,
|
123
|
-
infer_scarcity_mode_by_data_size,
|
124
|
-
init_df_preprocessor,
|
125
129
|
is_interactive_env,
|
126
130
|
is_interactive_strategy,
|
127
|
-
is_lazy_weight_tensor,
|
128
|
-
list_timm_models,
|
129
|
-
load_text_tokenizers,
|
130
131
|
logits_to_prob,
|
131
132
|
on_fit_end_message,
|
132
133
|
on_fit_per_run_start_message,
|
133
134
|
on_fit_start_message,
|
134
135
|
run_ddp_only_once,
|
135
136
|
save_pretrained_model_configs,
|
136
|
-
save_text_tokenizers,
|
137
|
-
select_model,
|
138
137
|
setup_save_path,
|
139
138
|
split_hyperparameters,
|
140
|
-
split_train_tuning_data,
|
141
139
|
tensor_to_ndarray,
|
142
|
-
turn_on_off_feature_column_info,
|
143
140
|
update_config_by_rules,
|
144
141
|
update_hyperparameters,
|
145
142
|
update_tabular_config_by_resources,
|
146
|
-
upgrade_config,
|
147
143
|
)
|
144
|
+
from ..utils.problem_types import PROBLEM_TYPES_REG
|
148
145
|
|
149
146
|
pl_logger = logging.getLogger("lightning")
|
150
147
|
pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/4621
|
@@ -253,6 +250,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
253
250
|
self._eval_metric_func = None
|
254
251
|
if isinstance(eval_metric, str):
|
255
252
|
self._eval_metric_name = eval_metric.lower()
|
253
|
+
self.set_eval_metric_func()
|
256
254
|
elif isinstance(eval_metric, Scorer):
|
257
255
|
self._eval_metric_name = eval_metric.name
|
258
256
|
self._eval_metric_func = eval_metric
|
@@ -350,6 +348,18 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
350
348
|
)
|
351
349
|
return model_size * 1e-6 # convert to megabytes
|
352
350
|
|
351
|
+
def set_eval_metric_func(self):
|
352
|
+
from .matching import MatchingLearner
|
353
|
+
from .ner import NERLearner
|
354
|
+
from .object_detection import ObjectDetectionLearner
|
355
|
+
from .semantic_segmentation import SemanticSegmentationLearner
|
356
|
+
|
357
|
+
if (
|
358
|
+
not isinstance(self, (NERLearner, SemanticSegmentationLearner, MatchingLearner, ObjectDetectionLearner))
|
359
|
+
and self._eval_metric_func is None
|
360
|
+
):
|
361
|
+
self._eval_metric_func = get_metric(self._eval_metric_name)
|
362
|
+
|
353
363
|
def ensure_fit_ready(self):
|
354
364
|
if self._problem_type and not self.problem_property.support_fit:
|
355
365
|
raise RuntimeError(
|
@@ -482,6 +492,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
482
492
|
validation_metric_name=self._validation_metric_name,
|
483
493
|
is_matching=is_matching,
|
484
494
|
)
|
495
|
+
self.set_eval_metric_func()
|
485
496
|
self._minmax_mode = get_minmax_mode(self._validation_metric_name)
|
486
497
|
logger.debug(f"validation_metric_name: {self._validation_metric_name}")
|
487
498
|
logger.debug(f"minmax_mode: {self._minmax_mode}")
|
@@ -598,7 +609,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
598
609
|
# top_k_average is called inside hyperparameter_tune() when building the final predictor.
|
599
610
|
self.top_k_average(
|
600
611
|
save_path=self._save_path,
|
601
|
-
top_k_average_method=self._config.
|
612
|
+
top_k_average_method=self._config.optim.top_k_average_method,
|
602
613
|
strategy=strategy,
|
603
614
|
strict_loading=strict_loading,
|
604
615
|
# Not strict loading if using parameter-efficient finetuning
|
@@ -750,12 +761,13 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
750
761
|
num_classes=self._output_shape,
|
751
762
|
num_numerical_columns=len(df_preprocessor.numerical_feature_names),
|
752
763
|
num_categories=df_preprocessor.categorical_num_categories,
|
764
|
+
numerical_fill_values=df_preprocessor.numerical_fill_values,
|
753
765
|
)
|
754
766
|
return model
|
755
767
|
|
756
768
|
@staticmethod
|
757
769
|
def compile_model_per_run(config, model):
|
758
|
-
if
|
770
|
+
if config.env.compile.turn_on:
|
759
771
|
assert version.parse(torch.__version__) >= version.parse(TORCH_COMPILE_MIN_VERSION), (
|
760
772
|
f"torch.compile requires torch version >= {TORCH_COMPILE_MIN_VERSION}, "
|
761
773
|
f"but torch version {torch.__version__} is detected."
|
@@ -763,21 +775,21 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
763
775
|
logger.debug("Using torch.compile() in compiling the model.")
|
764
776
|
model = torch.compile(
|
765
777
|
model,
|
766
|
-
mode=
|
767
|
-
dynamic=
|
768
|
-
backend=
|
778
|
+
mode=config.env.compile.mode,
|
779
|
+
dynamic=config.env.compile.dynamic,
|
780
|
+
backend=config.env.compile.backend,
|
769
781
|
)
|
770
782
|
return model
|
771
783
|
|
772
784
|
@staticmethod
|
773
785
|
def get_peft_param_names_per_run(model, config):
|
774
786
|
peft_param_names = None
|
775
|
-
peft =
|
787
|
+
peft = config.optim.peft
|
776
788
|
if peft:
|
777
789
|
norm_param_names = get_norm_layer_param_names(model)
|
778
|
-
peft_param_names =
|
790
|
+
peft_param_names = get_peft_param_names(
|
779
791
|
norm_param_names,
|
780
|
-
|
792
|
+
peft=peft,
|
781
793
|
)
|
782
794
|
return peft_param_names
|
783
795
|
|
@@ -809,7 +821,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
809
821
|
return data_processors
|
810
822
|
|
811
823
|
def get_validation_metric_per_run(self):
|
812
|
-
validation_metric, custom_metric_func =
|
824
|
+
validation_metric, custom_metric_func = get_torchmetric(
|
813
825
|
metric_name=self._validation_metric_name,
|
814
826
|
num_classes=self._output_shape,
|
815
827
|
problem_type=self._problem_type,
|
@@ -818,8 +830,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
818
830
|
|
819
831
|
def get_mixup_func_per_run(self, config):
|
820
832
|
mixup_active, mixup_func = get_mixup(
|
821
|
-
model_config=
|
822
|
-
mixup_config=
|
833
|
+
model_config=config.model,
|
834
|
+
mixup_config=config.data.mixup,
|
823
835
|
num_classes=self._output_shape,
|
824
836
|
)
|
825
837
|
if mixup_active and (config.env.per_gpu_batch_size == 1 or config.env.per_gpu_batch_size % 2 == 1):
|
@@ -834,10 +846,14 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
834
846
|
loss_func = get_loss_func(
|
835
847
|
problem_type=self._problem_type,
|
836
848
|
mixup_active=mixup_active,
|
837
|
-
loss_func_name=
|
838
|
-
config=config.
|
849
|
+
loss_func_name=config.optim.loss_func,
|
850
|
+
config=config.optim,
|
839
851
|
)
|
840
|
-
|
852
|
+
aug_loss_func = get_aug_loss_func(
|
853
|
+
config=config.optim,
|
854
|
+
problem_type=self._problem_type,
|
855
|
+
)
|
856
|
+
return loss_func, aug_loss_func
|
841
857
|
|
842
858
|
def get_model_postprocess_fn_per_run(self, loss_func):
|
843
859
|
model_postprocess_fn = get_model_postprocess_fn(
|
@@ -872,26 +888,46 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
872
888
|
datamodule = BaseDataModule(**datamodule_kwargs)
|
873
889
|
return datamodule
|
874
890
|
|
875
|
-
def
|
891
|
+
def get_optim_kwargs_per_run(
|
892
|
+
self,
|
893
|
+
config,
|
894
|
+
validation_metric,
|
895
|
+
custom_metric_func,
|
896
|
+
loss_func,
|
897
|
+
aug_loss_func,
|
898
|
+
mixup_func,
|
899
|
+
grad_steps,
|
900
|
+
):
|
876
901
|
return dict(
|
877
|
-
optim_type=config.
|
878
|
-
lr_choice=config.
|
879
|
-
lr_schedule=config.
|
880
|
-
lr=config.
|
881
|
-
lr_decay=config.
|
882
|
-
end_lr=config.
|
883
|
-
lr_mult=config.
|
884
|
-
weight_decay=config.
|
885
|
-
warmup_steps=config.
|
886
|
-
track_grad_norm=
|
902
|
+
optim_type=config.optim.optim_type,
|
903
|
+
lr_choice=config.optim.lr_choice,
|
904
|
+
lr_schedule=config.optim.lr_schedule,
|
905
|
+
lr=config.optim.lr,
|
906
|
+
lr_decay=config.optim.lr_decay,
|
907
|
+
end_lr=config.optim.end_lr,
|
908
|
+
lr_mult=config.optim.lr_mult,
|
909
|
+
weight_decay=config.optim.weight_decay,
|
910
|
+
warmup_steps=config.optim.warmup_steps,
|
911
|
+
track_grad_norm=config.optim.track_grad_norm,
|
887
912
|
validation_metric=validation_metric,
|
888
913
|
validation_metric_name=self._validation_metric_name,
|
889
914
|
custom_metric_func=custom_metric_func,
|
890
915
|
loss_func=loss_func,
|
891
916
|
mixup_fn=mixup_func,
|
892
|
-
|
893
|
-
mixup_off_epoch=
|
894
|
-
skip_final_val=
|
917
|
+
peft=config.optim.peft,
|
918
|
+
mixup_off_epoch=config.data.mixup.turn_off_epoch,
|
919
|
+
skip_final_val=config.optim.skip_final_val,
|
920
|
+
cross_modal_align=config.optim.cross_modal_align,
|
921
|
+
cross_modal_align_weight=config.optim.cross_modal_align_weight,
|
922
|
+
automatic_optimization=config.optim.automatic_optimization,
|
923
|
+
accumulate_grad_batches=grad_steps,
|
924
|
+
gradient_clip_val=config.optim.gradient_clip_val,
|
925
|
+
gradient_clip_algorithm=config.optim.gradient_clip_algorithm,
|
926
|
+
use_aug_optim=config.optim.lemda.turn_on,
|
927
|
+
aug_loss_func=aug_loss_func,
|
928
|
+
aug_lr=config.optim.lemda.lr,
|
929
|
+
aug_weight_decay=config.optim.lemda.weight_decay,
|
930
|
+
aug_optim_type=config.optim.lemda.optim_type,
|
895
931
|
)
|
896
932
|
|
897
933
|
def get_litmodule_per_run(
|
@@ -899,7 +935,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
899
935
|
model=None,
|
900
936
|
model_postprocess_fn=None,
|
901
937
|
peft_param_names=None,
|
902
|
-
|
938
|
+
optim_kwargs=dict(),
|
903
939
|
distillation_kwargs=dict(),
|
904
940
|
is_train=True,
|
905
941
|
):
|
@@ -908,7 +944,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
908
944
|
return DistillerLitModule(
|
909
945
|
student_model=model,
|
910
946
|
teacher_model=self._teacher_learner._model,
|
911
|
-
**
|
947
|
+
**optim_kwargs,
|
912
948
|
**distillation_kwargs,
|
913
949
|
)
|
914
950
|
else:
|
@@ -916,13 +952,13 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
916
952
|
model=model,
|
917
953
|
model_postprocess_fn=model_postprocess_fn,
|
918
954
|
trainable_param_names=peft_param_names,
|
919
|
-
**
|
955
|
+
**optim_kwargs,
|
920
956
|
)
|
921
957
|
else:
|
922
958
|
return LitModule(
|
923
959
|
model=self._model,
|
924
960
|
model_postprocess_fn=self._model_postprocess_fn,
|
925
|
-
**
|
961
|
+
**optim_kwargs,
|
926
962
|
)
|
927
963
|
|
928
964
|
def get_callbacks_per_run(self, save_path=None, config=None, litmodule=None, pred_writer=None, is_train=True):
|
@@ -935,7 +971,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
935
971
|
|
936
972
|
checkpoint_callback = AutoMMModelCheckpoint(
|
937
973
|
dirpath=save_path,
|
938
|
-
save_top_k=config.
|
974
|
+
save_top_k=config.optim.top_k,
|
939
975
|
verbose=True,
|
940
976
|
monitor=litmodule.validation_metric_name,
|
941
977
|
mode=self._minmax_mode,
|
@@ -943,7 +979,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
943
979
|
)
|
944
980
|
early_stopping_callback = pl.callbacks.EarlyStopping(
|
945
981
|
monitor=litmodule.validation_metric_name,
|
946
|
-
patience=config.
|
982
|
+
patience=config.optim.patience,
|
947
983
|
mode=self._minmax_mode,
|
948
984
|
stopping_threshold=get_stopping_threshold(self._validation_metric_name),
|
949
985
|
)
|
@@ -1020,7 +1056,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1020
1056
|
assert (
|
1021
1057
|
version.parse(pl.__version__) >= version.parse(DEEPSPEED_MIN_PL_VERSION)
|
1022
1058
|
), f"For DeepSpeed Offloading to work reliably you need at least lightning version {DEEPSPEED_MIN_PL_VERSION}, however, found {pl.__version__}. Please update your lightning version."
|
1023
|
-
from ..
|
1059
|
+
from ..optim.deepspeed import CustomDeepSpeedStrategy
|
1024
1060
|
|
1025
1061
|
strategy = CustomDeepSpeedStrategy(
|
1026
1062
|
stage=3,
|
@@ -1059,7 +1095,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1059
1095
|
|
1060
1096
|
num_gpus = compute_num_gpus(
|
1061
1097
|
config_num_gpus=config.env.num_gpus,
|
1062
|
-
accelerator=
|
1098
|
+
accelerator=config.env.accelerator,
|
1063
1099
|
)
|
1064
1100
|
num_gpus = self.update_num_gpus_by_data_size(num_gpus=num_gpus, data=data)
|
1065
1101
|
strategy = self.get_strategy_per_run(num_gpus=num_gpus, config=config)
|
@@ -1096,36 +1132,38 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1096
1132
|
is_train=True,
|
1097
1133
|
):
|
1098
1134
|
if is_train:
|
1135
|
+
trainer_kwargs = dict(
|
1136
|
+
accelerator="gpu" if num_gpus > 0 else config.env.accelerator,
|
1137
|
+
devices=num_gpus if num_gpus > 0 else "auto",
|
1138
|
+
num_nodes=config.env.num_nodes,
|
1139
|
+
precision=precision,
|
1140
|
+
strategy=strategy if strategy else "auto",
|
1141
|
+
benchmark=False,
|
1142
|
+
deterministic=config.env.deterministic,
|
1143
|
+
max_epochs=config.optim.max_epochs,
|
1144
|
+
max_steps=config.optim.max_steps,
|
1145
|
+
max_time=max_time,
|
1146
|
+
callbacks=callbacks,
|
1147
|
+
logger=tb_logger,
|
1148
|
+
log_every_n_steps=config.optim.log_every_n_steps,
|
1149
|
+
enable_progress_bar=enable_progress_bar,
|
1150
|
+
fast_dev_run=config.env.fast_dev_run,
|
1151
|
+
val_check_interval=config.optim.val_check_interval,
|
1152
|
+
check_val_every_n_epoch=config.optim.check_val_every_n_epoch,
|
1153
|
+
plugins=plugins,
|
1154
|
+
)
|
1155
|
+
if config.optim.automatic_optimization:
|
1156
|
+
trainer_kwargs.update(
|
1157
|
+
dict(
|
1158
|
+
gradient_clip_val=config.optim.gradient_clip_val,
|
1159
|
+
gradient_clip_algorithm=config.optim.gradient_clip_algorithm,
|
1160
|
+
accumulate_grad_batches=grad_steps,
|
1161
|
+
)
|
1162
|
+
)
|
1099
1163
|
blacklist_msgs = ["already configured with model summary"]
|
1100
1164
|
log_filter = LogFilter(blacklist_msgs)
|
1101
1165
|
with apply_log_filter(log_filter):
|
1102
|
-
trainer = pl.Trainer(
|
1103
|
-
accelerator="gpu" if num_gpus > 0 else OmegaConf.select(config, "env.accelerator", default="auto"),
|
1104
|
-
devices=num_gpus if num_gpus > 0 else "auto",
|
1105
|
-
num_nodes=config.env.num_nodes,
|
1106
|
-
precision=precision,
|
1107
|
-
strategy=strategy if strategy else "auto",
|
1108
|
-
benchmark=False,
|
1109
|
-
deterministic=config.env.deterministic,
|
1110
|
-
max_epochs=config.optimization.max_epochs,
|
1111
|
-
max_steps=config.optimization.max_steps,
|
1112
|
-
max_time=max_time,
|
1113
|
-
callbacks=callbacks,
|
1114
|
-
logger=tb_logger,
|
1115
|
-
gradient_clip_val=OmegaConf.select(config, "optimization.gradient_clip_val", default=1),
|
1116
|
-
gradient_clip_algorithm=OmegaConf.select(
|
1117
|
-
config, "optimization.gradient_clip_algorithm", default="norm"
|
1118
|
-
),
|
1119
|
-
accumulate_grad_batches=grad_steps,
|
1120
|
-
log_every_n_steps=OmegaConf.select(config, "optimization.log_every_n_steps", default=10),
|
1121
|
-
enable_progress_bar=enable_progress_bar,
|
1122
|
-
fast_dev_run=config.env.fast_dev_run,
|
1123
|
-
val_check_interval=config.optimization.val_check_interval,
|
1124
|
-
check_val_every_n_epoch=config.optimization.check_val_every_n_epoch
|
1125
|
-
if hasattr(config.optimization, "check_val_every_n_epoch")
|
1126
|
-
else 1,
|
1127
|
-
plugins=plugins,
|
1128
|
-
)
|
1166
|
+
trainer = pl.Trainer(**trainer_kwargs)
|
1129
1167
|
else:
|
1130
1168
|
blacklist_msgs = []
|
1131
1169
|
if self._verbosity <= 3: # turn off logging in prediction
|
@@ -1140,9 +1178,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1140
1178
|
|
1141
1179
|
with apply_log_filter(log_filter):
|
1142
1180
|
trainer = pl.Trainer(
|
1143
|
-
accelerator="gpu"
|
1144
|
-
if num_gpus > 0
|
1145
|
-
else OmegaConf.select(self._config, "env.accelerator", default="auto"),
|
1181
|
+
accelerator="gpu" if num_gpus > 0 else self._config.env.accelerator,
|
1146
1182
|
devices=num_gpus if num_gpus > 0 else "auto",
|
1147
1183
|
num_nodes=self._config.env.num_nodes,
|
1148
1184
|
precision=precision,
|
@@ -1253,8 +1289,11 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1253
1289
|
)
|
1254
1290
|
validation_metric, custom_metric_func = self.get_validation_metric_per_run()
|
1255
1291
|
mixup_active, mixup_func = self.get_mixup_func_per_run(config=config)
|
1256
|
-
loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
|
1292
|
+
loss_func, aug_loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
|
1257
1293
|
model_postprocess_fn = self.get_model_postprocess_fn_per_run(loss_func=loss_func)
|
1294
|
+
num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)
|
1295
|
+
precision = self.get_precision_per_run(num_gpus=num_gpus, precision=config.env.precision)
|
1296
|
+
grad_steps = self.get_grad_steps(num_gpus=num_gpus, config=config)
|
1258
1297
|
|
1259
1298
|
if max_time == timedelta(seconds=0):
|
1260
1299
|
return dict(
|
@@ -1278,26 +1317,25 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1278
1317
|
per_gpu_batch_size=config.env.per_gpu_batch_size,
|
1279
1318
|
num_workers=config.env.num_workers,
|
1280
1319
|
)
|
1281
|
-
|
1320
|
+
optim_kwargs = self.get_optim_kwargs_per_run(
|
1282
1321
|
config=config,
|
1283
1322
|
validation_metric=validation_metric,
|
1284
1323
|
custom_metric_func=custom_metric_func,
|
1285
1324
|
loss_func=loss_func,
|
1325
|
+
aug_loss_func=aug_loss_func,
|
1286
1326
|
mixup_func=mixup_func,
|
1327
|
+
grad_steps=grad_steps,
|
1287
1328
|
)
|
1288
1329
|
litmodule = self.get_litmodule_per_run(
|
1289
1330
|
model=model,
|
1290
1331
|
model_postprocess_fn=model_postprocess_fn,
|
1291
1332
|
peft_param_names=peft_param_names,
|
1292
|
-
|
1333
|
+
optim_kwargs=optim_kwargs,
|
1293
1334
|
distillation_kwargs=distillation_kwargs,
|
1294
1335
|
)
|
1295
1336
|
callbacks = self.get_callbacks_per_run(save_path=save_path, config=config, litmodule=litmodule)
|
1296
1337
|
plugins = self.get_plugins_per_run(model=model, peft_param_names=peft_param_names)
|
1297
1338
|
tb_logger = self.get_tb_logger(save_path=save_path)
|
1298
|
-
num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)
|
1299
|
-
precision = self.get_precision_per_run(num_gpus=num_gpus, precision=config.env.precision)
|
1300
|
-
grad_steps = self.get_grad_steps(num_gpus=num_gpus, config=config)
|
1301
1339
|
config = self.post_update_config_per_run(
|
1302
1340
|
config=config,
|
1303
1341
|
num_gpus=num_gpus,
|
@@ -1361,9 +1399,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1361
1399
|
standalone=True,
|
1362
1400
|
clean_ckpts=True,
|
1363
1401
|
):
|
1364
|
-
|
1365
|
-
|
1366
|
-
)
|
1402
|
+
eval_metric = self._eval_metric_name if self._eval_metric_func is None else self._eval_metric_func
|
1403
|
+
minmax_mode = get_minmax_mode(eval_metric)
|
1367
1404
|
best_k_models_yaml_path = os.path.join(save_path, BEST_K_MODELS_FILE)
|
1368
1405
|
if os.path.exists(best_k_models_yaml_path):
|
1369
1406
|
with open(best_k_models_yaml_path, "r") as f:
|
@@ -1409,9 +1446,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1409
1446
|
prefix=prefix,
|
1410
1447
|
strict=strict_loading,
|
1411
1448
|
)
|
1412
|
-
best_score = self.evaluate(self._tuning_data, metrics=[
|
1413
|
-
|
1414
|
-
]
|
1449
|
+
best_score = self.evaluate(self._tuning_data, metrics=[eval_metric])
|
1450
|
+
best_score = next(iter(best_score.values()))
|
1415
1451
|
for i in range(1, len(top_k_model_paths)):
|
1416
1452
|
cand_avg_state_dict = average_checkpoints(
|
1417
1453
|
checkpoint_paths=ingredients + [top_k_model_paths[i]],
|
@@ -1421,9 +1457,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1421
1457
|
prefix=prefix,
|
1422
1458
|
strict=strict_loading,
|
1423
1459
|
)
|
1424
|
-
cand_score = self.evaluate(self._tuning_data, metrics=[
|
1425
|
-
|
1426
|
-
]
|
1460
|
+
cand_score = self.evaluate(self._tuning_data, metrics=[eval_metric])
|
1461
|
+
cand_score = next(iter(cand_score.values()))
|
1427
1462
|
if monitor_op(cand_score, best_score):
|
1428
1463
|
# Add new ingredient
|
1429
1464
|
ingredients.append(top_k_model_paths[i])
|
@@ -1432,7 +1467,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1432
1467
|
ingredients = [top_k_model_paths[0]]
|
1433
1468
|
else:
|
1434
1469
|
raise ValueError(
|
1435
|
-
f"The key for '
|
1470
|
+
f"The key for 'optim.top_k_average_method' is not supported. "
|
1436
1471
|
f"We only support '{GREEDY_SOUP}', '{UNIFORM_SOUP}' and '{BEST}'. "
|
1437
1472
|
f"The provided value is '{top_k_average_method}'."
|
1438
1473
|
)
|
@@ -1495,7 +1530,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1495
1530
|
# TODO: Using optimiation_kwargs for inference is confusing and bad design. Remove as soon as fixed in lightning.
|
1496
1531
|
if self._config.env.strategy == DEEPSPEED_OFFLOADING and DEEPSPEED_MODULE not in sys.modules:
|
1497
1532
|
# Need to initialize DeepSpeed and optimizer as currently required in lightning's integration of deepspeed.
|
1498
|
-
from ..
|
1533
|
+
from ..optim.deepspeed import CustomDeepSpeedStrategy
|
1499
1534
|
|
1500
1535
|
strategy = CustomDeepSpeedStrategy(
|
1501
1536
|
stage=3,
|
@@ -1505,21 +1540,21 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1505
1540
|
reduce_bucket_size=self._config.env.deepspeed_allreduce_size,
|
1506
1541
|
)
|
1507
1542
|
|
1508
|
-
|
1509
|
-
optim_type=self._config.
|
1510
|
-
lr_choice=self._config.
|
1511
|
-
lr_schedule=self._config.
|
1512
|
-
lr=self._config.
|
1513
|
-
lr_decay=self._config.
|
1514
|
-
end_lr=self._config.
|
1515
|
-
lr_mult=self._config.
|
1516
|
-
weight_decay=self._config.
|
1517
|
-
warmup_steps=self._config.
|
1543
|
+
optim_kwargs = dict(
|
1544
|
+
optim_type=self._config.optim.optim_type,
|
1545
|
+
lr_choice=self._config.optim.lr_choice,
|
1546
|
+
lr_schedule=self._config.optim.lr_schedule,
|
1547
|
+
lr=self._config.optim.lr,
|
1548
|
+
lr_decay=self._config.optim.lr_decay,
|
1549
|
+
end_lr=self._config.optim.end_lr,
|
1550
|
+
lr_mult=self._config.optim.lr_mult,
|
1551
|
+
weight_decay=self._config.optim.weight_decay,
|
1552
|
+
warmup_steps=self._config.optim.warmup_steps,
|
1518
1553
|
)
|
1519
1554
|
else:
|
1520
|
-
|
1555
|
+
optim_kwargs = {}
|
1521
1556
|
|
1522
|
-
return strategy,
|
1557
|
+
return strategy, optim_kwargs
|
1523
1558
|
|
1524
1559
|
def get_pred_writer(self, strategy):
|
1525
1560
|
pred_writer = None
|
@@ -1650,9 +1685,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1650
1685
|
def get_predict_batch_size_per_run(self, num_gpus: int, strategy: str):
|
1651
1686
|
return compute_inference_batch_size(
|
1652
1687
|
per_gpu_batch_size=self._config.env.per_gpu_batch_size,
|
1653
|
-
|
1654
|
-
per_gpu_batch_size_evaluation=self._config.env.per_gpu_batch_size_evaluation,
|
1655
|
-
# backward compatibility.
|
1688
|
+
inference_batch_size_ratio=self._config.env.inference_batch_size_ratio,
|
1656
1689
|
num_gpus=num_gpus,
|
1657
1690
|
strategy=strategy,
|
1658
1691
|
)
|
@@ -1737,19 +1770,19 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
1737
1770
|
)
|
1738
1771
|
return outputs
|
1739
1772
|
|
1740
|
-
strategy,
|
1773
|
+
strategy, optim_kwargs = self.prepare_deepspeed_offloading(strategy=strategy)
|
1741
1774
|
datamodule = self.get_datamodule_per_run(
|
1742
1775
|
df_preprocessor=df_preprocessor,
|
1743
1776
|
data_processors=data_processors,
|
1744
1777
|
per_gpu_batch_size=batch_size,
|
1745
|
-
num_workers=self._config.env.
|
1778
|
+
num_workers=self._config.env.num_workers_inference,
|
1746
1779
|
predict_data=data,
|
1747
1780
|
is_train=False,
|
1748
1781
|
)
|
1749
1782
|
pred_writer = self.get_pred_writer(strategy=strategy)
|
1750
1783
|
callbacks = self.get_callbacks_per_run(pred_writer=pred_writer, is_train=False)
|
1751
|
-
# TODO: remove
|
1752
|
-
litmodule = self.get_litmodule_per_run(
|
1784
|
+
# TODO: remove optim_kwargs from inference
|
1785
|
+
litmodule = self.get_litmodule_per_run(optim_kwargs=optim_kwargs, is_train=False)
|
1753
1786
|
trainer = self.init_trainer_per_run(
|
1754
1787
|
num_gpus=num_gpus,
|
1755
1788
|
precision=precision,
|
@@ -2167,7 +2200,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2167
2200
|
config = config if config else self._config
|
2168
2201
|
config = copy.deepcopy(config)
|
2169
2202
|
model = model if model else self._model
|
2170
|
-
if standalone and not
|
2203
|
+
if standalone and not config.optim.peft:
|
2171
2204
|
config = save_pretrained_model_configs(model=model, config=config, path=path)
|
2172
2205
|
os.makedirs(path, exist_ok=True)
|
2173
2206
|
OmegaConf.save(config=config, f=os.path.join(path, "config.yaml"))
|
@@ -2182,10 +2215,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2182
2215
|
# Save text tokenizers before saving data processors
|
2183
2216
|
for modality in [TEXT, TEXT_NER, NER, DOCUMENT]:
|
2184
2217
|
if modality in data_processors:
|
2185
|
-
data_processors[modality]
|
2186
|
-
|
2187
|
-
path=path,
|
2188
|
-
)
|
2218
|
+
for per_processor in data_processors[modality]:
|
2219
|
+
per_processor.save_tokenizer(path)
|
2189
2220
|
|
2190
2221
|
# Clear the documents cache dictionary before saving.
|
2191
2222
|
for modality in [DOCUMENT]:
|
@@ -2196,10 +2227,13 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2196
2227
|
with open(os.path.join(path, "data_processors.pkl"), "wb") as fp:
|
2197
2228
|
pickle.dump(data_processors, fp)
|
2198
2229
|
|
2230
|
+
with open(os.path.join(path, "eval_metric.pkl"), "wb") as fp:
|
2231
|
+
pickle.dump(self._eval_metric_func, fp)
|
2232
|
+
|
2199
2233
|
with open(os.path.join(path, f"assets.json"), "w") as fp:
|
2200
2234
|
json.dump(
|
2201
2235
|
{
|
2202
|
-
"
|
2236
|
+
"learner_class": self.__class__.__name__,
|
2203
2237
|
"column_types": self._column_types,
|
2204
2238
|
"label_column": self._label_column,
|
2205
2239
|
"problem_type": self._problem_type,
|
@@ -2241,68 +2275,37 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2241
2275
|
|
2242
2276
|
with open(os.path.join(path, "assets.json"), "r") as fp:
|
2243
2277
|
assets = json.load(fp)
|
2244
|
-
config = upgrade_config(config, assets["version"])
|
2245
2278
|
|
2246
2279
|
with open(os.path.join(path, "df_preprocessor.pkl"), "rb") as fp:
|
2247
|
-
df_preprocessor =
|
2248
|
-
if (
|
2249
|
-
not hasattr(df_preprocessor, "_rois_feature_names")
|
2250
|
-
and hasattr(df_preprocessor, "_image_feature_names")
|
2251
|
-
and ROIS in df_preprocessor._image_feature_names
|
2252
|
-
): # backward compatibility for mmlab models
|
2253
|
-
df_preprocessor._image_feature_names = [
|
2254
|
-
name for name in df_preprocessor._image_feature_names if name != ROIS
|
2255
|
-
]
|
2256
|
-
df_preprocessor._rois_feature_names = [ROIS]
|
2257
|
-
|
2280
|
+
df_preprocessor = pickle.load(fp) # nosec B301
|
2258
2281
|
try:
|
2259
2282
|
with open(os.path.join(path, "data_processors.pkl"), "rb") as fp:
|
2260
|
-
data_processors =
|
2283
|
+
data_processors = pickle.load(fp) # nosec B301
|
2261
2284
|
# Load text tokenizers after loading data processors.
|
2262
|
-
for modality in [
|
2263
|
-
TEXT,
|
2264
|
-
TEXT_NER,
|
2265
|
-
NER,
|
2266
|
-
DOCUMENT,
|
2267
|
-
]: # NER is included for backward compatibility
|
2285
|
+
for modality in [TEXT, TEXT_NER, NER, DOCUMENT]:
|
2268
2286
|
if modality in data_processors:
|
2269
|
-
data_processors[modality]
|
2270
|
-
|
2271
|
-
path=path,
|
2272
|
-
)
|
2273
|
-
|
2274
|
-
# backward compatibility. Add feature column names in each data processor.
|
2275
|
-
data_processors = assign_feature_column_names(
|
2276
|
-
data_processors=data_processors,
|
2277
|
-
df_preprocessor=df_preprocessor,
|
2278
|
-
)
|
2287
|
+
for per_processor in data_processors[modality]:
|
2288
|
+
per_processor.load_tokenizer(path)
|
2279
2289
|
|
2280
2290
|
# Only keep the modalities with non-empty processors.
|
2281
2291
|
data_processors = {k: v for k, v in data_processors.items() if len(v) > 0}
|
2282
|
-
except: #
|
2292
|
+
except: # reconstruct the data processor in case something went wrong.
|
2283
2293
|
data_processors = None
|
2284
2294
|
|
2285
2295
|
learner._label_column = assets["label_column"]
|
2286
2296
|
learner._problem_type = assets["problem_type"]
|
2287
|
-
|
2288
|
-
|
2289
|
-
|
2290
|
-
learner._presets = assets["presets"]
|
2291
|
-
if "best_score" in assets: # backward compatibility
|
2292
|
-
learner._best_score = assets["best_score"]
|
2293
|
-
if "total_train_time" in assets: # backward compatibility
|
2294
|
-
learner._total_train_time = assets["total_train_time"]
|
2297
|
+
learner._presets = assets["presets"]
|
2298
|
+
learner._best_score = assets["best_score"]
|
2299
|
+
learner._total_train_time = assets["total_train_time"]
|
2295
2300
|
learner._eval_metric_name = assets["eval_metric_name"]
|
2301
|
+
with open(os.path.join(path, "eval_metric.pkl"), "rb") as fp:
|
2302
|
+
learner._eval_metric_func = pickle.load(fp) # nosec B301
|
2296
2303
|
learner._verbosity = verbosity
|
2297
2304
|
learner._resume = resume
|
2298
2305
|
learner._save_path = path # in case the original exp dir is copied to somewhere else
|
2299
2306
|
learner._pretrained_path = path
|
2300
|
-
|
2301
|
-
|
2302
|
-
if "fit_called" in assets:
|
2303
|
-
learner._fit_called = assets["fit_called"]
|
2304
|
-
else:
|
2305
|
-
learner._fit_called = True # backward compatible
|
2307
|
+
learner._pretrained = assets["pretrained"]
|
2308
|
+
learner._fit_called = assets["fit_called"]
|
2306
2309
|
learner._config = config
|
2307
2310
|
learner._output_shape = assets["output_shape"]
|
2308
2311
|
if "classes" in assets:
|
@@ -2311,10 +2314,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2311
2314
|
learner._validation_metric_name = assets["validation_metric_name"]
|
2312
2315
|
learner._df_preprocessor = df_preprocessor
|
2313
2316
|
learner._data_processors = data_processors
|
2314
|
-
|
2315
|
-
learner._minmax_mode = assets["minmax_mode"]
|
2316
|
-
else:
|
2317
|
-
learner._minmax_mode = get_minmax_mode(learner._validation_metric_name)
|
2317
|
+
learner._minmax_mode = assets["minmax_mode"]
|
2318
2318
|
|
2319
2319
|
return learner
|
2320
2320
|
|
@@ -2352,7 +2352,7 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2352
2352
|
assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
|
2353
2353
|
learner = cls(label="dummy_label")
|
2354
2354
|
learner = cls._load_metadata(learner=learner, path=dir_path, resume=resume, verbosity=verbosity)
|
2355
|
-
peft =
|
2355
|
+
peft = learner._config.optim.peft
|
2356
2356
|
learner._model = create_fusion_model(
|
2357
2357
|
config=learner._config,
|
2358
2358
|
num_classes=learner._output_shape,
|
@@ -2379,8 +2379,8 @@ class BaseLearner(ExportMixin, DistillationMixin, RealtimeMixin):
|
|
2379
2379
|
loss_func = get_loss_func(
|
2380
2380
|
problem_type=learner._problem_type,
|
2381
2381
|
mixup_active=False,
|
2382
|
-
loss_func_name=
|
2383
|
-
config=learner._config.
|
2382
|
+
loss_func_name=learner._config.optim.loss_func,
|
2383
|
+
config=learner._config.optim,
|
2384
2384
|
num_classes=learner._output_shape,
|
2385
2385
|
)
|
2386
2386
|
model_postprocess_fn = get_model_postprocess_fn(
|