fusion-bench 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +21 -2
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/method/__init__.py +8 -2
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +11 -42
- fusion_bench/mixins/serialization.py +18 -8
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -33
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +65 -87
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +9 -4
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +53 -45
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -19,9 +19,28 @@ from . import (
|
|
|
19
19
|
tasks,
|
|
20
20
|
utils,
|
|
21
21
|
)
|
|
22
|
+
from .constants import RuntimeConstants
|
|
22
23
|
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
23
24
|
from .mixins import auto_register_config
|
|
24
25
|
from .modelpool import BaseModelPool
|
|
25
|
-
from .models import
|
|
26
|
+
from .models import (
|
|
27
|
+
create_default_model_card,
|
|
28
|
+
load_model_card_template,
|
|
29
|
+
save_pretrained_with_remote_code,
|
|
30
|
+
separate_io,
|
|
31
|
+
)
|
|
32
|
+
from .programs import BaseHydraProgram
|
|
26
33
|
from .taskpool import BaseTaskPool
|
|
27
|
-
from .utils import
|
|
34
|
+
from .utils import (
|
|
35
|
+
cache_with_joblib,
|
|
36
|
+
get_rankzero_logger,
|
|
37
|
+
import_object,
|
|
38
|
+
instantiate,
|
|
39
|
+
parse_dtype,
|
|
40
|
+
print_parameters,
|
|
41
|
+
seed_everything_by_time,
|
|
42
|
+
set_default_cache_dir,
|
|
43
|
+
set_print_function_call,
|
|
44
|
+
set_print_function_call_permeanent,
|
|
45
|
+
timeit_context,
|
|
46
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RuntimeConstants:
|
|
7
|
+
"""
|
|
8
|
+
This class holds constants related to the runtime environment of the Fusion Bench framework.
|
|
9
|
+
It includes default values for cache directories and other runtime configurations.
|
|
10
|
+
|
|
11
|
+
Implemented as a thread-safe singleton to ensure consistent runtime configuration
|
|
12
|
+
across the entire application.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
_instance: Optional["RuntimeConstants"] = None
|
|
16
|
+
_lock = threading.Lock()
|
|
17
|
+
|
|
18
|
+
def __new__(cls) -> "RuntimeConstants":
|
|
19
|
+
"""Create a new instance using singleton pattern with thread safety."""
|
|
20
|
+
with cls._lock:
|
|
21
|
+
# Double-check locking pattern
|
|
22
|
+
if cls._instance is None:
|
|
23
|
+
cls._instance = super(RuntimeConstants, cls).__new__(cls)
|
|
24
|
+
cls._instance._initialized = False
|
|
25
|
+
return cls._instance
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
"""Initialize the singleton instance only once."""
|
|
29
|
+
if not self._initialized:
|
|
30
|
+
# Add your runtime constants here
|
|
31
|
+
self._initialized = True
|
|
32
|
+
|
|
33
|
+
debug = False
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def cache_dir(self) -> Path:
|
|
37
|
+
from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
|
|
38
|
+
|
|
39
|
+
return DEFAULT_CACHE_DIR
|
|
40
|
+
|
|
41
|
+
@cache_dir.setter
|
|
42
|
+
def cache_dir(self, path: Union[str, Path]) -> None:
|
|
43
|
+
from fusion_bench.utils.cache_utils import set_default_cache_dir
|
|
44
|
+
|
|
45
|
+
set_default_cache_dir(path)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def print_function_call(self) -> bool:
|
|
49
|
+
from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
|
|
50
|
+
|
|
51
|
+
return PRINT_FUNCTION_CALL
|
|
52
|
+
|
|
53
|
+
@print_function_call.setter
|
|
54
|
+
def print_function_call(self, enable: bool) -> None:
|
|
55
|
+
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
56
|
+
|
|
57
|
+
set_print_function_call(enable)
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -90,7 +90,10 @@ _import_structure = {
|
|
|
90
90
|
"MixtralForCausalLMMergingAlgorithm",
|
|
91
91
|
],
|
|
92
92
|
"dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
|
|
93
|
-
"we_moe": [
|
|
93
|
+
"we_moe": [
|
|
94
|
+
"CLIPWeightEnsemblingMoEAlgorithm",
|
|
95
|
+
"FlanT5WeightEnsemblingMoEAlgorithm",
|
|
96
|
+
],
|
|
94
97
|
"rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
|
|
95
98
|
"sparse_we_moe": [
|
|
96
99
|
"SparseWeightEnsemblingMoEAlgorithm",
|
|
@@ -228,7 +231,10 @@ if TYPE_CHECKING:
|
|
|
228
231
|
from .task_arithmetic import TaskArithmeticAlgorithm
|
|
229
232
|
from .task_singular_vector import TaskSingularVectorMerging
|
|
230
233
|
from .ties_merging import TiesMergingAlgorithm
|
|
231
|
-
from .we_moe import
|
|
234
|
+
from .we_moe import (
|
|
235
|
+
CLIPWeightEnsemblingMoEAlgorithm,
|
|
236
|
+
FlanT5WeightEnsemblingMoEAlgorithm,
|
|
237
|
+
)
|
|
232
238
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
233
239
|
|
|
234
240
|
else:
|
|
@@ -393,7 +393,7 @@ def convert_l_lora_state_dict_to_hf(
|
|
|
393
393
|
base_model_name: Optional[str] = None,
|
|
394
394
|
):
|
|
395
395
|
"""
|
|
396
|
-
Convert a linearized Lora model's checkpoint to
|
|
396
|
+
Convert a linearized Lora model's checkpoint to huggingface's format.
|
|
397
397
|
|
|
398
398
|
Args:
|
|
399
399
|
pretrained_path (str): The path to the pretrained model.
|
|
@@ -32,7 +32,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
32
32
|
zeroshot_weights = {}
|
|
33
33
|
|
|
34
34
|
_config_mapping = FisherMergingAlgorithm._config_mapping | {
|
|
35
|
-
"zeroshot_weights_cache_dir": "zeroshot_weights_cache_dir",
|
|
36
35
|
"_dataloader_kwargs": "dataloader_kwargs",
|
|
37
36
|
}
|
|
38
37
|
|
|
@@ -44,7 +43,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
44
43
|
minimal_fisher_weight,
|
|
45
44
|
num_fisher_examples,
|
|
46
45
|
dataloader_kwargs: DictConfig,
|
|
47
|
-
zeroshot_weights_cache_dir=None,
|
|
48
46
|
**kwargs,
|
|
49
47
|
):
|
|
50
48
|
"""
|
|
@@ -56,7 +54,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
56
54
|
minimal_fisher_weight (float): Minimal value for Fisher weights to avoid numerical issues.
|
|
57
55
|
num_fisher_examples (int): Number of examples to compute Fisher weights.
|
|
58
56
|
dataloader_kwargs (DictConfig): Configuration for the dataloader.
|
|
59
|
-
zeroshot_weights_cache_dir (str, optional): Directory to cache zero-shot weights. Defaults to None.
|
|
60
57
|
**kwargs: Additional keyword arguments.
|
|
61
58
|
"""
|
|
62
59
|
super().__init__(
|
|
@@ -66,7 +63,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
66
63
|
num_fisher_examples=num_fisher_examples,
|
|
67
64
|
)
|
|
68
65
|
self.dataloader_kwargs = dataloader_kwargs
|
|
69
|
-
self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
|
|
70
66
|
for key, value in kwargs.items():
|
|
71
67
|
log.warning(f"Unused argument: {key}={value}")
|
|
72
68
|
setattr(self, key, value)
|
|
@@ -15,10 +15,10 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
|
|
|
15
15
|
from transformers.data import default_data_collator
|
|
16
16
|
from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
17
17
|
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, auto_register_config
|
|
19
19
|
from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
|
|
20
20
|
from fusion_bench.utils import timeit_context
|
|
21
|
-
|
|
21
|
+
|
|
22
22
|
from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
|
|
23
23
|
|
|
24
24
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from copy import deepcopy
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
@@ -7,13 +8,16 @@ from typing_extensions import override
|
|
|
7
8
|
from fusion_bench import timeit_context
|
|
8
9
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
10
|
from fusion_bench.method.simple_average import SimpleAverageAlgorithm
|
|
11
|
+
from fusion_bench.mixins import auto_register_config
|
|
10
12
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
13
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
11
14
|
from fusion_bench.utils import instantiate
|
|
12
|
-
from fusion_bench.utils.pylogger import
|
|
15
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
13
16
|
|
|
14
|
-
log =
|
|
17
|
+
log = get_rankzero_logger(__name__)
|
|
15
18
|
|
|
16
19
|
|
|
20
|
+
@auto_register_config
|
|
17
21
|
class SimpleAverageForLlama(BaseAlgorithm):
|
|
18
22
|
R"""
|
|
19
23
|
A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
|
|
@@ -29,21 +33,14 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
29
33
|
```
|
|
30
34
|
"""
|
|
31
35
|
|
|
32
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
33
|
-
"merge_backbone": "merge_backbone",
|
|
34
|
-
"show_pbar": "show_pbar",
|
|
35
|
-
}
|
|
36
|
-
|
|
37
36
|
def __init__(
|
|
38
37
|
self,
|
|
39
38
|
merge_backbone: bool,
|
|
40
39
|
model_save_path: Optional[str] = None,
|
|
41
40
|
show_pbar: bool = False,
|
|
41
|
+
**kwargs,
|
|
42
42
|
):
|
|
43
|
-
super().__init__()
|
|
44
|
-
self.merge_backbone = merge_backbone
|
|
45
|
-
self.model_save_path = model_save_path
|
|
46
|
-
self.show_pbar = show_pbar
|
|
43
|
+
super().__init__(**kwargs)
|
|
47
44
|
|
|
48
45
|
@override
|
|
49
46
|
def run(self, modelpool: CausalLMPool):
|
|
@@ -75,4 +72,12 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
75
72
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
76
73
|
tokenizer.save_pretrained(self.model_save_path)
|
|
77
74
|
model.save_pretrained(self.model_save_path)
|
|
75
|
+
model_card_str = create_default_model_card(
|
|
76
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
77
|
+
description="Merged model using simple averaging.",
|
|
78
|
+
algorithm_config=self.config,
|
|
79
|
+
modelpool_config=modelpool.config,
|
|
80
|
+
)
|
|
81
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
82
|
+
f.write(model_card_str)
|
|
78
83
|
return model
|
|
@@ -61,8 +61,8 @@ def simple_average(
|
|
|
61
61
|
|
|
62
62
|
@auto_register_config
|
|
63
63
|
class SimpleAverageAlgorithm(
|
|
64
|
-
BaseAlgorithm,
|
|
65
64
|
SimpleProfilerMixin,
|
|
65
|
+
BaseAlgorithm,
|
|
66
66
|
):
|
|
67
67
|
def __init__(self, show_pbar: bool = False, **kwargs):
|
|
68
68
|
"""
|
|
@@ -120,13 +120,13 @@ class SimpleAverageAlgorithm(
|
|
|
120
120
|
if isinstance(forward_model, LazyStateDict):
|
|
121
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
122
122
|
forward_model = forward_model.meta_module.to_empty(
|
|
123
|
-
device=
|
|
124
|
-
"cpu"
|
|
125
|
-
if forward_model._torch_dtype is None
|
|
126
|
-
else forward_model._torch_dtype
|
|
127
|
-
)
|
|
123
|
+
device=forward_model._device
|
|
128
124
|
)
|
|
129
|
-
forward_model.load_state_dict(sd)
|
|
125
|
+
result = forward_model.load_state_dict(sd, strict=False)
|
|
126
|
+
if result.unexpected_keys:
|
|
127
|
+
raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
|
|
128
|
+
if result.missing_keys:
|
|
129
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
130
130
|
# print profile report and log the merged models
|
|
131
131
|
self.print_profile_summary()
|
|
132
132
|
log.info(f"merged {len(merged_model_names)} models:")
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from accelerate import init_empty_weights
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
from transformers import (
|
|
10
|
+
AutoConfig,
|
|
11
|
+
AutoModelForCausalLM,
|
|
12
|
+
AutoTokenizer,
|
|
13
|
+
LlamaForCausalLM,
|
|
14
|
+
MistralForCausalLM,
|
|
15
|
+
PretrainedConfig,
|
|
16
|
+
PreTrainedModel,
|
|
17
|
+
Qwen2ForCausalLM,
|
|
18
|
+
)
|
|
19
|
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
20
|
+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
21
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
22
|
+
|
|
23
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
24
|
+
from fusion_bench.compat.modelpool import to_modelpool
|
|
25
|
+
from fusion_bench.constants import RuntimeConstants
|
|
26
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
27
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
28
|
+
from fusion_bench.models.hf_utils import (
|
|
29
|
+
create_default_model_card,
|
|
30
|
+
save_pretrained_with_remote_code,
|
|
31
|
+
)
|
|
32
|
+
from fusion_bench.models.modeling_smile_llama import (
|
|
33
|
+
SmileLlamaConfig,
|
|
34
|
+
SmileLlamaForCausalLM,
|
|
35
|
+
SmileLlamaModel,
|
|
36
|
+
)
|
|
37
|
+
from fusion_bench.models.modeling_smile_llama.modeling_smile_llama import (
|
|
38
|
+
SmileLlamaDecoderLayer,
|
|
39
|
+
)
|
|
40
|
+
from fusion_bench.models.modeling_smile_mistral import (
|
|
41
|
+
SmileMistralConfig,
|
|
42
|
+
SmileMistralForCausalLM,
|
|
43
|
+
SmileMistralModel,
|
|
44
|
+
)
|
|
45
|
+
from fusion_bench.models.modeling_smile_mistral.modeling_smile_mistral import (
|
|
46
|
+
SmileMistralDecoderLayer,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Import all SMILE configurations and models
|
|
50
|
+
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
51
|
+
SmileQwen2Config,
|
|
52
|
+
SmileQwen2ForCausalLM,
|
|
53
|
+
SmileQwen2Model,
|
|
54
|
+
)
|
|
55
|
+
from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
|
|
56
|
+
SmileQwen2DecoderLayer,
|
|
57
|
+
)
|
|
58
|
+
from fusion_bench.models.smile_moe.linear_from_hf_config import (
|
|
59
|
+
ExpertNotTrainedError,
|
|
60
|
+
upscale_to_smile_linear,
|
|
61
|
+
)
|
|
62
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
63
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
64
|
+
|
|
65
|
+
log = logging.getLogger(__name__)
|
|
66
|
+
|
|
67
|
+
# Model type mappings
|
|
68
|
+
MODEL_TYPE_MAPPINGS = {
|
|
69
|
+
"qwen2": {
|
|
70
|
+
"base_model_cls": Qwen2ForCausalLM,
|
|
71
|
+
"base_decoder_layer_cls": Qwen2DecoderLayer,
|
|
72
|
+
"smile_config_cls": SmileQwen2Config,
|
|
73
|
+
"smile_model_cls": SmileQwen2ForCausalLM,
|
|
74
|
+
"smile_base_model_cls": SmileQwen2Model,
|
|
75
|
+
"smile_decoder_layer_cls": SmileQwen2DecoderLayer,
|
|
76
|
+
"description": "Qwen2",
|
|
77
|
+
},
|
|
78
|
+
"llama": {
|
|
79
|
+
"base_model_cls": LlamaForCausalLM,
|
|
80
|
+
"base_decoder_layer_cls": LlamaDecoderLayer,
|
|
81
|
+
"smile_config_cls": SmileLlamaConfig,
|
|
82
|
+
"smile_model_cls": SmileLlamaForCausalLM,
|
|
83
|
+
"smile_base_model_cls": SmileLlamaModel,
|
|
84
|
+
"smile_decoder_layer_cls": SmileLlamaDecoderLayer,
|
|
85
|
+
"description": "Llama",
|
|
86
|
+
},
|
|
87
|
+
"mistral": {
|
|
88
|
+
"base_model_cls": MistralForCausalLM,
|
|
89
|
+
"base_decoder_layer_cls": MistralDecoderLayer,
|
|
90
|
+
"smile_config_cls": SmileMistralConfig,
|
|
91
|
+
"smile_model_cls": SmileMistralForCausalLM,
|
|
92
|
+
"smile_base_model_cls": SmileMistralModel,
|
|
93
|
+
"smile_decoder_layer_cls": SmileMistralDecoderLayer,
|
|
94
|
+
"description": "Mistral",
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def detect_model_type(
|
|
100
|
+
model_or_config: Union[PreTrainedModel, PretrainedConfig, str],
|
|
101
|
+
) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Detect the model type from a model, config, or model name/path.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model_or_config: Model, config, or model name/path to detect type from
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
str: The detected model type ("qwen2", "llama", "mistral")
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If model type cannot be detected or is not supported
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(model_or_config, str):
|
|
115
|
+
# Load config from path/name
|
|
116
|
+
config = AutoConfig.from_pretrained(model_or_config)
|
|
117
|
+
elif isinstance(model_or_config, PreTrainedModel):
|
|
118
|
+
config = model_or_config.config
|
|
119
|
+
elif isinstance(model_or_config, PretrainedConfig):
|
|
120
|
+
config = model_or_config
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Unsupported type for model type detection: {type(model_or_config)}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
model_type = getattr(config, "model_type", "").lower()
|
|
127
|
+
|
|
128
|
+
# Handle various model type variations
|
|
129
|
+
if model_type in MODEL_TYPE_MAPPINGS:
|
|
130
|
+
return model_type
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Unsupported model type: {model_type}. Supported types: {list(MODEL_TYPE_MAPPINGS.keys())}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@auto_register_config
|
|
138
|
+
class SmileCausalLMUpscalingAlgorithm(
|
|
139
|
+
SimpleProfilerMixin,
|
|
140
|
+
BaseAlgorithm,
|
|
141
|
+
):
|
|
142
|
+
R"""
|
|
143
|
+
SmileCausalLMUpscalingAlgorithm is a generic model fusion algorithm designed to upscale
|
|
144
|
+
a pretrained CausalLM model using a set of fine-tuned expert models. The algorithm
|
|
145
|
+
supports Qwen2, Llama, and Mistral model architectures and leverages Singular Value
|
|
146
|
+
Decomposition (SVD) to merge the weights of the pretrained model and the expert models
|
|
147
|
+
into a new upscaled model.
|
|
148
|
+
|
|
149
|
+
The algorithm automatically detects the model type and uses the appropriate SMILE
|
|
150
|
+
configuration and model classes.
|
|
151
|
+
|
|
152
|
+
Methods:
|
|
153
|
+
run(modelpool: BaseModelPool) -> Union[SmileQwen2ForCausalLM, SmileLlamaForCausalLM, SmileMistralForCausalLM]:
|
|
154
|
+
Executes the upscaling process and returns the upscaled model.
|
|
155
|
+
|
|
156
|
+
merge(pretrained_model: PreTrainedModel, finetuned_models: List[PreTrainedModel]) -> PreTrainedModel:
|
|
157
|
+
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
modelpool: CausalLMPool
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
device,
|
|
165
|
+
accelerator,
|
|
166
|
+
model_save_path,
|
|
167
|
+
model_dtype,
|
|
168
|
+
num_experts_per_tok,
|
|
169
|
+
rank_of_router,
|
|
170
|
+
rank_of_expert,
|
|
171
|
+
save_with_remote_code: bool = True,
|
|
172
|
+
model_type: str = None, # Optional: explicitly specify model type
|
|
173
|
+
**kwargs,
|
|
174
|
+
):
|
|
175
|
+
super().__init__(**kwargs)
|
|
176
|
+
self.model_mappings = None # Will be set during run()
|
|
177
|
+
|
|
178
|
+
if not torch.cuda.is_available():
|
|
179
|
+
if "cuda" in self.device:
|
|
180
|
+
self.device = "cpu"
|
|
181
|
+
if "cuda" in self.accelerator:
|
|
182
|
+
self.accelerator = "cpu"
|
|
183
|
+
|
|
184
|
+
@torch.no_grad()
|
|
185
|
+
def run(self, modelpool) -> PreTrainedModel:
|
|
186
|
+
"""
|
|
187
|
+
Executes the upscaling process.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
modelpool (ModelPool): The pool of models to be used for upscaling.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
PreTrainedModel: The upscaled model (specific type depends on detected model architecture).
|
|
194
|
+
"""
|
|
195
|
+
self.modelpool = modelpool = to_modelpool(modelpool)
|
|
196
|
+
config = self.config
|
|
197
|
+
|
|
198
|
+
# Auto-detect model type if not specified
|
|
199
|
+
if self.model_type is None:
|
|
200
|
+
self.model_type = detect_model_type(
|
|
201
|
+
modelpool.get_model_path("_pretrained_")
|
|
202
|
+
)
|
|
203
|
+
log.info(f"Auto-detected model type: {self.model_type}")
|
|
204
|
+
|
|
205
|
+
# Get the appropriate model mappings
|
|
206
|
+
if self.model_type not in MODEL_TYPE_MAPPINGS:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Unsupported model type: {self.model_type}. Supported: {list(MODEL_TYPE_MAPPINGS.keys())}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self.model_mappings = MODEL_TYPE_MAPPINGS[self.model_type]
|
|
212
|
+
log.info(f"Using {self.model_mappings['description']} model architecture")
|
|
213
|
+
|
|
214
|
+
with self.profile("load pretrained model"):
|
|
215
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
216
|
+
|
|
217
|
+
with self.profile("load fine-tuned model"):
|
|
218
|
+
finetuned_models = [
|
|
219
|
+
m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
if self.device == "cuda" and torch.cuda.is_available():
|
|
223
|
+
pretrained_model = pretrained_model.cuda()
|
|
224
|
+
print("parameter count of pretrained model:")
|
|
225
|
+
print_parameters(pretrained_model)
|
|
226
|
+
finetuned_models = [m.cuda() for m in finetuned_models]
|
|
227
|
+
|
|
228
|
+
with self.profile("merge model"):
|
|
229
|
+
model = self.merge(pretrained_model, finetuned_models)
|
|
230
|
+
|
|
231
|
+
self.print_profile_summary()
|
|
232
|
+
print("parameter count of upscaled MoE model:")
|
|
233
|
+
print_parameters(model)
|
|
234
|
+
print(model)
|
|
235
|
+
|
|
236
|
+
if self.model_dtype is not None:
|
|
237
|
+
model.to(dtype=parse_dtype(self.model_dtype))
|
|
238
|
+
|
|
239
|
+
if self.model_save_path is not None:
|
|
240
|
+
if os.path.dirname(self.model_save_path):
|
|
241
|
+
os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
|
|
242
|
+
log.info(f"Saving model to {self.model_save_path}")
|
|
243
|
+
tokenizer = self.modelpool.load_tokenizer()
|
|
244
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
245
|
+
if not self.save_with_remote_code:
|
|
246
|
+
model.save_pretrained(self.model_save_path)
|
|
247
|
+
else:
|
|
248
|
+
# Use the appropriate auto_map for the detected model type
|
|
249
|
+
auto_map = {
|
|
250
|
+
"AutoConfig": self.model_mappings["smile_config_cls"],
|
|
251
|
+
"AutoModel": self.model_mappings["smile_base_model_cls"],
|
|
252
|
+
"AutoModelForCausalLM": self.model_mappings["smile_model_cls"],
|
|
253
|
+
}
|
|
254
|
+
save_pretrained_with_remote_code(
|
|
255
|
+
model,
|
|
256
|
+
auto_map=auto_map,
|
|
257
|
+
save_directory=self.model_save_path,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# save readme
|
|
261
|
+
model_card_str = create_default_model_card(
|
|
262
|
+
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
263
|
+
description=f"Merged {self.model_mappings['description']} model using SMILE Upscaling",
|
|
264
|
+
algorithm_config=self.config,
|
|
265
|
+
modelpool_config=modelpool.config,
|
|
266
|
+
)
|
|
267
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
268
|
+
f.write(model_card_str)
|
|
269
|
+
|
|
270
|
+
return model
|
|
271
|
+
|
|
272
|
+
def merge(
|
|
273
|
+
self,
|
|
274
|
+
pretrained_model: PreTrainedModel,
|
|
275
|
+
finetuned_models: List[PreTrainedModel],
|
|
276
|
+
) -> PreTrainedModel:
|
|
277
|
+
"""
|
|
278
|
+
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
pretrained_model (PreTrainedModel): The pretrained model.
|
|
282
|
+
finetuned_models (List[PreTrainedModel]): A list of fine-tuned models.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
PreTrainedModel: The upscaled model (specific type depends on model architecture).
|
|
286
|
+
"""
|
|
287
|
+
with init_empty_weights():
|
|
288
|
+
pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
289
|
+
if isinstance(pretrained_model_config, str):
|
|
290
|
+
pretrained_path = pretrained_model_config
|
|
291
|
+
else:
|
|
292
|
+
pretrained_path = pretrained_model_config.get(
|
|
293
|
+
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
294
|
+
)
|
|
295
|
+
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
296
|
+
|
|
297
|
+
# Create the appropriate SMILE config for the detected model type
|
|
298
|
+
SmileConfigClass = self.model_mappings["smile_config_cls"]
|
|
299
|
+
model_config = SmileConfigClass(
|
|
300
|
+
num_experts_per_tok=self.num_experts_per_tok,
|
|
301
|
+
rank_of_router=self.rank_of_router,
|
|
302
|
+
rank_of_expert=self.rank_of_expert,
|
|
303
|
+
num_local_experts=len(finetuned_models),
|
|
304
|
+
**base_config.to_dict(),
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Create the appropriate SMILE model for the detected model type
|
|
308
|
+
SmileModelClass = self.model_mappings["smile_model_cls"]
|
|
309
|
+
model = SmileModelClass(model_config)
|
|
310
|
+
|
|
311
|
+
model.to(dtype=pretrained_model.dtype).to_empty(device="cpu")
|
|
312
|
+
|
|
313
|
+
# copy pretrained model weights
|
|
314
|
+
state_dict = model.state_dict()
|
|
315
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
316
|
+
for key in list(pretrained_state_dict.keys()):
|
|
317
|
+
if key not in state_dict:
|
|
318
|
+
pretrained_state_dict.pop(key)
|
|
319
|
+
model.load_state_dict(pretrained_state_dict, strict=False)
|
|
320
|
+
|
|
321
|
+
# upscale model
|
|
322
|
+
BaseDecoderLayerClass = self.model_mappings["base_decoder_layer_cls"]
|
|
323
|
+
SmileDecoderLayerClass = self.model_mappings["smile_decoder_layer_cls"]
|
|
324
|
+
|
|
325
|
+
for layer_idx in tqdm(
|
|
326
|
+
range(len(pretrained_model.model.layers)),
|
|
327
|
+
"Upscaling Modules (layer)",
|
|
328
|
+
dynamic_ncols=True,
|
|
329
|
+
):
|
|
330
|
+
if RuntimeConstants.debug and layer_idx > 0:
|
|
331
|
+
log.info(
|
|
332
|
+
"Debug mode enabled: processing only the first layer, skipping remaining layers"
|
|
333
|
+
)
|
|
334
|
+
break
|
|
335
|
+
|
|
336
|
+
pretrained_layer = pretrained_model.model.layers[layer_idx]
|
|
337
|
+
finetuned_layers = [m.model.layers[layer_idx] for m in finetuned_models]
|
|
338
|
+
|
|
339
|
+
target_layer = model.model.layers[layer_idx]
|
|
340
|
+
|
|
341
|
+
for n in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
342
|
+
try:
|
|
343
|
+
upscale_to_smile_linear(
|
|
344
|
+
base=getattr(pretrained_layer.self_attn, n),
|
|
345
|
+
experts=[getattr(m.self_attn, n) for m in finetuned_layers],
|
|
346
|
+
target=getattr(target_layer.self_attn, n),
|
|
347
|
+
accelerator=self.accelerator,
|
|
348
|
+
)
|
|
349
|
+
except ExpertNotTrainedError:
|
|
350
|
+
setattr(
|
|
351
|
+
target_layer.self_attn,
|
|
352
|
+
n,
|
|
353
|
+
getattr(pretrained_layer.self_attn, n),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
for n in ["gate_proj", "up_proj", "down_proj"]:
|
|
357
|
+
try:
|
|
358
|
+
upscale_to_smile_linear(
|
|
359
|
+
base=getattr(pretrained_layer.mlp, n),
|
|
360
|
+
experts=[getattr(m.mlp, n) for m in finetuned_layers],
|
|
361
|
+
target=getattr(target_layer.mlp, n),
|
|
362
|
+
accelerator=self.accelerator,
|
|
363
|
+
)
|
|
364
|
+
except ExpertNotTrainedError:
|
|
365
|
+
setattr(
|
|
366
|
+
target_layer.mlp,
|
|
367
|
+
n,
|
|
368
|
+
getattr(pretrained_layer.mlp, n),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
return model
|
|
@@ -3,12 +3,11 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
6
7
|
|
|
7
8
|
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
8
9
|
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
9
10
|
|
|
10
|
-
from tqdm import tqdm
|
|
11
|
-
|
|
12
11
|
|
|
13
12
|
class ProjectedEnergyAnalysis(
|
|
14
13
|
SimpleProfilerMixin,
|
|
@@ -20,6 +20,7 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
|
20
20
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
21
21
|
from fusion_bench.method import BaseAlgorithm
|
|
22
22
|
from fusion_bench.method.simple_average import simple_average
|
|
23
|
+
from fusion_bench.mixins import auto_register_config
|
|
23
24
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
24
25
|
from fusion_bench.modelpool import BaseModelPool
|
|
25
26
|
from fusion_bench.models.modeling_smile_mistral import (
|
|
@@ -40,7 +41,10 @@ from fusion_bench.utils.parameters import print_parameters
|
|
|
40
41
|
log = logging.getLogger(__name__)
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
class SmileMistralUpscalingAlgorithm(
|
|
44
|
+
class SmileMistralUpscalingAlgorithm(
|
|
45
|
+
SimpleProfilerMixin,
|
|
46
|
+
BaseAlgorithm,
|
|
47
|
+
):
|
|
44
48
|
R"""
|
|
45
49
|
SmileMistralUpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
46
50
|
a pretrained Mistral model using a set of fine-tuned expert models. The algorithm
|