fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -20,9 +20,11 @@ class AlgorithmFactory:
|
|
|
20
20
|
# model merging methods
|
|
21
21
|
"clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
|
|
22
22
|
"clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
|
|
23
|
-
"clip_layer_wise_adamerging_doge_ta": ".
|
|
23
|
+
"clip_layer_wise_adamerging_doge_ta": ".doge_ta.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
|
|
24
24
|
"singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
|
|
25
25
|
"clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
|
|
26
|
+
"clip_task_wise_gossip": ".gossip.clip_task_wise_gossip.CLIPTaskWiseGossipAlgorithm",
|
|
27
|
+
"clip_layer_wise_gossip": ".gossip.clip_layer_wise_gossip.CLIPLayerWiseGossipAlgorithm",
|
|
26
28
|
# plug-and-play model merging methods
|
|
27
29
|
"clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
28
30
|
"clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
|
|
@@ -148,12 +148,13 @@ class FlanT5GLUETextGenerationTaskPool(LightningFabricMixin, TaskPool):
|
|
|
148
148
|
else:
|
|
149
149
|
raise ValueError(f"Unknown task {task_config.name}")
|
|
150
150
|
|
|
151
|
-
def evaluate(self, model: T5ForConditionalGeneration):
|
|
151
|
+
def evaluate(self, model: T5ForConditionalGeneration, name: str = None):
|
|
152
152
|
"""
|
|
153
153
|
Evaluate the model on the FlanT5 GLUE text generation tasks.
|
|
154
154
|
|
|
155
155
|
Args:
|
|
156
156
|
model (T5ForConditionalGeneration): The model to evaluate.
|
|
157
|
+
name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.
|
|
157
158
|
|
|
158
159
|
Returns:
|
|
159
160
|
dict: A dictionary containing the evaluation results for each task.
|
|
@@ -169,6 +170,8 @@ class FlanT5GLUETextGenerationTaskPool(LightningFabricMixin, TaskPool):
|
|
|
169
170
|
"all_params": all_params,
|
|
170
171
|
"trainable_percentage": training_params / all_params,
|
|
171
172
|
}
|
|
173
|
+
if name is not None:
|
|
174
|
+
report["model_info"]["name"] = name
|
|
172
175
|
model = self.fabric.setup(model)
|
|
173
176
|
report.update(super().evaluate(model))
|
|
174
177
|
log.info(f"evaluation report: {report}")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Constants for CLIP Vision Model Merging
|
|
2
|
+
TASK_NAMES_TA8 = [
|
|
3
|
+
"sun397",
|
|
4
|
+
"stanford-cars",
|
|
5
|
+
"resisc45",
|
|
6
|
+
"eurosat",
|
|
7
|
+
"svhn",
|
|
8
|
+
"gtsrb",
|
|
9
|
+
"mnist",
|
|
10
|
+
"dtd",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TASK_NAMES_TA8_CAP = [
|
|
14
|
+
"SUN397",
|
|
15
|
+
"Cars",
|
|
16
|
+
"RESISC45",
|
|
17
|
+
"EuroSAT",
|
|
18
|
+
"SVHN",
|
|
19
|
+
"GTSRB",
|
|
20
|
+
"MNIST",
|
|
21
|
+
"DTD",
|
|
22
|
+
]
|
|
@@ -2,11 +2,13 @@
|
|
|
2
2
|
This module provides a class to convert a dataset whose object is a list of dictionaries with keys "image" and "label" to a dataset whose object is a tuple of tensors (inputs, label) for CLIP models.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Optional
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from transformers import CLIPProcessor, ProcessorMixin
|
|
9
9
|
|
|
10
|
+
__all__ = ["CLIPDataset"]
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class CLIPDataset(torch.utils.data.Dataset):
|
|
12
14
|
"""
|
|
@@ -34,7 +36,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
34
36
|
"""Returns the number of items in the dataset."""
|
|
35
37
|
return len(self.dataset)
|
|
36
38
|
|
|
37
|
-
def __getitem__(self, idx: int):
|
|
39
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
38
40
|
"""
|
|
39
41
|
Retrieves and processes an item from the dataset.
|
|
40
42
|
|
|
@@ -62,6 +64,12 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
62
64
|
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
63
65
|
"pixel_values"
|
|
64
66
|
][0]
|
|
67
|
+
elif callable(self.processor):
|
|
68
|
+
inputs = self.processor(image)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"The processor should be a CLIPProcessor or a callable function"
|
|
72
|
+
)
|
|
65
73
|
else:
|
|
66
74
|
# if processor is None, return the raw image directly
|
|
67
75
|
inputs = image
|
fusion_bench/dataset/gsm8k.py
CHANGED
|
@@ -6,7 +6,7 @@ from datasets import load_dataset
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def load_gsm8k_question_label_data(
|
|
9
|
-
dataset_name: Literal["train", "test", "train_socratic", "test_socratic"]
|
|
9
|
+
dataset_name: Literal["train", "test", "train_socratic", "test_socratic"],
|
|
10
10
|
):
|
|
11
11
|
R"""
|
|
12
12
|
Load the GSM8K dataset and extract questions and labels.
|
|
@@ -45,7 +45,7 @@ def load_gsm8k_question_label_data(
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def load_gsm8k_question_label_dataset(
|
|
48
|
-
dataset_name: Literal["train", "test", "train_socratic", "test_socratic"]
|
|
48
|
+
dataset_name: Literal["train", "test", "train_socratic", "test_socratic"],
|
|
49
49
|
):
|
|
50
50
|
"""
|
|
51
51
|
Load the GSM8K dataset and return it as a Hugging Face Dataset object.
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -53,7 +53,7 @@ _import_structure = {
|
|
|
53
53
|
"PWEMoExactParetoOptimalForCLIP",
|
|
54
54
|
],
|
|
55
55
|
"ada_svd": ["AdaSVDMergingForCLIPVisionModel"],
|
|
56
|
-
"
|
|
56
|
+
"doge_ta": ["DOGE_TA_Algorithm"],
|
|
57
57
|
"task_singular_vector": ["TaskSingularVectorMerging"],
|
|
58
58
|
"isotropic_merging": [
|
|
59
59
|
"ISO_C_Merge", # alias
|
|
@@ -62,6 +62,11 @@ _import_structure = {
|
|
|
62
62
|
"IsotropicMergingInCommonSubspace",
|
|
63
63
|
],
|
|
64
64
|
"opcm": ["OPCMForCLIP"],
|
|
65
|
+
"gossip": [
|
|
66
|
+
"CLIPLayerWiseGossipAlgorithm",
|
|
67
|
+
"CLIPTaskWiseGossipAlgorithm",
|
|
68
|
+
"FlanT5LayerWiseGossipAlgorithm",
|
|
69
|
+
],
|
|
65
70
|
# plug-and-play model merging methods
|
|
66
71
|
"concrete_subspace": [
|
|
67
72
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -128,7 +133,7 @@ if TYPE_CHECKING:
|
|
|
128
133
|
from .dare import DareSimpleAverage, DareTaskArithmetic, DareTiesMerging
|
|
129
134
|
from .dawe import DataAdaptiveWeightEnsemblingForCLIP
|
|
130
135
|
from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
|
|
131
|
-
from .
|
|
136
|
+
from .doge_ta import DOGE_TA_Algorithm
|
|
132
137
|
from .dummy import DummyAlgorithm
|
|
133
138
|
from .ensemble import (
|
|
134
139
|
MaxModelPredictorAlgorithm,
|
|
@@ -136,6 +141,11 @@ if TYPE_CHECKING:
|
|
|
136
141
|
WeightedEnsembleAlgorithm,
|
|
137
142
|
)
|
|
138
143
|
from .fisher_merging import FisherMergingForCLIPVisionModel
|
|
144
|
+
from .gossip import (
|
|
145
|
+
CLIPLayerWiseGossipAlgorithm,
|
|
146
|
+
CLIPTaskWiseGossipAlgorithm,
|
|
147
|
+
FlanT5LayerWiseGossipAlgorithm,
|
|
148
|
+
)
|
|
139
149
|
from .isotropic_merging import (
|
|
140
150
|
ISO_C_Merge,
|
|
141
151
|
ISO_CTS_Merge,
|
|
@@ -9,7 +9,7 @@ fusion_bench \
|
|
|
9
9
|
modelpool=clip-vit-base-patch32_TA8 \
|
|
10
10
|
taskpool=clip-vit-classification_TA8 \
|
|
11
11
|
fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
|
|
12
|
-
fabric.loggers.name=
|
|
12
|
+
fabric.loggers.name=clip_layer_wise_adamerging_adamerging
|
|
13
13
|
```
|
|
14
14
|
"""
|
|
15
15
|
|
|
@@ -13,41 +13,13 @@ from fusion_bench.modelpool import CLIPVisionModelPool
|
|
|
13
13
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
14
14
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
15
15
|
from fusion_bench.utils import timeit_context
|
|
16
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
16
17
|
|
|
17
18
|
from .task_wise_adamerging import TaskWiseAdaMergingAlgorithm
|
|
18
19
|
|
|
19
20
|
log = logging.getLogger(__name__)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
class InfiniteDataLoader:
|
|
23
|
-
"""
|
|
24
|
-
A wrapper class for DataLoader to create an infinite data loader.
|
|
25
|
-
This is useful in case we are only interested in the number of steps and not the number of epochs.
|
|
26
|
-
|
|
27
|
-
This class wraps a DataLoader and provides an iterator that resets
|
|
28
|
-
when the end of the dataset is reached, creating an infinite loop.
|
|
29
|
-
|
|
30
|
-
Attributes:
|
|
31
|
-
data_loader (DataLoader): The DataLoader to wrap.
|
|
32
|
-
data_iter (iterator): An iterator over the DataLoader.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
def __init__(self, data_loader):
|
|
36
|
-
self.data_loader = data_loader
|
|
37
|
-
self.data_iter = iter(data_loader)
|
|
38
|
-
|
|
39
|
-
def __iter__(self):
|
|
40
|
-
return self
|
|
41
|
-
|
|
42
|
-
def __next__(self):
|
|
43
|
-
try:
|
|
44
|
-
data = next(self.data_iter)
|
|
45
|
-
except StopIteration:
|
|
46
|
-
self.data_iter = iter(self.data_loader) # Reset the data loader
|
|
47
|
-
data = next(self.data_iter)
|
|
48
|
-
return data
|
|
49
|
-
|
|
50
|
-
|
|
51
23
|
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
52
24
|
"""
|
|
53
25
|
A class for task-wise adaptive merging of CLIP models.
|
|
@@ -9,7 +9,7 @@ fusion_bench \
|
|
|
9
9
|
modelpool=clip-vit-base-patch32_TA8 \
|
|
10
10
|
taskpool=clip-vit-classification_TA8 \
|
|
11
11
|
fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
|
|
12
|
-
fabric.loggers.name=
|
|
12
|
+
fabric.loggers.name=clip_layer_wise_adamerging_adamerging
|
|
13
13
|
```
|
|
14
14
|
"""
|
|
15
15
|
|
|
@@ -12,6 +12,7 @@ from torch import Tensor, nn
|
|
|
12
12
|
from tqdm.autonotebook import tqdm
|
|
13
13
|
|
|
14
14
|
from fusion_bench.method import BaseAlgorithm
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
15
16
|
from fusion_bench.modelpool import BaseModelPool
|
|
16
17
|
|
|
17
18
|
log = logging.getLogger(__name__)
|
|
@@ -352,7 +353,7 @@ def filter_state_dict(
|
|
|
352
353
|
return filtered_state_dict
|
|
353
354
|
|
|
354
355
|
|
|
355
|
-
class FisherMergingAlgorithm(BaseAlgorithm):
|
|
356
|
+
class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
356
357
|
"""
|
|
357
358
|
Implements the Fisher Merging Algorithm.
|
|
358
359
|
|
|
@@ -432,25 +433,36 @@ class FisherMergingAlgorithm(BaseAlgorithm):
|
|
|
432
433
|
for param_name in param_names_to_merge:
|
|
433
434
|
models_to_merge_param_dict[param_name].append(param_dict[param_name])
|
|
434
435
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
436
|
+
with (
|
|
437
|
+
self.profile("merging models"),
|
|
438
|
+
self.profile("computing fisher weights"),
|
|
439
|
+
):
|
|
440
|
+
model_to_merge_fisher_weights = self.get_fisher_weights(
|
|
441
|
+
model_name=name,
|
|
442
|
+
model=model,
|
|
443
|
+
train_dataset=modelpool.load_train_dataset(name),
|
|
444
|
+
param_names_to_merge=param_names_to_merge,
|
|
445
|
+
)
|
|
441
446
|
|
|
442
|
-
|
|
447
|
+
models_to_merge_fisher_weights_list.append(
|
|
448
|
+
model_to_merge_fisher_weights
|
|
449
|
+
)
|
|
443
450
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
+
with self.profile("merging models"):
|
|
452
|
+
merged_params = merging_with_fisher_weights(
|
|
453
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
454
|
+
models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
|
|
455
|
+
fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
|
|
456
|
+
normalize_fisher_weight=self.config.get(
|
|
457
|
+
"normalize_fisher_weight", True
|
|
458
|
+
),
|
|
459
|
+
minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
463
|
+
merged_model.load_state_dict(merged_params, strict=False)
|
|
451
464
|
|
|
452
|
-
|
|
453
|
-
merged_model.load_state_dict(merged_params, strict=False)
|
|
465
|
+
self.print_profile_summary()
|
|
454
466
|
return merged_model
|
|
455
467
|
|
|
456
468
|
def get_fisher_weights(
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Example Usage:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
method=adamerging \
|
|
7
|
+
method.name=clip_layer_wise_adamerging \
|
|
8
|
+
method.save_merging_weights=merging_weights.pt \
|
|
9
|
+
modelpool=clip-vit-base-patch32_TA8 \
|
|
10
|
+
taskpool=clip-vit-classification_TA8 \
|
|
11
|
+
fabric_logger.root_dir=outputs/logs/ViT-B-32 \
|
|
12
|
+
fabric_logger.name=clip_layer_wise_adamerging_adam
|
|
13
|
+
```
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
20
|
+
|
|
21
|
+
from .layer_wise_gossip import LayerWiseGossipAlgorithm
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CLIPLayerWiseGossipAlgorithm(
|
|
27
|
+
CLIPClassificationMixin,
|
|
28
|
+
LayerWiseGossipAlgorithm,
|
|
29
|
+
):
|
|
30
|
+
def on_test_time_adaptation_start(self):
|
|
31
|
+
"""
|
|
32
|
+
Here we load the CLIP processor and construct the zero-shot classification head for each task.
|
|
33
|
+
"""
|
|
34
|
+
if self.whether_setup_zero_shot_classification_head == False:
|
|
35
|
+
self.setup_zero_shot_classification_head()
|
|
36
|
+
|
|
37
|
+
@functools.cache
|
|
38
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
39
|
+
return super().get_shuffled_test_loader_iter(
|
|
40
|
+
task,
|
|
41
|
+
batch_size=self.config.batch_size,
|
|
42
|
+
num_workers=self.config.num_workers,
|
|
43
|
+
)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
10
|
+
|
|
11
|
+
from fusion_bench.dataset import CLIPDataset
|
|
12
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
13
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
14
|
+
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
15
|
+
from fusion_bench.utils import timeit_context
|
|
16
|
+
|
|
17
|
+
from .task_wise_gossip import TaskWiseGossipAlgorithm
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class InfiniteDataLoader:
|
|
23
|
+
"""
|
|
24
|
+
A wrapper class for DataLoader to create an infinite data loader.
|
|
25
|
+
This is useful in case we are only interested in the number of steps and not the number of epochs.
|
|
26
|
+
|
|
27
|
+
This class wraps a DataLoader and provides an iterator that resets
|
|
28
|
+
when the end of the dataset is reached, creating an infinite loop.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
data_loader (DataLoader): The DataLoader to wrap.
|
|
32
|
+
data_iter (iterator): An iterator over the DataLoader.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, data_loader):
|
|
36
|
+
self.data_loader = data_loader
|
|
37
|
+
self.data_iter = iter(data_loader)
|
|
38
|
+
|
|
39
|
+
def __iter__(self):
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def __next__(self):
|
|
43
|
+
try:
|
|
44
|
+
data = next(self.data_iter)
|
|
45
|
+
except StopIteration:
|
|
46
|
+
self.data_iter = iter(self.data_loader) # Reset the data loader
|
|
47
|
+
data = next(self.data_iter)
|
|
48
|
+
return data
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CLIPTaskWiseGossipAlgorithm(TaskWiseGossipAlgorithm):
|
|
52
|
+
"""
|
|
53
|
+
A class for task-wise adaptive merging of CLIP models.
|
|
54
|
+
|
|
55
|
+
This class extends the TaskWiseGossipAlgorithm to provide specific
|
|
56
|
+
functionality for CLIP models, including loading datasets, constructing
|
|
57
|
+
zero-shot classification heads, and computing logits.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
|
|
61
|
+
_clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
|
|
62
|
+
zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
modelpool: CLIPVisionModelPool = None
|
|
66
|
+
_clip_processor: CLIPProcessor = None
|
|
67
|
+
zeroshot_weights = {}
|
|
68
|
+
|
|
69
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
70
|
+
super().__init__(algorithm_config)
|
|
71
|
+
|
|
72
|
+
@functools.cache
|
|
73
|
+
def get_test_dataset(self, task: str):
|
|
74
|
+
"""
|
|
75
|
+
Load the test dataset for the task.
|
|
76
|
+
This method is cached, so the dataset is loaded only once.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
task (str): The name of the task.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
CLIPDataset: The test dataset for the task.
|
|
83
|
+
"""
|
|
84
|
+
log.info(f"Loading test dataset: {task}")
|
|
85
|
+
dataset = self.modelpool.load_test_dataset(task)
|
|
86
|
+
dataset = CLIPDataset(dataset, self._clip_processor)
|
|
87
|
+
return dataset
|
|
88
|
+
|
|
89
|
+
@functools.cache
|
|
90
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
91
|
+
"""
|
|
92
|
+
Get an iterator over the shuffled test DataLoader for the task.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
task (str): The name of the task.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
iterator: An iterator over the shuffled test DataLoader.
|
|
99
|
+
"""
|
|
100
|
+
loader = DataLoader(
|
|
101
|
+
self.get_test_dataset(task),
|
|
102
|
+
batch_size=self.config.batch_size,
|
|
103
|
+
shuffle=True,
|
|
104
|
+
num_workers=self.config.num_workers,
|
|
105
|
+
pin_memory=True,
|
|
106
|
+
)
|
|
107
|
+
if self._fabric is not None:
|
|
108
|
+
loader = self._fabric.setup_dataloaders(loader)
|
|
109
|
+
return iter(InfiniteDataLoader(loader))
|
|
110
|
+
|
|
111
|
+
def on_test_time_adaptation_start(self):
|
|
112
|
+
"""
|
|
113
|
+
Prepare for test-time adaptation.
|
|
114
|
+
|
|
115
|
+
This method loads the CLIP processor and constructs the zero-shot
|
|
116
|
+
classification head for each task.
|
|
117
|
+
"""
|
|
118
|
+
if self._clip_processor is not None and self.zeroshot_weights is not None:
|
|
119
|
+
return # this can be reused in Gossip
|
|
120
|
+
|
|
121
|
+
clip_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
122
|
+
pretrained_path = (
|
|
123
|
+
clip_model_config.pretrained_model_name_or_path
|
|
124
|
+
if hasattr(clip_model_config, "pretrained_model_name_or_path")
|
|
125
|
+
else clip_model_config.path
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
with timeit_context("Loading CLIP processor and pretrained CLIP model."):
|
|
129
|
+
self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
|
|
130
|
+
clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)
|
|
131
|
+
|
|
132
|
+
clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
|
|
133
|
+
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
|
|
134
|
+
self.logit_scale_exp = clip_model.logit_scale.exp()
|
|
135
|
+
if self._fabric is not None:
|
|
136
|
+
self.visual_projection = self._fabric.to_device(self.visual_projection)
|
|
137
|
+
self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)
|
|
138
|
+
|
|
139
|
+
for task in self.modelpool.model_names:
|
|
140
|
+
cache_file = os.path.join(
|
|
141
|
+
self.config.cache_dir,
|
|
142
|
+
f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
|
|
143
|
+
)
|
|
144
|
+
if os.path.exists(cache_file):
|
|
145
|
+
log.info(f"Loading cached zeroshot weights for task: {task}")
|
|
146
|
+
zeroshot_weights = torch.load(cache_file, map_location="cpu")
|
|
147
|
+
else:
|
|
148
|
+
log.info(f"Construct zero shot classification head for task: {task}")
|
|
149
|
+
classnames, templates = get_classnames_and_templates(task)
|
|
150
|
+
clip_classifier.set_classification_task(classnames, templates)
|
|
151
|
+
zeroshot_weights = clip_classifier.zeroshot_weights
|
|
152
|
+
log.info(f"save zeroshot weights to {cache_file}")
|
|
153
|
+
torch.save(zeroshot_weights, cache_file)
|
|
154
|
+
self.zeroshot_weights[task] = zeroshot_weights
|
|
155
|
+
if self._fabric is not None:
|
|
156
|
+
self.zeroshot_weights[task] = self._fabric.to_device(
|
|
157
|
+
self.zeroshot_weights[task]
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def compute_logits(self, module, batch, task: str) -> Tensor:
|
|
161
|
+
"""
|
|
162
|
+
Compute the logits for the given batch and task.
|
|
163
|
+
|
|
164
|
+
This method computes the image embeddings, normalizes them, and calculates
|
|
165
|
+
the cosine similarity with the text embeddings to produce classification logits.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
module (nn.Module): The model module.
|
|
169
|
+
batch (tuple): A batch of input data.
|
|
170
|
+
task (str): The name of the task.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tensor: The classification logits for the batch.
|
|
174
|
+
"""
|
|
175
|
+
images, _ = batch
|
|
176
|
+
text_embeds = self.zeroshot_weights[task]
|
|
177
|
+
|
|
178
|
+
image_embeds = module(images)[1]
|
|
179
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
180
|
+
|
|
181
|
+
# normalize embeddings
|
|
182
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
183
|
+
|
|
184
|
+
# cosine similarity
|
|
185
|
+
logits_per_text = (
|
|
186
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
187
|
+
)
|
|
188
|
+
logits_per_image = logits_per_text.t()
|
|
189
|
+
|
|
190
|
+
return logits_per_image
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
|
|
6
|
+
"""
|
|
7
|
+
Compute the entropy loss of a set of logits.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
11
|
+
eps (float): A small value to avoid log(0). Default is 1e-8.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
Tensor: The entropy loss of the logits.
|
|
15
|
+
"""
|
|
16
|
+
# Ensure the logits tensor has 2 dimensions
|
|
17
|
+
assert (
|
|
18
|
+
logits.dim() == 2
|
|
19
|
+
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
|
|
20
|
+
|
|
21
|
+
# Compute the softmax probabilities
|
|
22
|
+
probs = torch.softmax(logits, dim=-1)
|
|
23
|
+
|
|
24
|
+
# Compute the entropy loss
|
|
25
|
+
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
|