fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -20,6 +20,7 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
|
20
20
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
21
21
|
from fusion_bench.method import BaseAlgorithm
|
|
22
22
|
from fusion_bench.method.simple_average import simple_average
|
|
23
|
+
from fusion_bench.mixins import auto_register_config
|
|
23
24
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
24
25
|
from fusion_bench.modelpool import BaseModelPool
|
|
25
26
|
from fusion_bench.models.modeling_smile_mistral import (
|
|
@@ -40,7 +41,10 @@ from fusion_bench.utils.parameters import print_parameters
|
|
|
40
41
|
log = logging.getLogger(__name__)
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
class SmileMistralUpscalingAlgorithm(
|
|
44
|
+
class SmileMistralUpscalingAlgorithm(
|
|
45
|
+
SimpleProfilerMixin,
|
|
46
|
+
BaseAlgorithm,
|
|
47
|
+
):
|
|
44
48
|
R"""
|
|
45
49
|
SmileMistralUpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
46
50
|
a pretrained Mistral model using a set of fine-tuned expert models. The algorithm
|
|
@@ -16,10 +16,17 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
18
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
19
|
-
from fusion_bench.
|
|
19
|
+
from fusion_bench.constants import RuntimeConstants
|
|
20
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
21
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
22
|
+
from fusion_bench.models.hf_utils import (
|
|
23
|
+
create_default_model_card,
|
|
24
|
+
save_pretrained_with_remote_code,
|
|
25
|
+
)
|
|
20
26
|
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
21
27
|
SmileQwen2Config,
|
|
22
28
|
SmileQwen2ForCausalLM,
|
|
29
|
+
SmileQwen2Model,
|
|
23
30
|
)
|
|
24
31
|
from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
|
|
25
32
|
SmileQwen2DecoderLayer,
|
|
@@ -34,7 +41,11 @@ from fusion_bench.utils.parameters import print_parameters
|
|
|
34
41
|
log = logging.getLogger(__name__)
|
|
35
42
|
|
|
36
43
|
|
|
37
|
-
|
|
44
|
+
@auto_register_config
|
|
45
|
+
class SmileQwen2UpscalingAlgorithm(
|
|
46
|
+
SimpleProfilerMixin,
|
|
47
|
+
BaseAlgorithm,
|
|
48
|
+
):
|
|
38
49
|
R"""
|
|
39
50
|
SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
40
51
|
a pretrained Qwen2 model using a set of fine-tuned expert models. The algorithm
|
|
@@ -49,39 +60,29 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
49
60
|
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
50
61
|
"""
|
|
51
62
|
|
|
52
|
-
|
|
53
|
-
"device": "device",
|
|
54
|
-
"accelerator": "accelerator",
|
|
55
|
-
"model_path": "model_path",
|
|
56
|
-
"model_dtype": "model_dtype",
|
|
57
|
-
"num_experts_per_tok": "num_experts_per_tok",
|
|
58
|
-
"rank_of_router": "rank_of_router",
|
|
59
|
-
"rank_of_expert": "rank_of_expert",
|
|
60
|
-
}
|
|
63
|
+
modelpool: CausalLMPool
|
|
61
64
|
|
|
62
65
|
def __init__(
|
|
63
66
|
self,
|
|
64
67
|
device,
|
|
65
68
|
accelerator,
|
|
66
|
-
|
|
69
|
+
model_save_path,
|
|
67
70
|
model_dtype,
|
|
68
71
|
num_experts_per_tok,
|
|
69
72
|
rank_of_router,
|
|
70
73
|
rank_of_expert,
|
|
74
|
+
save_with_remote_code: bool = True,
|
|
71
75
|
**kwargs,
|
|
72
76
|
):
|
|
73
|
-
self.device = device
|
|
74
|
-
self.accelerator = accelerator
|
|
75
|
-
self.model_path = model_path
|
|
76
|
-
self.model_dtype = model_dtype
|
|
77
|
-
# SmileMoE parameters, except `num_local_experts` which is set later according to the number of finetuned models
|
|
78
|
-
self.num_experts_per_tok = num_experts_per_tok
|
|
79
|
-
self.rank_of_router = rank_of_router
|
|
80
|
-
self.rank_of_expert = rank_of_expert
|
|
81
77
|
super().__init__(**kwargs)
|
|
78
|
+
if not torch.cuda.is_available():
|
|
79
|
+
if "cuda" in self.device:
|
|
80
|
+
self.device = "cpu"
|
|
81
|
+
if "cuda" in self.accelerator:
|
|
82
|
+
self.accelerator = "cpu"
|
|
82
83
|
|
|
83
84
|
@torch.no_grad()
|
|
84
|
-
def run(self, modelpool
|
|
85
|
+
def run(self, modelpool) -> SmileQwen2ForCausalLM:
|
|
85
86
|
"""
|
|
86
87
|
Executes the upscaling process.
|
|
87
88
|
|
|
@@ -94,13 +95,6 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
94
95
|
self.modelpool = modelpool = to_modelpool(modelpool)
|
|
95
96
|
config = self.config
|
|
96
97
|
|
|
97
|
-
# load model from path if provided and return directly
|
|
98
|
-
if config.model_path is not None and os.path.exists(config.model_path):
|
|
99
|
-
log.info(f"Loading model from {config.model_path}")
|
|
100
|
-
model = AutoModelForCausalLM.from_pretrained(config.model_path)
|
|
101
|
-
print_parameters(model)
|
|
102
|
-
return model
|
|
103
|
-
|
|
104
98
|
with self.profile("load pretrained model"):
|
|
105
99
|
pretrained_model = modelpool.load_pretrained_model()
|
|
106
100
|
with self.profile("load fine-tuned model"):
|
|
@@ -108,7 +102,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
108
102
|
m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
109
103
|
]
|
|
110
104
|
|
|
111
|
-
if
|
|
105
|
+
if self.device == "cuda" and torch.cuda.is_available():
|
|
112
106
|
pretrained_model = pretrained_model.cuda()
|
|
113
107
|
print("parameter count of pretrained model:")
|
|
114
108
|
print_parameters(pretrained_model)
|
|
@@ -122,20 +116,37 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
122
116
|
print_parameters(model)
|
|
123
117
|
print(model)
|
|
124
118
|
|
|
125
|
-
if
|
|
126
|
-
model.to(dtype=parse_dtype(
|
|
127
|
-
|
|
128
|
-
if
|
|
129
|
-
if os.path.dirname(
|
|
130
|
-
os.makedirs(os.path.dirname(
|
|
131
|
-
log.info(f"Saving model to {
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
119
|
+
if self.model_dtype is not None:
|
|
120
|
+
model.to(dtype=parse_dtype(self.model_dtype))
|
|
121
|
+
|
|
122
|
+
if self.model_save_path is not None:
|
|
123
|
+
if os.path.dirname(self.model_save_path):
|
|
124
|
+
os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
|
|
125
|
+
log.info(f"Saving model to {self.model_save_path}")
|
|
126
|
+
tokenizer = self.modelpool.load_tokenizer()
|
|
127
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
128
|
+
if not self.save_with_remote_code:
|
|
129
|
+
model.save_pretrained(self.model_save_path)
|
|
130
|
+
else:
|
|
131
|
+
save_pretrained_with_remote_code(
|
|
132
|
+
model,
|
|
133
|
+
auto_map={
|
|
134
|
+
"AutoConfig": SmileQwen2Config,
|
|
135
|
+
"AutoModel": SmileQwen2Model,
|
|
136
|
+
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
137
|
+
},
|
|
138
|
+
save_directory=self.model_save_path,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# save readme
|
|
142
|
+
model_card_str = create_default_model_card(
|
|
143
|
+
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
144
|
+
description="Merged Qwen model using SMILE Upscaling",
|
|
145
|
+
algorithm_config=self.config,
|
|
146
|
+
modelpool_config=modelpool.config,
|
|
135
147
|
)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
model.save_pretrained(config.model_path)
|
|
148
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
149
|
+
f.write(model_card_str)
|
|
139
150
|
|
|
140
151
|
return model
|
|
141
152
|
|
|
@@ -158,14 +169,17 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
158
169
|
|
|
159
170
|
with init_empty_weights():
|
|
160
171
|
pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
172
|
+
if isinstance(pretrained_model_config, str):
|
|
173
|
+
pretrained_path = pretrained_model_config
|
|
174
|
+
else:
|
|
175
|
+
pretrained_path = pretrained_model_config.get(
|
|
176
|
+
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
177
|
+
)
|
|
164
178
|
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
165
179
|
model_config = SmileQwen2Config(
|
|
166
|
-
num_experts_per_tok=
|
|
167
|
-
rank_of_router=
|
|
168
|
-
rank_of_expert=
|
|
180
|
+
num_experts_per_tok=self.num_experts_per_tok,
|
|
181
|
+
rank_of_router=self.rank_of_router,
|
|
182
|
+
rank_of_expert=self.rank_of_expert,
|
|
169
183
|
num_local_experts=len(finetuned_models),
|
|
170
184
|
**base_config.to_dict(),
|
|
171
185
|
)
|
|
@@ -175,7 +189,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
175
189
|
|
|
176
190
|
# copy pretrained model weights
|
|
177
191
|
state_dict = model.state_dict()
|
|
178
|
-
pretrained_state_dict =
|
|
192
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
179
193
|
for key in list(pretrained_state_dict.keys()):
|
|
180
194
|
if key not in state_dict:
|
|
181
195
|
pretrained_state_dict.pop(key)
|
|
@@ -187,6 +201,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
187
201
|
"Upscaling Modules (layer)",
|
|
188
202
|
dynamic_ncols=True,
|
|
189
203
|
):
|
|
204
|
+
if RuntimeConstants.debug and layer_idx > 0:
|
|
205
|
+
log.info(
|
|
206
|
+
"Debug mode enabled: processing only the first layer, skipping remaining layers"
|
|
207
|
+
)
|
|
208
|
+
break
|
|
209
|
+
|
|
190
210
|
pretrained_layer: Qwen2DecoderLayer = pretrained_model.model.layers[
|
|
191
211
|
layer_idx
|
|
192
212
|
]
|
|
@@ -202,7 +222,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
202
222
|
base=getattr(pretrained_layer.self_attn, n),
|
|
203
223
|
experts=[getattr(m.self_attn, n) for m in finetuned_layers],
|
|
204
224
|
target=getattr(target_layer.self_attn, n),
|
|
205
|
-
accelerator=
|
|
225
|
+
accelerator=self.accelerator,
|
|
206
226
|
)
|
|
207
227
|
except ExpertNotTrainedError:
|
|
208
228
|
setattr(
|
|
@@ -217,7 +237,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
217
237
|
base=getattr(pretrained_layer.mlp, n),
|
|
218
238
|
experts=[getattr(m.mlp, n) for m in finetuned_layers],
|
|
219
239
|
target=getattr(target_layer.mlp, n),
|
|
220
|
-
accelerator=
|
|
240
|
+
accelerator=self.accelerator,
|
|
221
241
|
)
|
|
222
242
|
except ExpertNotTrainedError:
|
|
223
243
|
setattr(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import Dict, List, Tuple # noqa: F401
|
|
4
|
+
from typing import Any, Dict, List, Tuple # noqa: F401
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn.functional as F
|
|
@@ -20,6 +20,7 @@ from fusion_bench.models.smile_moe.linear_from_module import (
|
|
|
20
20
|
SmileMoELinear,
|
|
21
21
|
)
|
|
22
22
|
from fusion_bench.models.utils import get_attr, set_attr
|
|
23
|
+
from fusion_bench.utils.devices import get_device
|
|
23
24
|
from fusion_bench.utils.parameters import print_parameters
|
|
24
25
|
|
|
25
26
|
log = logging.getLogger(__name__)
|
|
@@ -54,7 +55,7 @@ class SmileUpscalingAlgorithm(
|
|
|
54
55
|
routing_use_diff: bool = True,
|
|
55
56
|
average_experts: bool = False,
|
|
56
57
|
model_path: str = None,
|
|
57
|
-
**kwargs,
|
|
58
|
+
**kwargs: Any,
|
|
58
59
|
):
|
|
59
60
|
"""
|
|
60
61
|
Initialize the SmileUpscalingAlgorithm.
|
|
@@ -91,7 +92,7 @@ class SmileUpscalingAlgorithm(
|
|
|
91
92
|
print(f"=== Config for `{type(self).__name__}` ===")
|
|
92
93
|
|
|
93
94
|
@torch.no_grad()
|
|
94
|
-
def run(self, modelpool: BaseModelPool):
|
|
95
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
95
96
|
"""
|
|
96
97
|
Executes the upscaling process.
|
|
97
98
|
|
|
@@ -142,7 +143,7 @@ class SmileUpscalingAlgorithm(
|
|
|
142
143
|
pretrained_model: nn.Module,
|
|
143
144
|
finetuned_models: List[nn.Module],
|
|
144
145
|
in_place: bool = True,
|
|
145
|
-
):
|
|
146
|
+
) -> nn.Module:
|
|
146
147
|
"""
|
|
147
148
|
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
148
149
|
|
|
@@ -180,7 +181,12 @@ class SmileUpscalingAlgorithm(
|
|
|
180
181
|
|
|
181
182
|
name_list = name.split(".")
|
|
182
183
|
module = get_attr(pretrained_model, name_list)
|
|
183
|
-
|
|
184
|
+
original_device = get_device(module)
|
|
185
|
+
module = module.to(self.device, non_blocking=True)
|
|
186
|
+
experts = [
|
|
187
|
+
get_attr(m, name_list).to(self.device, non_blocking=True)
|
|
188
|
+
for m in finetuned_models
|
|
189
|
+
]
|
|
184
190
|
try:
|
|
185
191
|
moe_linear = SmileMoELinear(
|
|
186
192
|
module,
|
|
@@ -192,6 +198,7 @@ class SmileUpscalingAlgorithm(
|
|
|
192
198
|
full_matrices=self.full_matrices,
|
|
193
199
|
upscaling_accelerator=self.upscaling_accelerator,
|
|
194
200
|
)
|
|
201
|
+
moe_linear = moe_linear.to(original_device, non_blocking=True)
|
|
195
202
|
except ExpertNotTrainedError:
|
|
196
203
|
print(f"skip {name} because the experts are not trained.")
|
|
197
204
|
return
|
|
@@ -9,7 +9,7 @@ from copy import deepcopy
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from fusion_bench import BaseAlgorithm
|
|
12
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
13
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
14
14
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
15
15
|
state_dict_add,
|
|
@@ -58,16 +58,11 @@ def generate_task_masks(
|
|
|
58
58
|
return final_mask
|
|
59
59
|
|
|
60
60
|
|
|
61
|
+
@auto_register_config
|
|
61
62
|
class TallMaskTaskArithmeticAlgorithm(
|
|
62
|
-
BaseAlgorithm,
|
|
63
63
|
SimpleProfilerMixin,
|
|
64
|
+
BaseAlgorithm,
|
|
64
65
|
):
|
|
65
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
66
|
-
"tall_mask_lambda": "tall_mask_lambda",
|
|
67
|
-
"debug": "debug",
|
|
68
|
-
"verbose": "verbose",
|
|
69
|
-
}
|
|
70
|
-
|
|
71
66
|
def __init__(
|
|
72
67
|
self,
|
|
73
68
|
tall_mask_lambda: float,
|
|
@@ -76,9 +71,6 @@ class TallMaskTaskArithmeticAlgorithm(
|
|
|
76
71
|
**kwargs,
|
|
77
72
|
):
|
|
78
73
|
super().__init__(**kwargs)
|
|
79
|
-
self.tall_mask_lambda = tall_mask_lambda
|
|
80
|
-
self.debug = debug
|
|
81
|
-
self.verbose = verbose
|
|
82
74
|
|
|
83
75
|
@torch.no_grad()
|
|
84
76
|
def run(self, modelpool: BaseModelPool):
|
|
@@ -12,7 +12,7 @@ import torch
|
|
|
12
12
|
from torch import nn
|
|
13
13
|
|
|
14
14
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
15
|
-
from fusion_bench.mixins
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
16
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
17
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
18
18
|
state_dict_add,
|
|
@@ -74,9 +74,10 @@ def task_arithmetic_merge(
|
|
|
74
74
|
return pretrained_model
|
|
75
75
|
|
|
76
76
|
|
|
77
|
+
@auto_register_config
|
|
77
78
|
class TaskArithmeticAlgorithm(
|
|
78
|
-
BaseAlgorithm,
|
|
79
79
|
SimpleProfilerMixin,
|
|
80
|
+
BaseAlgorithm,
|
|
80
81
|
):
|
|
81
82
|
"""
|
|
82
83
|
Task Arithmetic Algorithm for model fusion.
|
|
@@ -89,22 +90,17 @@ class TaskArithmeticAlgorithm(
|
|
|
89
90
|
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
90
91
|
"""
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
"scaling_factor": "scaling_factor"
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
def __init__(self, scaling_factor: int):
|
|
93
|
+
def __init__(self, scaling_factor: int, **kwargs):
|
|
97
94
|
"""
|
|
98
95
|
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
|
|
99
96
|
|
|
100
97
|
Args:
|
|
101
98
|
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
102
99
|
"""
|
|
103
|
-
|
|
104
|
-
super().__init__()
|
|
100
|
+
super().__init__(**kwargs)
|
|
105
101
|
|
|
106
102
|
@torch.no_grad()
|
|
107
|
-
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
103
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
|
|
108
104
|
"""
|
|
109
105
|
Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
|
|
110
106
|
|
|
@@ -9,14 +9,14 @@ Overview of Ties-Merging:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
-
from typing import Dict, List, Literal, Mapping, Union # noqa: F401
|
|
12
|
+
from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
from torch import Tensor, nn
|
|
16
16
|
|
|
17
17
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
18
|
from fusion_bench.method import BaseAlgorithm
|
|
19
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
20
20
|
from fusion_bench.modelpool import BaseModelPool
|
|
21
21
|
from fusion_bench.utils.type import StateDictType
|
|
22
22
|
|
|
@@ -25,33 +25,22 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
|
|
|
25
25
|
log = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
scaling_factor (float): The scaling factor to apply to the merged task vector.
|
|
34
|
-
threshold (float): The threshold for resetting values in the task vector.
|
|
35
|
-
remove_keys (List[str]): List of keys to remove from the state dictionary.
|
|
36
|
-
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
40
|
-
"scaling_factor": "scaling_factor",
|
|
41
|
-
"threshold": "threshold",
|
|
42
|
-
"remove_keys": "remove_keys",
|
|
43
|
-
"merge_func": "merge_func",
|
|
44
|
-
}
|
|
45
|
-
|
|
28
|
+
@auto_register_config
|
|
29
|
+
class TiesMergingAlgorithm(
|
|
30
|
+
SimpleProfilerMixin,
|
|
31
|
+
BaseAlgorithm,
|
|
32
|
+
):
|
|
46
33
|
def __init__(
|
|
47
34
|
self,
|
|
48
35
|
scaling_factor: float,
|
|
49
36
|
threshold: float,
|
|
50
37
|
remove_keys: List[str],
|
|
51
38
|
merge_func: Literal["sum", "mean", "max"],
|
|
52
|
-
**kwargs,
|
|
39
|
+
**kwargs: Any,
|
|
53
40
|
):
|
|
54
41
|
"""
|
|
42
|
+
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
43
|
+
|
|
55
44
|
Initialize the TiesMergingAlgorithm with the given parameters.
|
|
56
45
|
|
|
57
46
|
Args:
|
|
@@ -61,14 +50,12 @@ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
61
50
|
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
62
51
|
**kwargs: Additional keyword arguments for the base class.
|
|
63
52
|
"""
|
|
64
|
-
self.scaling_factor = scaling_factor
|
|
65
|
-
self.threshold = threshold
|
|
66
|
-
self.remove_keys = remove_keys
|
|
67
|
-
self.merge_func = merge_func
|
|
68
53
|
super().__init__(**kwargs)
|
|
69
54
|
|
|
70
55
|
@torch.no_grad()
|
|
71
|
-
def run(
|
|
56
|
+
def run(
|
|
57
|
+
self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs: Any
|
|
58
|
+
) -> nn.Module:
|
|
72
59
|
"""
|
|
73
60
|
Run the TIES merging algorithm to fuse models in the model pool.
|
|
74
61
|
|
|
@@ -2,6 +2,7 @@ import functools
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
+
from typing import Any, Iterator
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
from torch import Tensor
|
|
@@ -38,7 +39,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
38
39
|
|
|
39
40
|
modelpool: CLIPVisionModelPool = None
|
|
40
41
|
|
|
41
|
-
def load_checkpoint(self, model, checkpoint):
|
|
42
|
+
def load_checkpoint(self, model: Any, checkpoint: Any):
|
|
42
43
|
"""
|
|
43
44
|
Load the checkpoint file.
|
|
44
45
|
|
|
@@ -49,7 +50,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
49
50
|
state = {"model": model}
|
|
50
51
|
self._fabric.load(checkpoint, state)
|
|
51
52
|
|
|
52
|
-
def save_checkpoint(self, model, checkpoint):
|
|
53
|
+
def save_checkpoint(self, model: Any, checkpoint: Any):
|
|
53
54
|
"""
|
|
54
55
|
Save the checkpoint file.
|
|
55
56
|
|
|
@@ -102,7 +103,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
102
103
|
return moe_model
|
|
103
104
|
|
|
104
105
|
@functools.cache
|
|
105
|
-
def get_shuffled_test_loader_iter(self, tta_dataset: str):
|
|
106
|
+
def get_shuffled_test_loader_iter(self, tta_dataset: str) -> Iterator:
|
|
106
107
|
"""
|
|
107
108
|
Get an iterator for the shuffled test data loader.
|
|
108
109
|
|
|
@@ -131,7 +132,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
131
132
|
"""
|
|
132
133
|
self.setup_zero_shot_classification_head()
|
|
133
134
|
|
|
134
|
-
def compute_logits(self, module, batch, task) -> Tensor:
|
|
135
|
+
def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
|
|
135
136
|
"""
|
|
136
137
|
Compute the logits for the given batch and task.
|
|
137
138
|
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
|
|
6
|
+
"""
|
|
7
|
+
Compute the entropy loss of a set of logits.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
11
|
+
eps (float): A small value to avoid log(0). Default is 1e-8.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
Tensor: The entropy loss of the logits.
|
|
15
|
+
"""
|
|
16
|
+
# Ensure the logits tensor has 2 dimensions
|
|
17
|
+
assert (
|
|
18
|
+
logits.dim() == 2
|
|
19
|
+
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
|
|
20
|
+
|
|
21
|
+
# Compute the softmax probabilities
|
|
22
|
+
probs = torch.softmax(logits, dim=-1)
|
|
23
|
+
|
|
24
|
+
# Compute the entropy loss
|
|
25
|
+
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
|