fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from torch import Tensor, nn
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
18
|
from fusion_bench.method import BaseAlgorithm
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
20
|
from fusion_bench.modelpool import BaseModelPool
|
|
20
21
|
from fusion_bench.utils.type import StateDictType
|
|
21
22
|
|
|
@@ -24,7 +25,7 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
|
|
|
24
25
|
log = logging.getLogger(__name__)
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
class TiesMergingAlgorithm(BaseAlgorithm):
|
|
28
|
+
class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
28
29
|
"""
|
|
29
30
|
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
30
31
|
|
|
@@ -84,34 +85,38 @@ class TiesMergingAlgorithm(BaseAlgorithm):
|
|
|
84
85
|
scaling_factor = self.scaling_factor
|
|
85
86
|
threshold = self.threshold
|
|
86
87
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
merged_check
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
88
|
+
with self.profile("loading models"):
|
|
89
|
+
# Load the pretrained model
|
|
90
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
91
|
+
|
|
92
|
+
# Load the state dicts of the models
|
|
93
|
+
ft_checks: List[StateDictType] = [
|
|
94
|
+
modelpool.load_model(model_name).state_dict(keep_vars=True)
|
|
95
|
+
for model_name in modelpool.model_names
|
|
96
|
+
]
|
|
97
|
+
ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
|
|
98
|
+
|
|
99
|
+
with self.profile("merging models"):
|
|
100
|
+
# Compute the task vectors
|
|
101
|
+
flat_ft: Tensor = torch.vstack(
|
|
102
|
+
[state_dict_to_vector(check, remove_keys) for check in ft_checks]
|
|
103
|
+
)
|
|
104
|
+
flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
|
|
105
|
+
tv_flat_checks = flat_ft - flat_ptm
|
|
106
|
+
|
|
107
|
+
# Perform TIES Merging
|
|
108
|
+
merged_tv = ties_merging(
|
|
109
|
+
tv_flat_checks,
|
|
110
|
+
reset_thresh=threshold,
|
|
111
|
+
merge_func=merge_func,
|
|
112
|
+
)
|
|
113
|
+
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
114
|
+
merged_state_dict = vector_to_state_dict(
|
|
115
|
+
merged_check, ptm_check, remove_keys=remove_keys
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Load the merged state dict into the pretrained model
|
|
119
|
+
pretrained_model.load_state_dict(merged_state_dict)
|
|
120
|
+
|
|
121
|
+
self.print_profile_summary()
|
|
117
122
|
return pretrained_model
|
|
@@ -5,7 +5,6 @@ from typing import cast # noqa: F401
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import lightning.fabric.wrappers
|
|
7
7
|
import torch
|
|
8
|
-
from lightning.pytorch.profilers import SimpleProfiler
|
|
9
8
|
from omegaconf import DictConfig
|
|
10
9
|
from torch import Tensor
|
|
11
10
|
from torch.utils.data import DataLoader
|
|
@@ -13,6 +12,7 @@ from tqdm.autonotebook import tqdm
|
|
|
13
12
|
|
|
14
13
|
from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
|
|
15
14
|
from fusion_bench.compat.modelpool import ModelPool
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
16
|
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
17
17
|
from fusion_bench.utils import timeit_context
|
|
18
18
|
from fusion_bench.utils.parameters import print_parameters
|
|
@@ -34,7 +34,10 @@ def entropy_loss(logits: Tensor) -> Tensor:
|
|
|
34
34
|
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class WeightEnsemblingMoEAlgorithm(
|
|
37
|
+
class WeightEnsemblingMoEAlgorithm(
|
|
38
|
+
ModelFusionAlgorithm,
|
|
39
|
+
SimpleProfilerMixin,
|
|
40
|
+
):
|
|
38
41
|
"""
|
|
39
42
|
Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
|
|
40
43
|
|
|
@@ -44,7 +47,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
44
47
|
Attributes:
|
|
45
48
|
_fabric (L.Fabric): The fabric for distributed training.
|
|
46
49
|
modelpool (ModelPool): The pool of models to be fused.
|
|
47
|
-
profiler (SimpleProfiler): The profiler for measuring performance.
|
|
48
50
|
"""
|
|
49
51
|
|
|
50
52
|
_fabric: L.Fabric = None
|
|
@@ -66,9 +68,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
66
68
|
self._fabric.launch()
|
|
67
69
|
else:
|
|
68
70
|
assert "No CUDA device available."
|
|
69
|
-
self.profiler = SimpleProfiler(
|
|
70
|
-
self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
|
|
71
|
-
)
|
|
72
71
|
|
|
73
72
|
@abstractmethod
|
|
74
73
|
def load_checkpoint(self, model, checkpoint):
|
|
@@ -177,9 +176,9 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
177
176
|
for step_idx in pbar:
|
|
178
177
|
if self.config.use_grad_accumulate:
|
|
179
178
|
for task in self.modelpool.model_names:
|
|
180
|
-
with self.
|
|
179
|
+
with self.profile("data time"):
|
|
181
180
|
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
182
|
-
with self.
|
|
181
|
+
with self.profile("forward pass"):
|
|
183
182
|
logits = self.compute_logits(module, batch, task)
|
|
184
183
|
assert (
|
|
185
184
|
logits.dim() == 2
|
|
@@ -187,23 +186,23 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
187
186
|
loss = entropy_loss(logits)
|
|
188
187
|
# .backward() accumulates when .zero_grad() wasn't called
|
|
189
188
|
# this can save memory
|
|
190
|
-
with self.
|
|
189
|
+
with self.profile("backward pass"):
|
|
191
190
|
self._fabric.backward(loss, retain_graph=True)
|
|
192
191
|
else:
|
|
193
192
|
loss = 0
|
|
194
193
|
for task in self.modelpool.model_names:
|
|
195
|
-
with self.
|
|
194
|
+
with self.profile("data time"):
|
|
196
195
|
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
197
|
-
with self.
|
|
196
|
+
with self.profile("forward pass"):
|
|
198
197
|
logits = self.compute_logits(module, batch, task)
|
|
199
198
|
assert (
|
|
200
199
|
logits.dim() == 2
|
|
201
200
|
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
202
201
|
loss = loss + entropy_loss(logits)
|
|
203
|
-
with self.
|
|
202
|
+
with self.profile("backward pass"):
|
|
204
203
|
self._fabric.backward(loss, retain_graph=True)
|
|
205
204
|
|
|
206
|
-
with self.
|
|
205
|
+
with self.profile("optimizer step"):
|
|
207
206
|
optimizer.step()
|
|
208
207
|
optimizer.zero_grad()
|
|
209
208
|
|
|
@@ -232,7 +231,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
232
231
|
)
|
|
233
232
|
self.load_checkpoint(moe_model, self.config.checkpoint)
|
|
234
233
|
else:
|
|
235
|
-
with self.
|
|
234
|
+
with self.profile("test-time adaptation"):
|
|
236
235
|
moe_model = self.test_time_adaptation(moe_model)
|
|
237
236
|
if self.config.get("save_checkpoint", False):
|
|
238
237
|
log.info(f"save checkpoint to {self.config.save_checkpoint}")
|
|
@@ -243,5 +242,5 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
243
242
|
|
|
244
243
|
# enable sample-wise adaptation
|
|
245
244
|
moe_model.batch_reduce = False
|
|
246
|
-
|
|
245
|
+
self.print_profile_summary()
|
|
247
246
|
return moe_model
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -6,20 +6,23 @@ from typing_extensions import TYPE_CHECKING
|
|
|
6
6
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
7
7
|
|
|
8
8
|
_import_structure = {
|
|
9
|
+
"clip_classification": ["CLIPClassificationMixin"],
|
|
10
|
+
"fabric_training": ["FabricTrainingMixin"],
|
|
11
|
+
"hydra_config": ["HydraConfigMixin"],
|
|
9
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
|
+
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
10
14
|
"serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
|
|
11
15
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
12
|
-
"clip_classification": ["CLIPClassificationMixin"],
|
|
13
|
-
"fabric_training": ["FabricTrainingMixin"],
|
|
14
16
|
}
|
|
15
17
|
|
|
16
18
|
if TYPE_CHECKING:
|
|
17
19
|
from .clip_classification import CLIPClassificationMixin
|
|
18
20
|
from .fabric_training import FabricTrainingMixin
|
|
21
|
+
from .hydra_config import HydraConfigMixin
|
|
19
22
|
from .lightning_fabric import LightningFabricMixin
|
|
23
|
+
from .openclip_classification import OpenCLIPClassificationMixin
|
|
20
24
|
from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
|
|
21
25
|
from .simple_profiler import SimpleProfilerMixin
|
|
22
|
-
|
|
23
26
|
else:
|
|
24
27
|
sys.modules[__name__] = LazyImporter(
|
|
25
28
|
__name__,
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import hydra.core.global_hydra
|
|
8
|
+
from hydra import compose, initialize
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils import import_object, instantiate
|
|
12
|
+
from fusion_bench.utils.instantiate import set_print_function_call
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HydraConfigMixin:
|
|
18
|
+
"""
|
|
19
|
+
A mixin for classes that need to be instantiated from a config file.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_config(
|
|
24
|
+
cls,
|
|
25
|
+
config_name: Union[str, Path],
|
|
26
|
+
overrides: Optional[List[str]] = None,
|
|
27
|
+
):
|
|
28
|
+
if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
|
|
29
|
+
raise RuntimeError("Hydra is not initialized.")
|
|
30
|
+
else:
|
|
31
|
+
cfg = compose(config_name=config_name, overrides=overrides)
|
|
32
|
+
|
|
33
|
+
config_groups = config_name.split("/")[:-1]
|
|
34
|
+
for config_group in config_groups:
|
|
35
|
+
cfg = cfg[config_group]
|
|
36
|
+
|
|
37
|
+
if "_target_" in cfg:
|
|
38
|
+
# if the config has a _target_ key, check if it is equal to the class name
|
|
39
|
+
target_cls = import_object(cfg["_target_"])
|
|
40
|
+
if target_cls != cls:
|
|
41
|
+
log.warning(
|
|
42
|
+
f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
|
|
43
|
+
)
|
|
44
|
+
with set_print_function_call(False):
|
|
45
|
+
obj = instantiate(cfg)
|
|
46
|
+
else:
|
|
47
|
+
obj = cls(**cfg)
|
|
48
|
+
|
|
49
|
+
return obj
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
4
|
+
from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
|
|
5
|
+
|
|
6
|
+
log = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenCLIPClassificationMixin(LightningFabricMixin):
|
|
10
|
+
_train_processor = None
|
|
11
|
+
_test_processor = None
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from contextlib import contextmanager
|
|
2
|
-
from typing import Generator
|
|
2
|
+
from typing import Generator, Optional
|
|
3
3
|
|
|
4
4
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
5
5
|
from lightning.pytorch.profilers import SimpleProfiler
|
|
@@ -70,7 +70,9 @@ class SimpleProfilerMixin:
|
|
|
70
70
|
self.profiler.stop(action_name)
|
|
71
71
|
|
|
72
72
|
@rank_zero_only
|
|
73
|
-
def print_profile_summary(self):
|
|
73
|
+
def print_profile_summary(self, title: Optional[str] = None):
|
|
74
|
+
if title is not None:
|
|
75
|
+
print(title)
|
|
74
76
|
print(self.profiler.summary())
|
|
75
77
|
|
|
76
78
|
def __del__(self):
|
|
@@ -6,12 +6,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
|
|
|
6
6
|
|
|
7
7
|
_import_structure = {
|
|
8
8
|
"base_pool": ["BaseModelPool"],
|
|
9
|
+
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
9
10
|
"clip_vision": ["CLIPVisionModelPool"],
|
|
10
11
|
"nyuv2_modelpool": ["NYUv2ModelPool"],
|
|
11
12
|
"huggingface_automodel": ["AutoModelPool"],
|
|
12
|
-
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
13
13
|
"seq2seq_lm": ["Seq2SeqLMPool"],
|
|
14
14
|
"PeftModelForSeq2SeqLM": ["PeftModelForSeq2SeqLMPool"],
|
|
15
|
+
"openclip_vision": ["OpenCLIPVisionModelPool"],
|
|
15
16
|
"huggingface_gpt2_classification": [
|
|
16
17
|
"HuggingFaceGPT2ClassificationPool",
|
|
17
18
|
"GPT2ForSequenceClassificationPool",
|
|
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
|
|
|
30
31
|
HuggingFaceGPT2ClassificationPool,
|
|
31
32
|
)
|
|
32
33
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
34
|
+
from .openclip_vision import OpenCLIPVisionModelPool
|
|
33
35
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
34
36
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
35
37
|
from .seq_classification_lm import SeqenceClassificationModelPool
|
|
@@ -7,7 +7,7 @@ from omegaconf import DictConfig
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
-
from fusion_bench.mixins import BaseYAMLSerializableModel
|
|
10
|
+
from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
|
|
11
11
|
from fusion_bench.utils import instantiate, timeit_context
|
|
12
12
|
|
|
13
13
|
__all__ = ["BaseModelPool"]
|
|
@@ -15,7 +15,7 @@ __all__ = ["BaseModelPool"]
|
|
|
15
15
|
log = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class BaseModelPool(BaseYAMLSerializableModel):
|
|
18
|
+
class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
19
19
|
"""
|
|
20
20
|
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
|
|
21
21
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .modelpool import OpenCLIPVisionModelPool
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pickle
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Callable, Optional, Union, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
12
|
+
from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
|
|
13
|
+
from fusion_bench.utils import instantiate
|
|
14
|
+
from fusion_bench.utils.expr import is_expr_match
|
|
15
|
+
from fusion_bench.utils.packages import _get_package_version, compare_versions
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Add flag to track if warning has been shown
|
|
20
|
+
_openclip_version_warning_shown = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _check_and_redirect_open_clip_modeling():
|
|
24
|
+
global _openclip_version_warning_shown
|
|
25
|
+
if compare_versions(_get_package_version("open-clip-torch").__str__(), "2.0.2") > 0:
|
|
26
|
+
if not _openclip_version_warning_shown:
|
|
27
|
+
log.warning(
|
|
28
|
+
"OpenCLIP version is greater than 2.0.2. This may cause issues with the modelpool."
|
|
29
|
+
)
|
|
30
|
+
_openclip_version_warning_shown = True
|
|
31
|
+
import open_clip.model
|
|
32
|
+
import open_clip.transformer
|
|
33
|
+
|
|
34
|
+
if not hasattr(open_clip.model, "VisualTransformer"):
|
|
35
|
+
open_clip.model.VisualTransformer = open_clip.model.VisionTransformer
|
|
36
|
+
if not hasattr(open_clip.model, "Transformer"):
|
|
37
|
+
open_clip.model.Transformer = open_clip.transformer.Transformer
|
|
38
|
+
if not hasattr(open_clip.model, "ResidualAttentionBlock"):
|
|
39
|
+
open_clip.model.ResidualAttentionBlock = (
|
|
40
|
+
open_clip.transformer.ResidualAttentionBlock
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import src
|
|
45
|
+
import src.modeling
|
|
46
|
+
except ImportError:
|
|
47
|
+
if "src" not in sys.modules:
|
|
48
|
+
# redirect the import of `src` to `fusion_bench.models.open_clip`
|
|
49
|
+
import fusion_bench.models.open_clip as open_clip
|
|
50
|
+
|
|
51
|
+
sys.modules["src"] = open_clip
|
|
52
|
+
log.warning(
|
|
53
|
+
"`src` is not imported."
|
|
54
|
+
"Redirecting the import to `fusion_bench.models.open_clip`"
|
|
55
|
+
)
|
|
56
|
+
if "src.modeling" not in sys.modules:
|
|
57
|
+
# redirect the import of `src.modeling` to `fusion_bench.models.open_clip.modeling`
|
|
58
|
+
import fusion_bench.models.open_clip.modeling as open_clip_modeling
|
|
59
|
+
|
|
60
|
+
sys.modules["src.modeling"] = open_clip_modeling
|
|
61
|
+
log.warning(
|
|
62
|
+
"`src.modeling` is not imported."
|
|
63
|
+
"Redirecting the import to `fusion_bench.models.open_clip.modeling`"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def load_classifier_head(model_config: Union[str, DictConfig], *args, **kwargs):
|
|
68
|
+
if isinstance(model_config, str):
|
|
69
|
+
_check_and_redirect_open_clip_modeling()
|
|
70
|
+
log.info(f"Loading `ClassificationHead` from {model_config}")
|
|
71
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
72
|
+
head = torch.load(model_config, weights_only=weights_only, *args, **kwargs)
|
|
73
|
+
elif isinstance(model_config, nn.Module):
|
|
74
|
+
log.info(f"Returning existing model: {model_config}")
|
|
75
|
+
head = model_config
|
|
76
|
+
else:
|
|
77
|
+
head = instantiate(model_config, *args, **kwargs)
|
|
78
|
+
head = cast(ClassificationHead, head)
|
|
79
|
+
return head
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OpenCLIPVisionModelPool(BaseModelPool):
|
|
83
|
+
"""
|
|
84
|
+
A model pool for managing OpenCLIP Vision models (models from task vector paper).
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
_train_processor = None
|
|
88
|
+
_test_processor = None
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
models: DictConfig,
|
|
93
|
+
classification_heads: Optional[DictConfig] = None,
|
|
94
|
+
**kwargs,
|
|
95
|
+
):
|
|
96
|
+
super().__init__(models, **kwargs)
|
|
97
|
+
self._classification_heads = classification_heads
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def train_processor(self):
|
|
101
|
+
if self._train_processor is None:
|
|
102
|
+
encoder: ImageEncoder = self.load_pretrained_or_first_model()
|
|
103
|
+
self._train_processor = encoder.train_preprocess
|
|
104
|
+
if self._test_processor is None:
|
|
105
|
+
self._test_processor = encoder.val_preprocess
|
|
106
|
+
return self._train_processor
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def test_processor(self):
|
|
110
|
+
if self._test_processor is None:
|
|
111
|
+
encoder: ImageEncoder = self.load_pretrained_or_first_model()
|
|
112
|
+
if self._train_processor is None:
|
|
113
|
+
self._train_processor = encoder.train_preprocess
|
|
114
|
+
self._test_processor = encoder.val_preprocess
|
|
115
|
+
return self._test_processor
|
|
116
|
+
|
|
117
|
+
def load_model(
|
|
118
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
119
|
+
) -> ImageEncoder:
|
|
120
|
+
R"""
|
|
121
|
+
The model config can be:
|
|
122
|
+
|
|
123
|
+
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
124
|
+
- {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
|
|
125
|
+
- {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
|
|
126
|
+
- Default, load the model using `instantiate` from hydra.
|
|
127
|
+
"""
|
|
128
|
+
if (
|
|
129
|
+
isinstance(model_name_or_config, str)
|
|
130
|
+
and model_name_or_config in self._models
|
|
131
|
+
):
|
|
132
|
+
model_config = self._models[model_name_or_config]
|
|
133
|
+
else:
|
|
134
|
+
model_config = model_name_or_config
|
|
135
|
+
if isinstance(model_config, DictConfig):
|
|
136
|
+
model_config = OmegaConf.to_container(model_config, resolve=True)
|
|
137
|
+
|
|
138
|
+
if isinstance(model_config, str):
|
|
139
|
+
# the model config is a string, which is the path to the model checkpoint in pickle format
|
|
140
|
+
# load the model using `torch.load`
|
|
141
|
+
# this is the original usage in the task arithmetic codebase
|
|
142
|
+
_check_and_redirect_open_clip_modeling()
|
|
143
|
+
log.info(f"loading ImageEncoder from {model_config}")
|
|
144
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
145
|
+
try:
|
|
146
|
+
encoder = torch.load(
|
|
147
|
+
model_config, weights_only=weights_only, *args, **kwargs
|
|
148
|
+
)
|
|
149
|
+
except RuntimeError as e:
|
|
150
|
+
encoder = pickle.load(open(model_config, "rb"))
|
|
151
|
+
elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
|
|
152
|
+
# the model config is a dictionary with the following keys:
|
|
153
|
+
# - model_name: str, the name of the model
|
|
154
|
+
# - pickle_path: str, the path to the binary file (pickle format)
|
|
155
|
+
# load the model from the binary file (pickle format)
|
|
156
|
+
# this is useful when you use a newer version of torchvision
|
|
157
|
+
_check_and_redirect_open_clip_modeling()
|
|
158
|
+
log.info(
|
|
159
|
+
f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
|
|
160
|
+
)
|
|
161
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
162
|
+
try:
|
|
163
|
+
encoder = torch.load(
|
|
164
|
+
model_config["pickle_path"],
|
|
165
|
+
weights_only=weights_only,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
except RuntimeError as e:
|
|
170
|
+
encoder = pickle.load(open(model_config["pickle_path"], "rb"))
|
|
171
|
+
_encoder = ImageEncoder(model_config["model_name"])
|
|
172
|
+
_encoder.load_state_dict(encoder.state_dict())
|
|
173
|
+
encoder = _encoder
|
|
174
|
+
elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
|
|
175
|
+
# the model config is a dictionary with the following keys:
|
|
176
|
+
# - model_name: str, the name of the model
|
|
177
|
+
# - state_dict_path: str, the path to the state dict file
|
|
178
|
+
# load the model from the state dict file
|
|
179
|
+
log.info(
|
|
180
|
+
f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
|
|
181
|
+
)
|
|
182
|
+
encoder = ImageEncoder(model_config["model_name"])
|
|
183
|
+
encoder.load_state_dict(
|
|
184
|
+
torch.load(
|
|
185
|
+
model_config["state_dict_path"], weights_only=True, *args, **kwargs
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
elif isinstance(model_config, nn.Module):
|
|
189
|
+
# the model config is an existing model
|
|
190
|
+
log.info(f"Returning existing model: {model_config}")
|
|
191
|
+
encoder = model_config
|
|
192
|
+
else:
|
|
193
|
+
encoder = super().load_model(model_name_or_config, *args, **kwargs)
|
|
194
|
+
encoder = cast(ImageEncoder, encoder)
|
|
195
|
+
|
|
196
|
+
# setup the train and test processors
|
|
197
|
+
if self._train_processor is None and hasattr(encoder, "train_preprocess"):
|
|
198
|
+
self._train_processor = encoder.train_preprocess
|
|
199
|
+
if self._test_processor is None and hasattr(encoder, "val_preprocess"):
|
|
200
|
+
self._test_processor = encoder.val_preprocess
|
|
201
|
+
|
|
202
|
+
return encoder
|
|
203
|
+
|
|
204
|
+
def load_classification_head(
|
|
205
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
206
|
+
) -> ClassificationHead:
|
|
207
|
+
R"""
|
|
208
|
+
The model config can be:
|
|
209
|
+
|
|
210
|
+
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
211
|
+
- Default, load the model using `instantiate` from hydra.
|
|
212
|
+
"""
|
|
213
|
+
if (
|
|
214
|
+
isinstance(model_name_or_config, str)
|
|
215
|
+
and model_name_or_config in self._classification_heads
|
|
216
|
+
):
|
|
217
|
+
model_config = self._classification_heads[model_name_or_config]
|
|
218
|
+
else:
|
|
219
|
+
model_config = model_name_or_config
|
|
220
|
+
|
|
221
|
+
head = load_classifier_head(model_config, *args, **kwargs)
|
|
222
|
+
return head
|
|
223
|
+
|
|
224
|
+
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
225
|
+
dataset_config = self._train_datasets[dataset_name]
|
|
226
|
+
if isinstance(dataset_config, str):
|
|
227
|
+
log.info(
|
|
228
|
+
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
229
|
+
)
|
|
230
|
+
dataset = load_dataset(dataset_config, split="train")
|
|
231
|
+
else:
|
|
232
|
+
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
233
|
+
return dataset
|
|
234
|
+
|
|
235
|
+
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
236
|
+
dataset_config = self._val_datasets[dataset_name]
|
|
237
|
+
if isinstance(dataset_config, str):
|
|
238
|
+
log.info(
|
|
239
|
+
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
240
|
+
)
|
|
241
|
+
dataset = load_dataset(dataset_config, split="validation")
|
|
242
|
+
else:
|
|
243
|
+
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
244
|
+
return dataset
|
|
245
|
+
|
|
246
|
+
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
247
|
+
dataset_config = self._test_datasets[dataset_name]
|
|
248
|
+
if isinstance(dataset_config, str):
|
|
249
|
+
log.info(
|
|
250
|
+
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
251
|
+
)
|
|
252
|
+
dataset = load_dataset(dataset_config, split="test")
|
|
253
|
+
else:
|
|
254
|
+
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
255
|
+
return dataset
|