fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gc
|
|
3
|
+
import logging
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import List, Mapping, Union # noqa: F401
|
|
6
|
+
|
|
7
|
+
import lightning as L
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from tqdm.autonotebook import tqdm
|
|
15
|
+
|
|
16
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
17
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
18
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
19
|
+
TaskWiseMergedModel,
|
|
20
|
+
get_task_wise_weights,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# obtain the current GPU memory usage
|
|
27
|
+
def print_memory_usage(desc):
|
|
28
|
+
print(desc)
|
|
29
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
30
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
31
|
+
print(f"Allocated Memory: {allocated:.2f} MB")
|
|
32
|
+
print(f"Cached Memory: {cached:.2f} MB")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def entropy_loss(logits: Tensor) -> Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Compute the entropy loss of a set of logits.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tensor: The entropy loss of the logits.
|
|
44
|
+
"""
|
|
45
|
+
probs = torch.softmax(logits, dim=-1)
|
|
46
|
+
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ModelScheduler:
|
|
50
|
+
"""
|
|
51
|
+
Manage the storage of models, schedule the order in which models are loaded to GPU
|
|
52
|
+
transfer data between the CPU and GPU
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
modelpool: ModelPool,
|
|
58
|
+
config: DictConfig,
|
|
59
|
+
):
|
|
60
|
+
self.pretrained_model = modelpool.load_model("_pretrained_")
|
|
61
|
+
self.finetuned_models = [
|
|
62
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
63
|
+
]
|
|
64
|
+
self.num_finetuned_models = len(self.finetuned_models)
|
|
65
|
+
self.new_finetuned_models = copy.deepcopy(self.finetuned_models)
|
|
66
|
+
self.finetuned_model_names = [name for name in modelpool.model_names]
|
|
67
|
+
|
|
68
|
+
self.config = config
|
|
69
|
+
|
|
70
|
+
@torch.no_grad() # not sure whether to use this
|
|
71
|
+
def __call__(self, model_id):
|
|
72
|
+
"""
|
|
73
|
+
return models and relevant data in each step
|
|
74
|
+
"""
|
|
75
|
+
# TODO: use a mixing matrix to determine which models to use in step idx
|
|
76
|
+
|
|
77
|
+
pretrained_model = copy.deepcopy(self.finetuned_models[model_id])
|
|
78
|
+
finetuned_models = [
|
|
79
|
+
copy.deepcopy(
|
|
80
|
+
self.finetuned_models[(model_id + 1) % self.num_finetuned_models]
|
|
81
|
+
),
|
|
82
|
+
copy.deepcopy(
|
|
83
|
+
self.finetuned_models[(model_id - 1) % self.num_finetuned_models]
|
|
84
|
+
),
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
if self.config.weights is None:
|
|
88
|
+
task_wise_weight = get_task_wise_weights(
|
|
89
|
+
num_models=len(finetuned_models),
|
|
90
|
+
init_values=self.config.init_values,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
module = TaskWiseMergedModel(
|
|
96
|
+
task_wise_weight=task_wise_weight,
|
|
97
|
+
pretrained_model=pretrained_model,
|
|
98
|
+
finetuned_models=finetuned_models,
|
|
99
|
+
clamp_weights=self.config.clamp_weights,
|
|
100
|
+
tie_weights=self.config.tie_weights,
|
|
101
|
+
strict=self.config.strict,
|
|
102
|
+
)
|
|
103
|
+
return module
|
|
104
|
+
|
|
105
|
+
def store_model(self, new_finetuned_model_dict, model_id):
|
|
106
|
+
"""
|
|
107
|
+
store new finetuned model after every turn of adamerging
|
|
108
|
+
"""
|
|
109
|
+
self.new_finetuned_models[model_id].load_state_dict(new_finetuned_model_dict)
|
|
110
|
+
|
|
111
|
+
def update_models(self):
|
|
112
|
+
self.finetuned_models = copy.deepcopy(self.new_finetuned_models)
|
|
113
|
+
|
|
114
|
+
def get_final_models(self):
|
|
115
|
+
# need a check
|
|
116
|
+
final_models = [
|
|
117
|
+
{"name": name, "model": model}
|
|
118
|
+
for name, model in zip(self.finetuned_model_names, self.finetuned_models)
|
|
119
|
+
]
|
|
120
|
+
num_finetuned_models = len(self.finetuned_models)
|
|
121
|
+
|
|
122
|
+
state_dict = self.pretrained_model.state_dict(keep_vars=True)
|
|
123
|
+
for name in state_dict.keys():
|
|
124
|
+
state_dict[name].data.zero_()
|
|
125
|
+
for model in self.finetuned_models:
|
|
126
|
+
for name, param in model.named_parameters():
|
|
127
|
+
state_dict[name] = state_dict[name] + 1 / num_finetuned_models * param
|
|
128
|
+
|
|
129
|
+
self.pretrained_model.load_state_dict(state_dict)
|
|
130
|
+
final_models += [{"name": "average model", "model": self.pretrained_model}]
|
|
131
|
+
|
|
132
|
+
return final_models
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TaskWiseGossipAlgorithm(ModelFusionAlgorithm):
|
|
136
|
+
_fabric: L.Fabric = None
|
|
137
|
+
|
|
138
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
139
|
+
super().__init__(algorithm_config)
|
|
140
|
+
|
|
141
|
+
if self._fabric is None and torch.cuda.is_available():
|
|
142
|
+
self._fabric = L.Fabric(devices=self.config.get("devices", 1))
|
|
143
|
+
self._fabric.launch()
|
|
144
|
+
|
|
145
|
+
self.optimizer = None # we want to reuse it in Gossip using single GPU
|
|
146
|
+
|
|
147
|
+
def free_gpu_memory(self, module: TaskWiseMergedModel):
|
|
148
|
+
module.pretrained_model.to("cpu")
|
|
149
|
+
for model in module.task_vectors:
|
|
150
|
+
model.to("cpu")
|
|
151
|
+
del module
|
|
152
|
+
gc.collect()
|
|
153
|
+
torch.cuda.empty_cache()
|
|
154
|
+
print_memory_usage(
|
|
155
|
+
"finish local adamerging, after freeing memory, the memory usage of GPU is:"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def run(self, modelpool: ModelPool):
|
|
159
|
+
log.info("Fusing models using task-wise adaptive merging with gossip.")
|
|
160
|
+
self.modelpool = modelpool
|
|
161
|
+
self.num_finetuned_models = len(modelpool.model_names)
|
|
162
|
+
|
|
163
|
+
model_scheduler = ModelScheduler(self.modelpool, self.config)
|
|
164
|
+
|
|
165
|
+
pbar = tqdm(
|
|
166
|
+
range(self.config.gossip_max_steps), "Gossip merging", dynamic_ncols=True
|
|
167
|
+
)
|
|
168
|
+
for step_idx in pbar:
|
|
169
|
+
log.info(f"step: {step_idx}")
|
|
170
|
+
for model_id in tqdm(
|
|
171
|
+
range(self.num_finetuned_models), "local adamerging", dynamic_ncols=True
|
|
172
|
+
):
|
|
173
|
+
# log.info(f"adamerging model: {model_scheduler.finetuned_midels_name[model_id]}")
|
|
174
|
+
module = model_scheduler(model_id)
|
|
175
|
+
module = self.test_time_adaptation(module)
|
|
176
|
+
# if self.config.get("save_merging_weights", False):
|
|
177
|
+
# torch.save(module.merge_weight, self.config.save_merging_weights)
|
|
178
|
+
print_memory_usage(
|
|
179
|
+
"local adamerging almost done, the memory usage of GPU is:"
|
|
180
|
+
)
|
|
181
|
+
model_scheduler.store_model(module.merge_weights(), model_id)
|
|
182
|
+
print_memory_usage(
|
|
183
|
+
"local adamerging almost done, the memory usage of GPU is:"
|
|
184
|
+
)
|
|
185
|
+
self.free_gpu_memory(
|
|
186
|
+
module
|
|
187
|
+
) # simulate distributed GPU memory usage as much as possible
|
|
188
|
+
|
|
189
|
+
model_scheduler.update_models()
|
|
190
|
+
|
|
191
|
+
return model_scheduler.get_final_models()
|
|
192
|
+
|
|
193
|
+
def on_test_time_adaptation_start(self):
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
@abstractmethod
|
|
197
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
@abstractmethod
|
|
201
|
+
def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Compute the logits for the given batch and task.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
module (nn.Module): The model module.
|
|
207
|
+
batch (tuple): A batch of input data.
|
|
208
|
+
task (str): The name of the task.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Tensor: The classification logits for the batch.
|
|
212
|
+
"""
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
def test_time_adaptation(self, module: TaskWiseMergedModel):
|
|
216
|
+
self.on_test_time_adaptation_start()
|
|
217
|
+
|
|
218
|
+
# configure optimizer
|
|
219
|
+
if self.config.optimizer == "adam":
|
|
220
|
+
self.optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
223
|
+
|
|
224
|
+
if self._fabric is not None:
|
|
225
|
+
module, self.optimizer = self._fabric.setup(module, self.optimizer)
|
|
226
|
+
print_memory_usage(
|
|
227
|
+
"load model and optimizer to GPU, the memory usage of GPU is:"
|
|
228
|
+
)
|
|
229
|
+
module.train()
|
|
230
|
+
module.merge_weights()
|
|
231
|
+
|
|
232
|
+
if self.config.get("fast_dev_run", False):
|
|
233
|
+
log.info("Running fast_dev_run, only one step")
|
|
234
|
+
pbar = tqdm(
|
|
235
|
+
range(1),
|
|
236
|
+
"AdaMerging Test-time adaptation",
|
|
237
|
+
dynamic_ncols=True,
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
pbar = tqdm(
|
|
241
|
+
range(self.config.max_steps),
|
|
242
|
+
"AdaMerging Test-time adaptation",
|
|
243
|
+
dynamic_ncols=True,
|
|
244
|
+
)
|
|
245
|
+
for step_idx in pbar:
|
|
246
|
+
for task in self.modelpool.model_names:
|
|
247
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
248
|
+
logits = self.compute_logits(module, batch, task)
|
|
249
|
+
assert (
|
|
250
|
+
logits.dim() == 2
|
|
251
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
252
|
+
loss = entropy_loss(logits)
|
|
253
|
+
# .backward() accumulates when .zero_grad() wasn't called
|
|
254
|
+
# this can save memory
|
|
255
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
256
|
+
|
|
257
|
+
# print_memory_usage('model + dataset: ')
|
|
258
|
+
self.optimizer.step()
|
|
259
|
+
self.optimizer.zero_grad()
|
|
260
|
+
module.merge_weights()
|
|
261
|
+
|
|
262
|
+
del self.optimizer
|
|
263
|
+
gc.collect()
|
|
264
|
+
torch.cuda.empty_cache()
|
|
265
|
+
return module
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_memory_usage(desc):
|
|
9
|
+
"""
|
|
10
|
+
obtain the current GPU memory usage
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
str: A string containing the allocated and cached memory in MB.
|
|
14
|
+
"""
|
|
15
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
16
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
17
|
+
return (
|
|
18
|
+
f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Model conversion utils
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
26
|
+
"""
|
|
27
|
+
Convert a state dictionary to a vector.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
state_dict (dict): The state dictionary to convert.
|
|
31
|
+
remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
torch.Tensor: The converted vector.
|
|
35
|
+
"""
|
|
36
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
37
|
+
for key in remove_keys:
|
|
38
|
+
if key in shared_state_dict:
|
|
39
|
+
del shared_state_dict[key]
|
|
40
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
41
|
+
return nn.utils.parameters_to_vector(
|
|
42
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
47
|
+
"""
|
|
48
|
+
Convert a vector to a state dictionary.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
vector (torch.Tensor): The vector to convert.
|
|
52
|
+
state_dict (dict): The reference state dictionary to define the order of the vector.
|
|
53
|
+
remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
dict: The converted state dictionary.
|
|
57
|
+
"""
|
|
58
|
+
# create a reference dict to define the order of the vector
|
|
59
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
60
|
+
for key in remove_keys:
|
|
61
|
+
if key in reference_dict:
|
|
62
|
+
del reference_dict[key]
|
|
63
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
64
|
+
|
|
65
|
+
# create a shared state dict using the reference dict
|
|
66
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
67
|
+
|
|
68
|
+
# add back the encoder and decoder embedding weights.
|
|
69
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
70
|
+
for key in remove_keys:
|
|
71
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
72
|
+
"transformer.shared.weight"
|
|
73
|
+
]
|
|
74
|
+
return sorted_reference_dict
|
|
@@ -3,7 +3,7 @@ This module contains the implementation of the Isotropic Merging in Common Subsp
|
|
|
3
3
|
Modified from the original implementation: https://github.com/danielm1405/iso-merging
|
|
4
4
|
|
|
5
5
|
Reference:
|
|
6
|
-
- Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
|
|
6
|
+
- Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
|
|
7
7
|
https://arxiv.org/abs/2502.04959
|
|
8
8
|
"""
|
|
9
9
|
|
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -126,10 +126,14 @@ class OPCMForCLIP(
|
|
|
126
126
|
)
|
|
127
127
|
self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
|
|
128
128
|
self.fabric.log(
|
|
129
|
-
"model/task_vector_norm",
|
|
129
|
+
"model/task_vector_norm",
|
|
130
|
+
self.all_task_vector_norm[-1],
|
|
131
|
+
step=model_idx,
|
|
130
132
|
)
|
|
131
133
|
self.fabric.log(
|
|
132
|
-
"model/avg_task_vector_norm",
|
|
134
|
+
"model/avg_task_vector_norm",
|
|
135
|
+
self.avg_task_vector_norm,
|
|
136
|
+
step=model_idx,
|
|
133
137
|
)
|
|
134
138
|
|
|
135
139
|
self.lambda_t = 1 # temporary value
|
|
@@ -166,9 +170,9 @@ class OPCMForCLIP(
|
|
|
166
170
|
pretrained_W=pretrained_model.get_submodule(
|
|
167
171
|
module_name
|
|
168
172
|
).get_parameter(param_name),
|
|
169
|
-
task_W=task_model.get_submodule(
|
|
170
|
-
|
|
171
|
-
),
|
|
173
|
+
task_W=task_model.get_submodule(
|
|
174
|
+
module_name
|
|
175
|
+
).get_parameter(param_name),
|
|
172
176
|
param_name=".".join([module_name, param_name]),
|
|
173
177
|
accelerator=accelerator,
|
|
174
178
|
)
|
|
@@ -200,10 +204,15 @@ class OPCMForCLIP(
|
|
|
200
204
|
with self.profile("evaluating model"):
|
|
201
205
|
self.taskpool._is_setup = False
|
|
202
206
|
self.taskpool._test_datasets = DictConfig(
|
|
203
|
-
{
|
|
207
|
+
{
|
|
208
|
+
n: self._test_datasets[n]
|
|
209
|
+
for n in model_names[: model_idx + 1]
|
|
210
|
+
}
|
|
204
211
|
)
|
|
205
212
|
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
206
|
-
save_to_json(
|
|
213
|
+
save_to_json(
|
|
214
|
+
report, Path(self.log_dir) / f"report_{model_idx}.json"
|
|
215
|
+
)
|
|
207
216
|
|
|
208
217
|
self.print_profile_summary()
|
|
209
218
|
return merged_model
|