fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -3
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +2 -3
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +5 -9
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +5 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Dict, List, Optional, Union
|
|
3
|
+
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from omegaconf import DictConfig
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
-
from fusion_bench.mixins import
|
|
10
|
+
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
11
|
from fusion_bench.utils import instantiate, timeit_context
|
|
12
12
|
|
|
13
13
|
__all__ = ["BaseModelPool"]
|
|
@@ -15,7 +15,10 @@ __all__ = ["BaseModelPool"]
|
|
|
15
15
|
log = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class BaseModelPool(
|
|
18
|
+
class BaseModelPool(
|
|
19
|
+
HydraConfigMixin,
|
|
20
|
+
BaseYAMLSerializable,
|
|
21
|
+
):
|
|
19
22
|
"""
|
|
20
23
|
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
|
|
21
24
|
|
|
@@ -31,7 +34,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
31
34
|
_program = None
|
|
32
35
|
_config_key = "modelpool"
|
|
33
36
|
_models: Union[DictConfig, Dict[str, nn.Module]]
|
|
34
|
-
_config_mapping =
|
|
37
|
+
_config_mapping = BaseYAMLSerializable._config_mapping | {
|
|
35
38
|
"_models": "models",
|
|
36
39
|
"_train_datasets": "train_datasets",
|
|
37
40
|
"_val_datasets": "val_datasets",
|
|
@@ -56,7 +59,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
56
59
|
super().__init__(**kwargs)
|
|
57
60
|
|
|
58
61
|
@property
|
|
59
|
-
def has_pretrained(self):
|
|
62
|
+
def has_pretrained(self) -> bool:
|
|
60
63
|
"""
|
|
61
64
|
Check if the model pool contains a pretrained model.
|
|
62
65
|
|
|
@@ -125,7 +128,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
125
128
|
return len(self.model_names)
|
|
126
129
|
|
|
127
130
|
@staticmethod
|
|
128
|
-
def is_special_model(model_name: str):
|
|
131
|
+
def is_special_model(model_name: str) -> bool:
|
|
129
132
|
"""
|
|
130
133
|
Determine if a model is special based on its name.
|
|
131
134
|
|
|
@@ -152,6 +155,23 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
152
155
|
model_config = deepcopy(model_config)
|
|
153
156
|
return model_config
|
|
154
157
|
|
|
158
|
+
def get_model_path(self, model_name: str) -> str:
|
|
159
|
+
"""
|
|
160
|
+
Get the path for the specified model.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
model_name (str): The name of the model.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
str: The path for the specified model.
|
|
167
|
+
"""
|
|
168
|
+
if isinstance(self._models[model_name], str):
|
|
169
|
+
return self._models[model_name]
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"Model path is not a string. Try to override this method in derived modelpool class."
|
|
173
|
+
)
|
|
174
|
+
|
|
155
175
|
def load_model(
|
|
156
176
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
157
177
|
) -> nn.Module:
|
|
@@ -159,7 +179,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
159
179
|
Load a model from the pool based on the provided configuration.
|
|
160
180
|
|
|
161
181
|
Args:
|
|
162
|
-
|
|
182
|
+
model_name_or_config (Union[str, DictConfig]): The model name or configuration.
|
|
163
183
|
|
|
164
184
|
Returns:
|
|
165
185
|
nn.Module: The instantiated model.
|
|
@@ -201,11 +221,11 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
201
221
|
model = self.load_model(self.model_names[0], *args, **kwargs)
|
|
202
222
|
return model
|
|
203
223
|
|
|
204
|
-
def models(self):
|
|
224
|
+
def models(self) -> Generator[nn.Module, None, None]:
|
|
205
225
|
for model_name in self.model_names:
|
|
206
226
|
yield self.load_model(model_name)
|
|
207
227
|
|
|
208
|
-
def named_models(self):
|
|
228
|
+
def named_models(self) -> Generator[Tuple[str, nn.Module], None, None]:
|
|
209
229
|
for model_name in self.model_names:
|
|
210
230
|
yield model_name, self.load_model(model_name)
|
|
211
231
|
|
|
@@ -57,6 +57,15 @@ class CausalLMPool(BaseModelPool):
|
|
|
57
57
|
)
|
|
58
58
|
self.load_lazy = load_lazy
|
|
59
59
|
|
|
60
|
+
def get_model_path(self, model_name: str):
|
|
61
|
+
model_name_or_config = self._models[model_name]
|
|
62
|
+
if isinstance(model_name_or_config, str):
|
|
63
|
+
return model_name_or_config
|
|
64
|
+
elif isinstance(model_name_or_config, (DictConfig, dict)):
|
|
65
|
+
return model_name_or_config.get("pretrained_model_name_or_path")
|
|
66
|
+
else:
|
|
67
|
+
raise RuntimeError("Invalid model configuration")
|
|
68
|
+
|
|
60
69
|
@override
|
|
61
70
|
def load_model(
|
|
62
71
|
self,
|
|
@@ -11,9 +11,7 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import instantiate, timeit_context
|
|
14
|
-
from fusion_bench.utils.modelscope import
|
|
15
|
-
resolve_repo_path,
|
|
16
|
-
)
|
|
14
|
+
from fusion_bench.utils.modelscope import resolve_repo_path
|
|
17
15
|
|
|
18
16
|
from ..base_pool import BaseModelPool
|
|
19
17
|
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .reward_model import create_reward_model_from_pretrained
|
|
2
|
-
from .seq_classification_lm import
|
|
2
|
+
from .seq_classification_lm import SequenceClassificationModelPool
|
fusion_bench/models/__init__.py
CHANGED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains utilities for working with Hugging Face models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
import os
|
|
7
|
+
import shutil
|
|
8
|
+
from typing import Optional, cast
|
|
9
|
+
|
|
10
|
+
from omegaconf import OmegaConf
|
|
11
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
14
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
15
|
+
|
|
16
|
+
log = getRankZeroLogger(__name__)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"save_pretrained_with_remote_code",
|
|
20
|
+
"generate_readme_head",
|
|
21
|
+
"generate_readme_body",
|
|
22
|
+
"generate_complete_readme",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def save_pretrained_with_remote_code(
|
|
27
|
+
model: PreTrainedModel,
|
|
28
|
+
auto_map: dict[str, object],
|
|
29
|
+
save_directory,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Saves a model with custom code to a directory.
|
|
34
|
+
|
|
35
|
+
This function facilitates saving a Hugging Face `PreTrainedModel` along with its
|
|
36
|
+
associated custom code. It inspects the objects provided in the `auto_map`,
|
|
37
|
+
copies their source files to the `save_directory`, and generates an `__init__.py`
|
|
38
|
+
to make them importable. It also updates the model's configuration with an
|
|
39
|
+
`auto_map` attribute, which allows `AutoModel.from_pretrained` to correctly
|
|
40
|
+
instantiate the custom model classes when `trust_remote_code=True`.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model (PreTrainedModel): The model instance to be saved.
|
|
44
|
+
auto_map (dict[str, object]): A dictionary mapping auto class names
|
|
45
|
+
(e.g., "AutoModelForCausalLM") to the corresponding custom class objects.
|
|
46
|
+
save_directory (str or os.PathLike): The directory where the model and
|
|
47
|
+
custom code files will be saved.
|
|
48
|
+
**kwargs: Additional keyword arguments to be passed to the
|
|
49
|
+
`model.save_pretrained` method.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
```python
|
|
53
|
+
# Assuming `model` is an instance of `SmileQwen2ForCausalLM`
|
|
54
|
+
# and `SmileQwen2Config`, `SmileQwen2Model`, `SmileQwen2ForCausalLM`
|
|
55
|
+
# are custom classes defined in your project.
|
|
56
|
+
|
|
57
|
+
save_pretrained_with_remote_code(
|
|
58
|
+
model,
|
|
59
|
+
auto_map={
|
|
60
|
+
"AutoConfig": SmileQwen2Config,
|
|
61
|
+
"AutoModel": SmileQwen2Model,
|
|
62
|
+
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
63
|
+
},
|
|
64
|
+
save_directory="./my-custom-model",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# The model can then be loaded with `trust_remote_code=True`:
|
|
68
|
+
# from transformers import AutoModelForCausalLM
|
|
69
|
+
# loaded_model = AutoModelForCausalLM.from_pretrained(
|
|
70
|
+
# "./my-custom-model", trust_remote_code=True
|
|
71
|
+
# )
|
|
72
|
+
```
|
|
73
|
+
"""
|
|
74
|
+
auto_map_files = {}
|
|
75
|
+
auto_map_strs = {}
|
|
76
|
+
for key, obj in auto_map.items():
|
|
77
|
+
auto_map_files[key] = inspect.getfile(obj)
|
|
78
|
+
|
|
79
|
+
for key, obj in auto_map.items():
|
|
80
|
+
auto_map_strs[key] = (
|
|
81
|
+
f"{(inspect.getmodule(obj).__name__).split('.')[-1]}.{obj.__name__}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
model.config.auto_map = auto_map_strs
|
|
85
|
+
|
|
86
|
+
# save model to `save_directory`
|
|
87
|
+
model.save_pretrained(save_directory=save_directory, **kwargs)
|
|
88
|
+
|
|
89
|
+
# copy source files to `save_directory`
|
|
90
|
+
for key, file_path in auto_map_files.items():
|
|
91
|
+
shutil.copy(
|
|
92
|
+
src=file_path, dst=os.path.join(save_directory, os.path.basename(file_path))
|
|
93
|
+
)
|
|
94
|
+
# construct `__init__.py`
|
|
95
|
+
init_file = os.path.join(save_directory, "__init__.py")
|
|
96
|
+
with open(init_file, "w") as f:
|
|
97
|
+
for key, file_name in auto_map_files.items():
|
|
98
|
+
base_name = os.path.basename(file_name).split(".")[0]
|
|
99
|
+
f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def generate_readme_head(
|
|
103
|
+
models: list[str] | BaseModelPool,
|
|
104
|
+
library_name: str = "transformers",
|
|
105
|
+
tags: list[str] = ["fusion-bench", "merge"],
|
|
106
|
+
):
|
|
107
|
+
text = "---\nbase_model:\n"
|
|
108
|
+
for model_name in models:
|
|
109
|
+
text += f"- {model_name}\n"
|
|
110
|
+
if library_name:
|
|
111
|
+
text += f"library_name: {library_name}\n"
|
|
112
|
+
text += "tags:\n"
|
|
113
|
+
for tag in tags:
|
|
114
|
+
text += f"- {tag}\n"
|
|
115
|
+
text += "---\n"
|
|
116
|
+
return text
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def generate_readme_body(
|
|
120
|
+
algorithm: BaseAlgorithm,
|
|
121
|
+
models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
|
|
122
|
+
models: list[str] = None,
|
|
123
|
+
):
|
|
124
|
+
text = """\
|
|
125
|
+
# Merge
|
|
126
|
+
|
|
127
|
+
This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
128
|
+
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
if models is not None:
|
|
132
|
+
text += """
|
|
133
|
+
## Models Merged
|
|
134
|
+
|
|
135
|
+
The following models were included in the merge:
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
for model_name in models:
|
|
139
|
+
text += f"- {model_name}\n"
|
|
140
|
+
text += "\n"
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
text += f"""\
|
|
144
|
+
## Configuration
|
|
145
|
+
|
|
146
|
+
The following YAML configuration was used to produce this model:
|
|
147
|
+
|
|
148
|
+
```yaml
|
|
149
|
+
{OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
|
|
150
|
+
```
|
|
151
|
+
"""
|
|
152
|
+
except Exception as e:
|
|
153
|
+
return (
|
|
154
|
+
text # If the algorithm config cannot be converted to YAML, we skip it.
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if isinstance(models_or_modelpool, BaseModelPool):
|
|
158
|
+
try:
|
|
159
|
+
text += f"""
|
|
160
|
+
```yaml
|
|
161
|
+
{OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
|
|
162
|
+
```
|
|
163
|
+
"""
|
|
164
|
+
except Exception as e:
|
|
165
|
+
pass # If the model pool config cannot be converted to YAML, we skip it.
|
|
166
|
+
return text
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def generate_complete_readme(
|
|
170
|
+
algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
|
|
171
|
+
):
|
|
172
|
+
# Generate the complete README content
|
|
173
|
+
text = generate_readme_head(
|
|
174
|
+
[modelpool.get_model_path(m) for m in modelpool.model_names]
|
|
175
|
+
)
|
|
176
|
+
readme_body = generate_readme_body(
|
|
177
|
+
algorithm,
|
|
178
|
+
models_or_modelpool=modelpool,
|
|
179
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
180
|
+
)
|
|
181
|
+
complete_readme = text + "\n" + readme_body
|
|
182
|
+
return complete_readme
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import Optional
|
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
from torch.func import functional_call, jvp
|
|
@@ -9,7 +9,7 @@ from torch.func import functional_call, jvp
|
|
|
9
9
|
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def dict_params_to_tuple(dict_params: dict):
|
|
12
|
+
def dict_params_to_tuple(dict_params: dict) -> Tuple:
|
|
13
13
|
return tuple(v for k, v in dict_params.items())
|
|
14
14
|
|
|
15
15
|
|
|
@@ -33,7 +33,7 @@ class LinearizedModelWraper(nn.Module):
|
|
|
33
33
|
for p in self.params0_values:
|
|
34
34
|
p.requires_grad_(False)
|
|
35
35
|
|
|
36
|
-
def tuple_params_to_dict(self, tuple_params):
|
|
36
|
+
def tuple_params_to_dict(self, tuple_params) -> Dict[str, Any]:
|
|
37
37
|
"""
|
|
38
38
|
Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
|
|
39
39
|
|
|
@@ -50,7 +50,7 @@ class LinearizedModelWraper(nn.Module):
|
|
|
50
50
|
state_dict[k] = p
|
|
51
51
|
return state_dict
|
|
52
52
|
|
|
53
|
-
def forward(self, *args, **kwargs):
|
|
53
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
54
54
|
"""
|
|
55
55
|
Computes the linearized model output using a first-order Taylor decomposition.
|
|
56
56
|
|
|
@@ -4,12 +4,12 @@ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://hug
|
|
|
4
4
|
|
|
5
5
|
from .configuration_deepseek import DeepseekV2Config
|
|
6
6
|
from .modeling_deepseek import (
|
|
7
|
+
DeepseekV2DecoderLayer,
|
|
7
8
|
DeepseekV2ForCausalLM,
|
|
8
9
|
DeepseekV2ForSequenceClassification,
|
|
9
10
|
DeepseekV2MLP,
|
|
10
11
|
DeepseekV2Model,
|
|
11
12
|
DeepseekV2MoE,
|
|
12
|
-
DeepseekV2DecoderLayer,
|
|
13
13
|
)
|
|
14
14
|
from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
|
|
15
15
|
from .tokenization_deepseek_fast import DeepseekTokenizerFast
|
|
@@ -17,17 +17,18 @@
|
|
|
17
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
18
|
# See the License for the specific language governing permissions and
|
|
19
19
|
# limitations under the License.
|
|
20
|
-
"""
|
|
20
|
+
"""PyTorch DeepSeek model."""
|
|
21
21
|
import math
|
|
22
22
|
import warnings
|
|
23
23
|
from typing import List, Optional, Tuple, Union
|
|
24
24
|
|
|
25
|
+
import numpy as np
|
|
25
26
|
import torch
|
|
27
|
+
import torch.distributed as dist
|
|
26
28
|
import torch.nn.functional as F
|
|
27
29
|
import torch.utils.checkpoint
|
|
28
30
|
from torch import nn
|
|
29
31
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
30
|
-
|
|
31
32
|
from transformers.activations import ACT2FN
|
|
32
33
|
from transformers.cache_utils import Cache, DynamicCache
|
|
33
34
|
from transformers.modeling_attn_mask_utils import (
|
|
@@ -54,9 +55,8 @@ from transformers.utils import (
|
|
|
54
55
|
replace_return_docstrings,
|
|
55
56
|
)
|
|
56
57
|
from transformers.utils.import_utils import is_torch_fx_available
|
|
58
|
+
|
|
57
59
|
from .configuration_deepseek import DeepseekV2Config
|
|
58
|
-
import torch.distributed as dist
|
|
59
|
-
import numpy as np
|
|
60
60
|
|
|
61
61
|
if is_flash_attn_2_available():
|
|
62
62
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from . import register
|
|
2
|
+
from .configuration_smile_gemma2 import SmileGemma2Config
|
|
3
|
+
from .modeling_smile_gemma2 import (
|
|
4
|
+
SmileGemma2ForCausalLM,
|
|
5
|
+
SmileGemma2ForSequenceClassification,
|
|
6
|
+
SmileGemma2ForTokenClassification,
|
|
7
|
+
SmileGemma2Model,
|
|
8
|
+
SmileGemma2PreTrainedModel,
|
|
9
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SmileGemma2Config(Gemma2Config):
|
|
5
|
+
model_type = "smile_gemma2"
|
|
6
|
+
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
num_experts_per_tok: int = 1,
|
|
10
|
+
rank_of_router: int = None,
|
|
11
|
+
rank_of_expert: int = None,
|
|
12
|
+
num_local_experts: int = None,
|
|
13
|
+
**kwargs,
|
|
14
|
+
):
|
|
15
|
+
self.num_experts_per_tok = num_experts_per_tok
|
|
16
|
+
self.rank_of_router = rank_of_router
|
|
17
|
+
self.rank_of_expert = rank_of_expert
|
|
18
|
+
self.num_local_experts = num_local_experts
|
|
19
|
+
|
|
20
|
+
super().__init__(**kwargs)
|