fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,24 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
from torch import nn
|
|
8
|
+
from tqdm import tqdm
|
|
6
9
|
from typing_extensions import override
|
|
7
10
|
|
|
11
|
+
from fusion_bench import LazyStateDict, create_default_model_card, timeit_context
|
|
8
12
|
from fusion_bench.method import BaseAlgorithm
|
|
9
|
-
from fusion_bench.
|
|
13
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
14
|
+
from fusion_bench.modelpool import BaseModelPool, CausalLMPool
|
|
10
15
|
from fusion_bench.utils.type import StateDictType
|
|
11
16
|
|
|
12
17
|
from .slerp_utils import slerp
|
|
13
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from transformers import PreTrainedModel
|
|
21
|
+
|
|
14
22
|
log = logging.getLogger(__name__)
|
|
15
23
|
|
|
16
24
|
|
|
@@ -21,6 +29,7 @@ def slerp_on_state_dicts(
|
|
|
21
29
|
*,
|
|
22
30
|
DOT_THRESHOLD: float = 0.9995,
|
|
23
31
|
epsilon: float = 1e-8,
|
|
32
|
+
show_pbar: bool = False,
|
|
24
33
|
) -> StateDictType:
|
|
25
34
|
"""
|
|
26
35
|
Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
|
|
@@ -36,7 +45,8 @@ def slerp_on_state_dicts(
|
|
|
36
45
|
dict: The interpolated state dictionary.
|
|
37
46
|
"""
|
|
38
47
|
state_dict = {}
|
|
39
|
-
|
|
48
|
+
pbar = secondary_state_dict if not show_pbar else tqdm(secondary_state_dict)
|
|
49
|
+
for key in pbar:
|
|
40
50
|
v0 = primary_state_dict[key]
|
|
41
51
|
v1 = secondary_state_dict[key]
|
|
42
52
|
if v0.shape != v1.shape:
|
|
@@ -49,18 +59,19 @@ def slerp_on_state_dicts(
|
|
|
49
59
|
return state_dict
|
|
50
60
|
|
|
51
61
|
|
|
62
|
+
@auto_register_config
|
|
52
63
|
class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
53
64
|
"""
|
|
54
65
|
General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
|
|
55
66
|
"""
|
|
56
67
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
t: float,
|
|
71
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
72
|
+
epsilon: float = 1e-8,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
64
75
|
"""
|
|
65
76
|
Initialize the SlerpMergeAlgorithm.
|
|
66
77
|
|
|
@@ -69,10 +80,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
69
80
|
DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
|
|
70
81
|
epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
|
|
71
82
|
"""
|
|
72
|
-
|
|
73
|
-
self.DOT_THRESHOLD = DOT_THRESHOLD
|
|
74
|
-
self.epsilon = epsilon
|
|
75
|
-
super().__init__()
|
|
83
|
+
super().__init__(**kwargs)
|
|
76
84
|
|
|
77
85
|
@override
|
|
78
86
|
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
@@ -102,3 +110,91 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
102
110
|
|
|
103
111
|
primary_model.load_state_dict(state_dict)
|
|
104
112
|
return primary_model
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@auto_register_config
|
|
116
|
+
class SlerpForCausalLM(
|
|
117
|
+
SimpleProfilerMixin,
|
|
118
|
+
BaseAlgorithm,
|
|
119
|
+
):
|
|
120
|
+
"""
|
|
121
|
+
Slerp (Spherical Linear Interpolation) for Causal Language Models.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
t: float,
|
|
127
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
128
|
+
epsilon: float = 1e-8,
|
|
129
|
+
model_save_path: Optional[str] = None,
|
|
130
|
+
show_pbar: bool = False,
|
|
131
|
+
**kwargs,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Initialize the SlerpForCausalLM algorithm.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
t (float): The interpolation parameter. Must be in the range [0, 1].
|
|
138
|
+
t=0 returns the first model, t=1 returns the second model,
|
|
139
|
+
t=0.5 provides balanced interpolation.
|
|
140
|
+
DOT_THRESHOLD (float, optional): The threshold for the dot product of normalized vectors.
|
|
141
|
+
When the absolute dot product exceeds this threshold,
|
|
142
|
+
vectors are considered nearly collinear and linear
|
|
143
|
+
interpolation (LERP) is used instead of SLERP for
|
|
144
|
+
numerical stability. Defaults to 0.9995.
|
|
145
|
+
epsilon (float, optional): Small value used for numerical stability to avoid
|
|
146
|
+
division by zero during vector normalization.
|
|
147
|
+
Defaults to 1e-8.
|
|
148
|
+
model_save_path (Optional[str], optional): Path where the merged model should be saved.
|
|
149
|
+
If None, the model is not saved to disk.
|
|
150
|
+
Defaults to None.
|
|
151
|
+
show_pbar (bool, optional): Whether to display a progress bar during the interpolation
|
|
152
|
+
process. Useful for debugging or monitoring progress with
|
|
153
|
+
large models. Defaults to False.
|
|
154
|
+
**kwargs: Additional keyword arguments passed to the parent BaseAlgorithm class.
|
|
155
|
+
"""
|
|
156
|
+
super().__init__(**kwargs)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def run(self, modelpool: CausalLMPool):
|
|
160
|
+
assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
|
|
161
|
+
primary_model = modelpool.load_model(modelpool.all_model_names[0])
|
|
162
|
+
secondary_model = modelpool.load_model(modelpool.all_model_names[1])
|
|
163
|
+
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
primary_state_dict = primary_model.state_dict()
|
|
166
|
+
secondary_state_dict = secondary_model.state_dict()
|
|
167
|
+
state_dict = slerp_on_state_dicts(
|
|
168
|
+
self.t,
|
|
169
|
+
primary_state_dict,
|
|
170
|
+
secondary_state_dict,
|
|
171
|
+
DOT_THRESHOLD=self.DOT_THRESHOLD,
|
|
172
|
+
epsilon=self.epsilon,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if isinstance(primary_model, nn.Module):
|
|
176
|
+
model = primary_model
|
|
177
|
+
model.load_state_dict(state_dict)
|
|
178
|
+
elif isinstance(primary_model, LazyStateDict):
|
|
179
|
+
model: "PreTrainedModel" = deepcopy(primary_model.meta_module)
|
|
180
|
+
model.to(device=primary_model._device)
|
|
181
|
+
model.load_state_dict(state_dict)
|
|
182
|
+
else:
|
|
183
|
+
raise TypeError(
|
|
184
|
+
f"Unsupported model type: {type(primary_model)}. "
|
|
185
|
+
"Expected nn.Module or LazyStateDict."
|
|
186
|
+
)
|
|
187
|
+
if self.model_save_path is not None:
|
|
188
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
189
|
+
tokenizer = modelpool.load_tokenizer()
|
|
190
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
191
|
+
model.save_pretrained(self.model_save_path)
|
|
192
|
+
model_card_str = create_default_model_card(
|
|
193
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
194
|
+
description="Merged model using Slerp.",
|
|
195
|
+
algorithm_config=self.config,
|
|
196
|
+
modelpool_config=modelpool.config,
|
|
197
|
+
)
|
|
198
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
199
|
+
f.write(model_card_str)
|
|
200
|
+
return model
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from accelerate import init_empty_weights
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
from transformers import (
|
|
10
|
+
AutoConfig,
|
|
11
|
+
AutoModelForCausalLM,
|
|
12
|
+
AutoTokenizer,
|
|
13
|
+
LlamaForCausalLM,
|
|
14
|
+
MistralForCausalLM,
|
|
15
|
+
PretrainedConfig,
|
|
16
|
+
PreTrainedModel,
|
|
17
|
+
Qwen2ForCausalLM,
|
|
18
|
+
)
|
|
19
|
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
20
|
+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
21
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
22
|
+
|
|
23
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
24
|
+
from fusion_bench.compat.modelpool import to_modelpool
|
|
25
|
+
from fusion_bench.constants import RuntimeConstants
|
|
26
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
27
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
28
|
+
from fusion_bench.models.hf_utils import (
|
|
29
|
+
create_default_model_card,
|
|
30
|
+
save_pretrained_with_remote_code,
|
|
31
|
+
)
|
|
32
|
+
from fusion_bench.models.modeling_smile_llama import (
|
|
33
|
+
SmileLlamaConfig,
|
|
34
|
+
SmileLlamaForCausalLM,
|
|
35
|
+
SmileLlamaModel,
|
|
36
|
+
)
|
|
37
|
+
from fusion_bench.models.modeling_smile_llama.modeling_smile_llama import (
|
|
38
|
+
SmileLlamaDecoderLayer,
|
|
39
|
+
)
|
|
40
|
+
from fusion_bench.models.modeling_smile_mistral import (
|
|
41
|
+
SmileMistralConfig,
|
|
42
|
+
SmileMistralForCausalLM,
|
|
43
|
+
SmileMistralModel,
|
|
44
|
+
)
|
|
45
|
+
from fusion_bench.models.modeling_smile_mistral.modeling_smile_mistral import (
|
|
46
|
+
SmileMistralDecoderLayer,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Import all SMILE configurations and models
|
|
50
|
+
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
51
|
+
SmileQwen2Config,
|
|
52
|
+
SmileQwen2ForCausalLM,
|
|
53
|
+
SmileQwen2Model,
|
|
54
|
+
)
|
|
55
|
+
from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
|
|
56
|
+
SmileQwen2DecoderLayer,
|
|
57
|
+
)
|
|
58
|
+
from fusion_bench.models.smile_moe.linear_from_hf_config import (
|
|
59
|
+
ExpertNotTrainedError,
|
|
60
|
+
upscale_to_smile_linear,
|
|
61
|
+
)
|
|
62
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
63
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
64
|
+
|
|
65
|
+
log = logging.getLogger(__name__)
|
|
66
|
+
|
|
67
|
+
# Model type mappings
|
|
68
|
+
MODEL_TYPE_MAPPINGS = {
|
|
69
|
+
"qwen2": {
|
|
70
|
+
"base_model_cls": Qwen2ForCausalLM,
|
|
71
|
+
"base_decoder_layer_cls": Qwen2DecoderLayer,
|
|
72
|
+
"smile_config_cls": SmileQwen2Config,
|
|
73
|
+
"smile_model_cls": SmileQwen2ForCausalLM,
|
|
74
|
+
"smile_base_model_cls": SmileQwen2Model,
|
|
75
|
+
"smile_decoder_layer_cls": SmileQwen2DecoderLayer,
|
|
76
|
+
"description": "Qwen2",
|
|
77
|
+
},
|
|
78
|
+
"llama": {
|
|
79
|
+
"base_model_cls": LlamaForCausalLM,
|
|
80
|
+
"base_decoder_layer_cls": LlamaDecoderLayer,
|
|
81
|
+
"smile_config_cls": SmileLlamaConfig,
|
|
82
|
+
"smile_model_cls": SmileLlamaForCausalLM,
|
|
83
|
+
"smile_base_model_cls": SmileLlamaModel,
|
|
84
|
+
"smile_decoder_layer_cls": SmileLlamaDecoderLayer,
|
|
85
|
+
"description": "Llama",
|
|
86
|
+
},
|
|
87
|
+
"mistral": {
|
|
88
|
+
"base_model_cls": MistralForCausalLM,
|
|
89
|
+
"base_decoder_layer_cls": MistralDecoderLayer,
|
|
90
|
+
"smile_config_cls": SmileMistralConfig,
|
|
91
|
+
"smile_model_cls": SmileMistralForCausalLM,
|
|
92
|
+
"smile_base_model_cls": SmileMistralModel,
|
|
93
|
+
"smile_decoder_layer_cls": SmileMistralDecoderLayer,
|
|
94
|
+
"description": "Mistral",
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def detect_model_type(
|
|
100
|
+
model_or_config: Union[PreTrainedModel, PretrainedConfig, str],
|
|
101
|
+
) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Detect the model type from a model, config, or model name/path.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model_or_config: Model, config, or model name/path to detect type from
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
str: The detected model type ("qwen2", "llama", "mistral")
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If model type cannot be detected or is not supported
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(model_or_config, str):
|
|
115
|
+
# Load config from path/name
|
|
116
|
+
config = AutoConfig.from_pretrained(model_or_config)
|
|
117
|
+
elif isinstance(model_or_config, PreTrainedModel):
|
|
118
|
+
config = model_or_config.config
|
|
119
|
+
elif isinstance(model_or_config, PretrainedConfig):
|
|
120
|
+
config = model_or_config
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Unsupported type for model type detection: {type(model_or_config)}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
model_type = getattr(config, "model_type", "").lower()
|
|
127
|
+
|
|
128
|
+
# Handle various model type variations
|
|
129
|
+
if model_type in MODEL_TYPE_MAPPINGS:
|
|
130
|
+
return model_type
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Unsupported model type: {model_type}. Supported types: {list(MODEL_TYPE_MAPPINGS.keys())}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@auto_register_config
|
|
138
|
+
class SmileCausalLMUpscalingAlgorithm(
|
|
139
|
+
SimpleProfilerMixin,
|
|
140
|
+
BaseAlgorithm,
|
|
141
|
+
):
|
|
142
|
+
R"""
|
|
143
|
+
SmileCausalLMUpscalingAlgorithm is a generic model fusion algorithm designed to upscale
|
|
144
|
+
a pretrained CausalLM model using a set of fine-tuned expert models. The algorithm
|
|
145
|
+
supports Qwen2, Llama, and Mistral model architectures and leverages Singular Value
|
|
146
|
+
Decomposition (SVD) to merge the weights of the pretrained model and the expert models
|
|
147
|
+
into a new upscaled model.
|
|
148
|
+
|
|
149
|
+
The algorithm automatically detects the model type and uses the appropriate SMILE
|
|
150
|
+
configuration and model classes.
|
|
151
|
+
|
|
152
|
+
Methods:
|
|
153
|
+
run(modelpool: BaseModelPool) -> Union[SmileQwen2ForCausalLM, SmileLlamaForCausalLM, SmileMistralForCausalLM]:
|
|
154
|
+
Executes the upscaling process and returns the upscaled model.
|
|
155
|
+
|
|
156
|
+
merge(pretrained_model: PreTrainedModel, finetuned_models: List[PreTrainedModel]) -> PreTrainedModel:
|
|
157
|
+
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
modelpool: CausalLMPool
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
device,
|
|
165
|
+
accelerator,
|
|
166
|
+
model_save_path,
|
|
167
|
+
model_dtype,
|
|
168
|
+
num_experts_per_tok,
|
|
169
|
+
rank_of_router,
|
|
170
|
+
rank_of_expert,
|
|
171
|
+
save_with_remote_code: bool = True,
|
|
172
|
+
model_type: str = None, # Optional: explicitly specify model type
|
|
173
|
+
**kwargs,
|
|
174
|
+
):
|
|
175
|
+
super().__init__(**kwargs)
|
|
176
|
+
self.model_mappings = None # Will be set during run()
|
|
177
|
+
|
|
178
|
+
if not torch.cuda.is_available():
|
|
179
|
+
if "cuda" in self.device:
|
|
180
|
+
self.device = "cpu"
|
|
181
|
+
if "cuda" in self.accelerator:
|
|
182
|
+
self.accelerator = "cpu"
|
|
183
|
+
|
|
184
|
+
@torch.no_grad()
|
|
185
|
+
def run(self, modelpool) -> PreTrainedModel:
|
|
186
|
+
"""
|
|
187
|
+
Executes the upscaling process.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
modelpool (ModelPool): The pool of models to be used for upscaling.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
PreTrainedModel: The upscaled model (specific type depends on detected model architecture).
|
|
194
|
+
"""
|
|
195
|
+
self.modelpool = modelpool = to_modelpool(modelpool)
|
|
196
|
+
config = self.config
|
|
197
|
+
|
|
198
|
+
# Auto-detect model type if not specified
|
|
199
|
+
if self.model_type is None:
|
|
200
|
+
self.model_type = detect_model_type(
|
|
201
|
+
modelpool.get_model_path("_pretrained_")
|
|
202
|
+
)
|
|
203
|
+
log.info(f"Auto-detected model type: {self.model_type}")
|
|
204
|
+
|
|
205
|
+
# Get the appropriate model mappings
|
|
206
|
+
if self.model_type not in MODEL_TYPE_MAPPINGS:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Unsupported model type: {self.model_type}. Supported: {list(MODEL_TYPE_MAPPINGS.keys())}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self.model_mappings = MODEL_TYPE_MAPPINGS[self.model_type]
|
|
212
|
+
log.info(f"Using {self.model_mappings['description']} model architecture")
|
|
213
|
+
|
|
214
|
+
with self.profile("load pretrained model"):
|
|
215
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
216
|
+
|
|
217
|
+
with self.profile("load fine-tuned model"):
|
|
218
|
+
finetuned_models = [
|
|
219
|
+
m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
if self.device == "cuda" and torch.cuda.is_available():
|
|
223
|
+
pretrained_model = pretrained_model.cuda()
|
|
224
|
+
print("parameter count of pretrained model:")
|
|
225
|
+
print_parameters(pretrained_model)
|
|
226
|
+
finetuned_models = [m.cuda() for m in finetuned_models]
|
|
227
|
+
|
|
228
|
+
with self.profile("merge model"):
|
|
229
|
+
model = self.merge(pretrained_model, finetuned_models)
|
|
230
|
+
|
|
231
|
+
self.print_profile_summary()
|
|
232
|
+
print("parameter count of upscaled MoE model:")
|
|
233
|
+
print_parameters(model)
|
|
234
|
+
print(model)
|
|
235
|
+
|
|
236
|
+
if self.model_dtype is not None:
|
|
237
|
+
model.to(dtype=parse_dtype(self.model_dtype))
|
|
238
|
+
|
|
239
|
+
if self.model_save_path is not None:
|
|
240
|
+
if os.path.dirname(self.model_save_path):
|
|
241
|
+
os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
|
|
242
|
+
log.info(f"Saving model to {self.model_save_path}")
|
|
243
|
+
tokenizer = self.modelpool.load_tokenizer()
|
|
244
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
245
|
+
if not self.save_with_remote_code:
|
|
246
|
+
model.save_pretrained(self.model_save_path)
|
|
247
|
+
else:
|
|
248
|
+
# Use the appropriate auto_map for the detected model type
|
|
249
|
+
auto_map = {
|
|
250
|
+
"AutoConfig": self.model_mappings["smile_config_cls"],
|
|
251
|
+
"AutoModel": self.model_mappings["smile_base_model_cls"],
|
|
252
|
+
"AutoModelForCausalLM": self.model_mappings["smile_model_cls"],
|
|
253
|
+
}
|
|
254
|
+
save_pretrained_with_remote_code(
|
|
255
|
+
model,
|
|
256
|
+
auto_map=auto_map,
|
|
257
|
+
save_directory=self.model_save_path,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# save readme
|
|
261
|
+
model_card_str = create_default_model_card(
|
|
262
|
+
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
263
|
+
description=f"Merged {self.model_mappings['description']} model using SMILE Upscaling",
|
|
264
|
+
algorithm_config=self.config,
|
|
265
|
+
modelpool_config=modelpool.config,
|
|
266
|
+
)
|
|
267
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
268
|
+
f.write(model_card_str)
|
|
269
|
+
|
|
270
|
+
return model
|
|
271
|
+
|
|
272
|
+
def merge(
|
|
273
|
+
self,
|
|
274
|
+
pretrained_model: PreTrainedModel,
|
|
275
|
+
finetuned_models: List[PreTrainedModel],
|
|
276
|
+
) -> PreTrainedModel:
|
|
277
|
+
"""
|
|
278
|
+
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
pretrained_model (PreTrainedModel): The pretrained model.
|
|
282
|
+
finetuned_models (List[PreTrainedModel]): A list of fine-tuned models.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
PreTrainedModel: The upscaled model (specific type depends on model architecture).
|
|
286
|
+
"""
|
|
287
|
+
with init_empty_weights():
|
|
288
|
+
pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
289
|
+
if isinstance(pretrained_model_config, str):
|
|
290
|
+
pretrained_path = pretrained_model_config
|
|
291
|
+
else:
|
|
292
|
+
pretrained_path = pretrained_model_config.get(
|
|
293
|
+
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
294
|
+
)
|
|
295
|
+
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
296
|
+
|
|
297
|
+
# Create the appropriate SMILE config for the detected model type
|
|
298
|
+
SmileConfigClass = self.model_mappings["smile_config_cls"]
|
|
299
|
+
model_config = SmileConfigClass(
|
|
300
|
+
num_experts_per_tok=self.num_experts_per_tok,
|
|
301
|
+
rank_of_router=self.rank_of_router,
|
|
302
|
+
rank_of_expert=self.rank_of_expert,
|
|
303
|
+
num_local_experts=len(finetuned_models),
|
|
304
|
+
**base_config.to_dict(),
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Create the appropriate SMILE model for the detected model type
|
|
308
|
+
SmileModelClass = self.model_mappings["smile_model_cls"]
|
|
309
|
+
model = SmileModelClass(model_config)
|
|
310
|
+
|
|
311
|
+
model.to(dtype=pretrained_model.dtype).to_empty(device="cpu")
|
|
312
|
+
|
|
313
|
+
# copy pretrained model weights
|
|
314
|
+
state_dict = model.state_dict()
|
|
315
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
316
|
+
for key in list(pretrained_state_dict.keys()):
|
|
317
|
+
if key not in state_dict:
|
|
318
|
+
pretrained_state_dict.pop(key)
|
|
319
|
+
model.load_state_dict(pretrained_state_dict, strict=False)
|
|
320
|
+
|
|
321
|
+
# upscale model
|
|
322
|
+
BaseDecoderLayerClass = self.model_mappings["base_decoder_layer_cls"]
|
|
323
|
+
SmileDecoderLayerClass = self.model_mappings["smile_decoder_layer_cls"]
|
|
324
|
+
|
|
325
|
+
for layer_idx in tqdm(
|
|
326
|
+
range(len(pretrained_model.model.layers)),
|
|
327
|
+
"Upscaling Modules (layer)",
|
|
328
|
+
dynamic_ncols=True,
|
|
329
|
+
):
|
|
330
|
+
if RuntimeConstants.debug and layer_idx > 0:
|
|
331
|
+
log.info(
|
|
332
|
+
"Debug mode enabled: processing only the first layer, skipping remaining layers"
|
|
333
|
+
)
|
|
334
|
+
break
|
|
335
|
+
|
|
336
|
+
pretrained_layer = pretrained_model.model.layers[layer_idx]
|
|
337
|
+
finetuned_layers = [m.model.layers[layer_idx] for m in finetuned_models]
|
|
338
|
+
|
|
339
|
+
target_layer = model.model.layers[layer_idx]
|
|
340
|
+
|
|
341
|
+
for n in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
342
|
+
try:
|
|
343
|
+
upscale_to_smile_linear(
|
|
344
|
+
base=getattr(pretrained_layer.self_attn, n),
|
|
345
|
+
experts=[getattr(m.self_attn, n) for m in finetuned_layers],
|
|
346
|
+
target=getattr(target_layer.self_attn, n),
|
|
347
|
+
accelerator=self.accelerator,
|
|
348
|
+
)
|
|
349
|
+
except ExpertNotTrainedError:
|
|
350
|
+
setattr(
|
|
351
|
+
target_layer.self_attn,
|
|
352
|
+
n,
|
|
353
|
+
getattr(pretrained_layer.self_attn, n),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
for n in ["gate_proj", "up_proj", "down_proj"]:
|
|
357
|
+
try:
|
|
358
|
+
upscale_to_smile_linear(
|
|
359
|
+
base=getattr(pretrained_layer.mlp, n),
|
|
360
|
+
experts=[getattr(m.mlp, n) for m in finetuned_layers],
|
|
361
|
+
target=getattr(target_layer.mlp, n),
|
|
362
|
+
accelerator=self.accelerator,
|
|
363
|
+
)
|
|
364
|
+
except ExpertNotTrainedError:
|
|
365
|
+
setattr(
|
|
366
|
+
target_layer.mlp,
|
|
367
|
+
n,
|
|
368
|
+
getattr(pretrained_layer.mlp, n),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
return model
|
|
@@ -3,12 +3,11 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
6
7
|
|
|
7
8
|
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
8
9
|
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
9
10
|
|
|
10
|
-
from tqdm import tqdm
|
|
11
|
-
|
|
12
11
|
|
|
13
12
|
class ProjectedEnergyAnalysis(
|
|
14
13
|
SimpleProfilerMixin,
|
|
@@ -20,6 +20,7 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
|
20
20
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
21
21
|
from fusion_bench.method import BaseAlgorithm
|
|
22
22
|
from fusion_bench.method.simple_average import simple_average
|
|
23
|
+
from fusion_bench.mixins import auto_register_config
|
|
23
24
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
24
25
|
from fusion_bench.modelpool import BaseModelPool
|
|
25
26
|
from fusion_bench.models.modeling_smile_mistral import (
|
|
@@ -40,7 +41,10 @@ from fusion_bench.utils.parameters import print_parameters
|
|
|
40
41
|
log = logging.getLogger(__name__)
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
class SmileMistralUpscalingAlgorithm(
|
|
44
|
+
class SmileMistralUpscalingAlgorithm(
|
|
45
|
+
SimpleProfilerMixin,
|
|
46
|
+
BaseAlgorithm,
|
|
47
|
+
):
|
|
44
48
|
R"""
|
|
45
49
|
SmileMistralUpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
46
50
|
a pretrained Mistral model using a set of fine-tuned expert models. The algorithm
|