fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -8,7 +8,7 @@ from copy import deepcopy
|
|
|
8
8
|
from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
|
|
9
9
|
|
|
10
10
|
import peft
|
|
11
|
-
from omegaconf import DictConfig, flag_override
|
|
11
|
+
from omegaconf import DictConfig, OmegaConf, flag_override
|
|
12
12
|
from torch import nn
|
|
13
13
|
from torch.nn.modules import Module
|
|
14
14
|
from transformers import (
|
|
@@ -19,43 +19,51 @@ from transformers import (
|
|
|
19
19
|
)
|
|
20
20
|
from typing_extensions import override
|
|
21
21
|
|
|
22
|
-
from fusion_bench
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
from fusion_bench import (
|
|
23
|
+
BaseModelPool,
|
|
24
|
+
auto_register_config,
|
|
25
|
+
import_object,
|
|
26
|
+
instantiate,
|
|
27
|
+
parse_dtype,
|
|
28
|
+
)
|
|
25
29
|
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
26
|
-
from fusion_bench.utils.packages import import_object
|
|
27
30
|
|
|
28
31
|
log = logging.getLogger(__name__)
|
|
29
32
|
|
|
30
33
|
|
|
34
|
+
@auto_register_config
|
|
31
35
|
class CausalLMPool(BaseModelPool):
|
|
32
|
-
_config_mapping = BaseModelPool._config_mapping | {
|
|
33
|
-
"_tokenizer": "tokenizer",
|
|
34
|
-
"_model_kwargs": "model_kwargs",
|
|
35
|
-
"load_lazy": "load_lazy",
|
|
36
|
-
}
|
|
37
|
-
|
|
38
36
|
def __init__(
|
|
39
37
|
self,
|
|
40
38
|
models,
|
|
41
39
|
*,
|
|
42
|
-
tokenizer: Optional[DictConfig],
|
|
40
|
+
tokenizer: Optional[DictConfig | str],
|
|
43
41
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
-
|
|
42
|
+
enable_lazy_loading: bool = False,
|
|
45
43
|
**kwargs,
|
|
46
44
|
):
|
|
47
45
|
super().__init__(models, **kwargs)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
46
|
+
if model_kwargs is None:
|
|
47
|
+
self.model_kwargs = DictConfig({})
|
|
48
|
+
|
|
49
|
+
def get_model_path(self, model_name: str):
|
|
50
|
+
model_name_or_config = self._models[model_name]
|
|
51
|
+
if isinstance(model_name_or_config, str):
|
|
52
|
+
return model_name_or_config
|
|
53
|
+
elif isinstance(model_name_or_config, (DictConfig, dict)):
|
|
54
|
+
return model_name_or_config.get("pretrained_model_name_or_path")
|
|
55
|
+
else:
|
|
56
|
+
raise RuntimeError("Invalid model configuration")
|
|
57
|
+
|
|
58
|
+
def get_model_kwargs(self):
|
|
59
|
+
model_kwargs = (
|
|
60
|
+
OmegaConf.to_container(self.model_kwargs, resolve=True)
|
|
61
|
+
if isinstance(self.model_kwargs, DictConfig)
|
|
62
|
+
else self.model_kwargs
|
|
63
|
+
)
|
|
64
|
+
if "torch_dtype" in model_kwargs:
|
|
65
|
+
model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
|
|
66
|
+
return model_kwargs
|
|
59
67
|
|
|
60
68
|
@override
|
|
61
69
|
def load_model(
|
|
@@ -89,7 +97,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
89
97
|
pretrained_model_name_or_path: path_to_model_b
|
|
90
98
|
```
|
|
91
99
|
"""
|
|
92
|
-
model_kwargs =
|
|
100
|
+
model_kwargs = self.get_model_kwargs()
|
|
93
101
|
model_kwargs.update(kwargs)
|
|
94
102
|
|
|
95
103
|
if isinstance(model_name_or_config, str):
|
|
@@ -99,7 +107,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
99
107
|
model_config = self._models[model_name_or_config]
|
|
100
108
|
if isinstance(model_config, str):
|
|
101
109
|
# model_config is a string
|
|
102
|
-
if not self.
|
|
110
|
+
if not self.enable_lazy_loading:
|
|
103
111
|
model = AutoModelForCausalLM.from_pretrained(
|
|
104
112
|
model_config,
|
|
105
113
|
*args,
|
|
@@ -117,7 +125,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
117
125
|
elif isinstance(model_name_or_config, (DictConfig, Dict)):
|
|
118
126
|
model_config = model_name_or_config
|
|
119
127
|
|
|
120
|
-
if not self.
|
|
128
|
+
if not self.enable_lazy_loading:
|
|
121
129
|
model = instantiate(model_config, *args, **model_kwargs)
|
|
122
130
|
else:
|
|
123
131
|
meta_module_class = model_config.pop("_target_")
|
|
@@ -149,12 +157,12 @@ class CausalLMPool(BaseModelPool):
|
|
|
149
157
|
Returns:
|
|
150
158
|
PreTrainedTokenizer: The tokenizer.
|
|
151
159
|
"""
|
|
152
|
-
assert self.
|
|
160
|
+
assert self.tokenizer is not None, "Tokenizer is not defined in the config"
|
|
153
161
|
log.info("Loading tokenizer.", stacklevel=2)
|
|
154
|
-
if isinstance(self.
|
|
155
|
-
tokenizer = AutoTokenizer.from_pretrained(self.
|
|
162
|
+
if isinstance(self.tokenizer, str):
|
|
163
|
+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
|
|
156
164
|
else:
|
|
157
|
-
tokenizer = instantiate(self.
|
|
165
|
+
tokenizer = instantiate(self.tokenizer, *args, **kwargs)
|
|
158
166
|
return tokenizer
|
|
159
167
|
|
|
160
168
|
@override
|
|
@@ -204,12 +212,12 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
204
212
|
def load_model(
|
|
205
213
|
self, model_name_or_config: str | DictConfig, *args, **kwargs
|
|
206
214
|
) -> Module:
|
|
207
|
-
if self.
|
|
215
|
+
if self.enable_lazy_loading:
|
|
208
216
|
log.warning(
|
|
209
217
|
"CausalLMBackbonePool does not support lazy loading. "
|
|
210
218
|
"Falling back to normal loading."
|
|
211
219
|
)
|
|
212
|
-
self.
|
|
220
|
+
self.enable_lazy_loading = False
|
|
213
221
|
model: AutoModelForCausalLM = super().load_model(
|
|
214
222
|
model_name_or_config, *args, **kwargs
|
|
215
223
|
)
|
|
@@ -11,9 +11,7 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import instantiate, timeit_context
|
|
14
|
-
from fusion_bench.utils.modelscope import
|
|
15
|
-
resolve_repo_path,
|
|
16
|
-
)
|
|
14
|
+
from fusion_bench.utils.modelscope import resolve_repo_path
|
|
17
15
|
|
|
18
16
|
from ..base_pool import BaseModelPool
|
|
19
17
|
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .reward_model import create_reward_model_from_pretrained
|
|
2
|
-
from .seq_classification_lm import
|
|
2
|
+
from .seq_classification_lm import SequenceClassificationModelPool
|
fusion_bench/models/__init__.py
CHANGED
|
@@ -1,4 +1,10 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
+
from fusion_bench.utils import LazyStateDict
|
|
3
|
+
|
|
2
4
|
from . import separate_io, utils
|
|
5
|
+
from .hf_utils import (
|
|
6
|
+
create_default_model_card,
|
|
7
|
+
load_model_card_template,
|
|
8
|
+
save_pretrained_with_remote_code,
|
|
9
|
+
)
|
|
3
10
|
from .parameter_dict import ParameterDictModel
|
|
4
|
-
from fusion_bench.utils import LazyStateDict
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains utilities for working with Hugging Face models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
import os
|
|
7
|
+
import shutil
|
|
8
|
+
from typing import List, Optional, cast
|
|
9
|
+
|
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
|
11
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
14
|
+
|
|
15
|
+
log = get_rankzero_logger(__name__)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"load_model_card_template",
|
|
19
|
+
"save_pretrained_with_remote_code",
|
|
20
|
+
"create_default_model_card",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
MODEL_CARD_TEMPLATE_DIRS = [
|
|
24
|
+
os.path.join(os.path.dirname(__file__), "model_card_templates")
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_model_card_template(basename: str) -> str:
|
|
29
|
+
"""
|
|
30
|
+
Load a model card template from file.
|
|
31
|
+
|
|
32
|
+
Searches for a template file by name, first checking if the name is a direct file path,
|
|
33
|
+
then searching through predefined template directories.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name (str): The name of the template file or a direct file path to the template.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The contents of the template file as a string.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
FileNotFoundError: If the template file is not found in any of the search locations.
|
|
43
|
+
"""
|
|
44
|
+
if os.path.exists(basename):
|
|
45
|
+
return open(basename).read()
|
|
46
|
+
|
|
47
|
+
for template_dir in MODEL_CARD_TEMPLATE_DIRS:
|
|
48
|
+
template_path = os.path.join(template_dir, basename)
|
|
49
|
+
if os.path.exists(template_path):
|
|
50
|
+
return open(template_path).read()
|
|
51
|
+
|
|
52
|
+
raise FileNotFoundError(f"Model card template '{basename}' not found.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def try_to_yaml(config):
|
|
56
|
+
if config is None:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
return OmegaConf.to_yaml(config, resolve=True, sort_keys=True)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
log.error(f"Failed to convert config to YAML: {e}. Return `None`.")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def save_pretrained_with_remote_code(
|
|
67
|
+
model: PreTrainedModel,
|
|
68
|
+
auto_map: dict[str, object],
|
|
69
|
+
save_directory,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Saves a model with custom code to a directory.
|
|
74
|
+
|
|
75
|
+
This function facilitates saving a Hugging Face `PreTrainedModel` along with its
|
|
76
|
+
associated custom code. It inspects the objects provided in the `auto_map`,
|
|
77
|
+
copies their source files to the `save_directory`, and generates an `__init__.py`
|
|
78
|
+
to make them importable. It also updates the model's configuration with an
|
|
79
|
+
`auto_map` attribute, which allows `AutoModel.from_pretrained` to correctly
|
|
80
|
+
instantiate the custom model classes when `trust_remote_code=True`.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model (PreTrainedModel): The model instance to be saved.
|
|
84
|
+
auto_map (dict[str, object]): A dictionary mapping auto class names
|
|
85
|
+
(e.g., "AutoModelForCausalLM") to the corresponding custom class objects.
|
|
86
|
+
save_directory (str or os.PathLike): The directory where the model and
|
|
87
|
+
custom code files will be saved.
|
|
88
|
+
**kwargs: Additional keyword arguments to be passed to the
|
|
89
|
+
`model.save_pretrained` method.
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
```python
|
|
93
|
+
# Assuming `model` is an instance of `SmileQwen2ForCausalLM`
|
|
94
|
+
# and `SmileQwen2Config`, `SmileQwen2Model`, `SmileQwen2ForCausalLM`
|
|
95
|
+
# are custom classes defined in your project.
|
|
96
|
+
|
|
97
|
+
save_pretrained_with_remote_code(
|
|
98
|
+
model,
|
|
99
|
+
auto_map={
|
|
100
|
+
"AutoConfig": SmileQwen2Config,
|
|
101
|
+
"AutoModel": SmileQwen2Model,
|
|
102
|
+
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
103
|
+
},
|
|
104
|
+
save_directory="./my-custom-model",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# The model can then be loaded with `trust_remote_code=True`:
|
|
108
|
+
# from transformers import AutoModelForCausalLM
|
|
109
|
+
# loaded_model = AutoModelForCausalLM.from_pretrained(
|
|
110
|
+
# "./my-custom-model", trust_remote_code=True
|
|
111
|
+
# )
|
|
112
|
+
```
|
|
113
|
+
"""
|
|
114
|
+
auto_map_files = {}
|
|
115
|
+
auto_map_strs = {}
|
|
116
|
+
for key, obj in auto_map.items():
|
|
117
|
+
auto_map_files[key] = inspect.getfile(obj)
|
|
118
|
+
|
|
119
|
+
for key, obj in auto_map.items():
|
|
120
|
+
auto_map_strs[key] = (
|
|
121
|
+
f"{(inspect.getmodule(obj).__name__).split('.')[-1]}.{obj.__name__}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
model.config.auto_map = auto_map_strs
|
|
125
|
+
|
|
126
|
+
# save model to `save_directory`
|
|
127
|
+
model.save_pretrained(save_directory=save_directory, **kwargs)
|
|
128
|
+
|
|
129
|
+
# copy source files to `save_directory`
|
|
130
|
+
for key, file_path in auto_map_files.items():
|
|
131
|
+
shutil.copy(
|
|
132
|
+
src=file_path, dst=os.path.join(save_directory, os.path.basename(file_path))
|
|
133
|
+
)
|
|
134
|
+
# construct `__init__.py`
|
|
135
|
+
init_file = os.path.join(save_directory, "__init__.py")
|
|
136
|
+
with open(init_file, "w") as f:
|
|
137
|
+
for key, file_name in auto_map_files.items():
|
|
138
|
+
base_name = os.path.basename(file_name).split(".")[0]
|
|
139
|
+
f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def create_default_model_card(
|
|
143
|
+
models: list[str],
|
|
144
|
+
description=None,
|
|
145
|
+
algorithm_config: DictConfig = None,
|
|
146
|
+
modelpool_config: DictConfig = None,
|
|
147
|
+
):
|
|
148
|
+
from jinja2 import Template
|
|
149
|
+
|
|
150
|
+
template: Template = Template(load_model_card_template("default.md"))
|
|
151
|
+
card = template.render(
|
|
152
|
+
models=models,
|
|
153
|
+
library_name="transformers",
|
|
154
|
+
tags=["fusion-bench", "merge"],
|
|
155
|
+
title="Deep Model Fusion",
|
|
156
|
+
description=description,
|
|
157
|
+
algorithm_config_str=try_to_yaml(algorithm_config),
|
|
158
|
+
modelpool_config_str=try_to_yaml(modelpool_config),
|
|
159
|
+
)
|
|
160
|
+
return card
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import Optional
|
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
from torch.func import functional_call, jvp
|
|
@@ -9,7 +9,7 @@ from torch.func import functional_call, jvp
|
|
|
9
9
|
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def dict_params_to_tuple(dict_params: dict):
|
|
12
|
+
def dict_params_to_tuple(dict_params: dict) -> Tuple:
|
|
13
13
|
return tuple(v for k, v in dict_params.items())
|
|
14
14
|
|
|
15
15
|
|
|
@@ -33,7 +33,7 @@ class LinearizedModelWraper(nn.Module):
|
|
|
33
33
|
for p in self.params0_values:
|
|
34
34
|
p.requires_grad_(False)
|
|
35
35
|
|
|
36
|
-
def tuple_params_to_dict(self, tuple_params):
|
|
36
|
+
def tuple_params_to_dict(self, tuple_params) -> Dict[str, Any]:
|
|
37
37
|
"""
|
|
38
38
|
Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
|
|
39
39
|
|
|
@@ -50,7 +50,7 @@ class LinearizedModelWraper(nn.Module):
|
|
|
50
50
|
state_dict[k] = p
|
|
51
51
|
return state_dict
|
|
52
52
|
|
|
53
|
-
def forward(self, *args, **kwargs):
|
|
53
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
54
54
|
"""
|
|
55
55
|
Computes the linearized model output using a first-order Taylor decomposition.
|
|
56
56
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
---
|
|
2
|
+
base_model:
|
|
3
|
+
{%- for model in models %}
|
|
4
|
+
- {{ model }}
|
|
5
|
+
{%- endfor %}
|
|
6
|
+
library_name: {{ library_name }}
|
|
7
|
+
tags:
|
|
8
|
+
{%- for tag in tags %}
|
|
9
|
+
- {{ tag }}
|
|
10
|
+
{%- endfor %}
|
|
11
|
+
---
|
|
12
|
+
# {{ title }}
|
|
13
|
+
|
|
14
|
+
{% if description is not none %}{{ description }}{% endif %}
|
|
15
|
+
|
|
16
|
+
## Models Merged
|
|
17
|
+
|
|
18
|
+
This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
19
|
+
|
|
20
|
+
The following models were included in the merge:
|
|
21
|
+
{% for model in models %}
|
|
22
|
+
- {{ model }}
|
|
23
|
+
{%- endfor %}
|
|
24
|
+
|
|
25
|
+
{% if algorithm_config_str is not none or modelpool_config_str is not none %}
|
|
26
|
+
## Configuration
|
|
27
|
+
|
|
28
|
+
The following YAML configuration was used to produce this model:
|
|
29
|
+
|
|
30
|
+
{% if algorithm_config_str is not none -%}
|
|
31
|
+
### Algorithm Configuration
|
|
32
|
+
|
|
33
|
+
```yaml
|
|
34
|
+
{{ algorithm_config_str -}}
|
|
35
|
+
```
|
|
36
|
+
{%- endif %}
|
|
37
|
+
|
|
38
|
+
{% if modelpool_config_str is not none -%}
|
|
39
|
+
### Model Pool Configuration
|
|
40
|
+
|
|
41
|
+
```yaml
|
|
42
|
+
{{ modelpool_config_str -}}
|
|
43
|
+
```
|
|
44
|
+
{%- endif %}
|
|
45
|
+
|
|
46
|
+
{% endif %}
|
|
@@ -4,12 +4,12 @@ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://hug
|
|
|
4
4
|
|
|
5
5
|
from .configuration_deepseek import DeepseekV2Config
|
|
6
6
|
from .modeling_deepseek import (
|
|
7
|
+
DeepseekV2DecoderLayer,
|
|
7
8
|
DeepseekV2ForCausalLM,
|
|
8
9
|
DeepseekV2ForSequenceClassification,
|
|
9
10
|
DeepseekV2MLP,
|
|
10
11
|
DeepseekV2Model,
|
|
11
12
|
DeepseekV2MoE,
|
|
12
|
-
DeepseekV2DecoderLayer,
|
|
13
13
|
)
|
|
14
14
|
from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
|
|
15
15
|
from .tokenization_deepseek_fast import DeepseekTokenizerFast
|
|
@@ -17,17 +17,18 @@
|
|
|
17
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
18
|
# See the License for the specific language governing permissions and
|
|
19
19
|
# limitations under the License.
|
|
20
|
-
"""
|
|
20
|
+
"""PyTorch DeepSeek model."""
|
|
21
21
|
import math
|
|
22
22
|
import warnings
|
|
23
23
|
from typing import List, Optional, Tuple, Union
|
|
24
24
|
|
|
25
|
+
import numpy as np
|
|
25
26
|
import torch
|
|
27
|
+
import torch.distributed as dist
|
|
26
28
|
import torch.nn.functional as F
|
|
27
29
|
import torch.utils.checkpoint
|
|
28
30
|
from torch import nn
|
|
29
31
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
30
|
-
|
|
31
32
|
from transformers.activations import ACT2FN
|
|
32
33
|
from transformers.cache_utils import Cache, DynamicCache
|
|
33
34
|
from transformers.modeling_attn_mask_utils import (
|
|
@@ -54,9 +55,8 @@ from transformers.utils import (
|
|
|
54
55
|
replace_return_docstrings,
|
|
55
56
|
)
|
|
56
57
|
from transformers.utils.import_utils import is_torch_fx_available
|
|
58
|
+
|
|
57
59
|
from .configuration_deepseek import DeepseekV2Config
|
|
58
|
-
import torch.distributed as dist
|
|
59
|
-
import numpy as np
|
|
60
60
|
|
|
61
61
|
if is_flash_attn_2_available():
|
|
62
62
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from . import register
|
|
2
|
+
from .configuration_smile_gemma2 import SmileGemma2Config
|
|
3
|
+
from .modeling_smile_gemma2 import (
|
|
4
|
+
SmileGemma2ForCausalLM,
|
|
5
|
+
SmileGemma2ForSequenceClassification,
|
|
6
|
+
SmileGemma2ForTokenClassification,
|
|
7
|
+
SmileGemma2Model,
|
|
8
|
+
SmileGemma2PreTrainedModel,
|
|
9
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SmileGemma2Config(Gemma2Config):
|
|
5
|
+
model_type = "smile_gemma2"
|
|
6
|
+
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
num_experts_per_tok: int = 1,
|
|
10
|
+
rank_of_router: int = None,
|
|
11
|
+
rank_of_expert: int = None,
|
|
12
|
+
num_local_experts: int = None,
|
|
13
|
+
**kwargs,
|
|
14
|
+
):
|
|
15
|
+
self.num_experts_per_tok = num_experts_per_tok
|
|
16
|
+
self.rank_of_router = rank_of_router
|
|
17
|
+
self.rank_of_expert = rank_of_expert
|
|
18
|
+
self.num_local_experts = num_local_experts
|
|
19
|
+
|
|
20
|
+
super().__init__(**kwargs)
|