fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains utilities for working with Hugging Face models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
import os
|
|
7
|
+
import shutil
|
|
8
|
+
from typing import Optional, cast
|
|
9
|
+
|
|
10
|
+
from omegaconf import OmegaConf
|
|
11
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
14
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
15
|
+
|
|
16
|
+
log = getRankZeroLogger(__name__)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"save_pretrained_with_remote_code",
|
|
20
|
+
"generate_readme_head",
|
|
21
|
+
"generate_readme_body",
|
|
22
|
+
"generate_complete_readme",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def save_pretrained_with_remote_code(
|
|
27
|
+
model: PreTrainedModel,
|
|
28
|
+
auto_map: dict[str, object],
|
|
29
|
+
save_directory,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Saves a model with custom code to a directory.
|
|
34
|
+
|
|
35
|
+
This function facilitates saving a Hugging Face `PreTrainedModel` along with its
|
|
36
|
+
associated custom code. It inspects the objects provided in the `auto_map`,
|
|
37
|
+
copies their source files to the `save_directory`, and generates an `__init__.py`
|
|
38
|
+
to make them importable. It also updates the model's configuration with an
|
|
39
|
+
`auto_map` attribute, which allows `AutoModel.from_pretrained` to correctly
|
|
40
|
+
instantiate the custom model classes when `trust_remote_code=True`.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model (PreTrainedModel): The model instance to be saved.
|
|
44
|
+
auto_map (dict[str, object]): A dictionary mapping auto class names
|
|
45
|
+
(e.g., "AutoModelForCausalLM") to the corresponding custom class objects.
|
|
46
|
+
save_directory (str or os.PathLike): The directory where the model and
|
|
47
|
+
custom code files will be saved.
|
|
48
|
+
**kwargs: Additional keyword arguments to be passed to the
|
|
49
|
+
`model.save_pretrained` method.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
```python
|
|
53
|
+
# Assuming `model` is an instance of `SmileQwen2ForCausalLM`
|
|
54
|
+
# and `SmileQwen2Config`, `SmileQwen2Model`, `SmileQwen2ForCausalLM`
|
|
55
|
+
# are custom classes defined in your project.
|
|
56
|
+
|
|
57
|
+
save_pretrained_with_remote_code(
|
|
58
|
+
model,
|
|
59
|
+
auto_map={
|
|
60
|
+
"AutoConfig": SmileQwen2Config,
|
|
61
|
+
"AutoModel": SmileQwen2Model,
|
|
62
|
+
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
63
|
+
},
|
|
64
|
+
save_directory="./my-custom-model",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# The model can then be loaded with `trust_remote_code=True`:
|
|
68
|
+
# from transformers import AutoModelForCausalLM
|
|
69
|
+
# loaded_model = AutoModelForCausalLM.from_pretrained(
|
|
70
|
+
# "./my-custom-model", trust_remote_code=True
|
|
71
|
+
# )
|
|
72
|
+
```
|
|
73
|
+
"""
|
|
74
|
+
auto_map_files = {}
|
|
75
|
+
auto_map_strs = {}
|
|
76
|
+
for key, obj in auto_map.items():
|
|
77
|
+
auto_map_files[key] = inspect.getfile(obj)
|
|
78
|
+
|
|
79
|
+
for key, obj in auto_map.items():
|
|
80
|
+
auto_map_strs[key] = (
|
|
81
|
+
f"{(inspect.getmodule(obj).__name__).split('.')[-1]}.{obj.__name__}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
model.config.auto_map = auto_map_strs
|
|
85
|
+
|
|
86
|
+
# save model to `save_directory`
|
|
87
|
+
model.save_pretrained(save_directory=save_directory, **kwargs)
|
|
88
|
+
|
|
89
|
+
# copy source files to `save_directory`
|
|
90
|
+
for key, file_path in auto_map_files.items():
|
|
91
|
+
shutil.copy(
|
|
92
|
+
src=file_path, dst=os.path.join(save_directory, os.path.basename(file_path))
|
|
93
|
+
)
|
|
94
|
+
# construct `__init__.py`
|
|
95
|
+
init_file = os.path.join(save_directory, "__init__.py")
|
|
96
|
+
with open(init_file, "w") as f:
|
|
97
|
+
for key, file_name in auto_map_files.items():
|
|
98
|
+
base_name = os.path.basename(file_name).split(".")[0]
|
|
99
|
+
f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def generate_readme_head(
|
|
103
|
+
models: list[str] | BaseModelPool,
|
|
104
|
+
library_name: str = "transformers",
|
|
105
|
+
tags: list[str] = ["fusion-bench", "merge"],
|
|
106
|
+
):
|
|
107
|
+
text = "---\nbase_model:\n"
|
|
108
|
+
for model_name in models:
|
|
109
|
+
text += f"- {model_name}\n"
|
|
110
|
+
if library_name:
|
|
111
|
+
text += f"library_name: {library_name}\n"
|
|
112
|
+
text += "tags:\n"
|
|
113
|
+
for tag in tags:
|
|
114
|
+
text += f"- {tag}\n"
|
|
115
|
+
text += "---\n"
|
|
116
|
+
return text
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def generate_readme_body(
|
|
120
|
+
algorithm: BaseAlgorithm,
|
|
121
|
+
models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
|
|
122
|
+
models: list[str] = None,
|
|
123
|
+
):
|
|
124
|
+
text = """\
|
|
125
|
+
# Merge
|
|
126
|
+
|
|
127
|
+
This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
128
|
+
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
if models is not None:
|
|
132
|
+
text += """
|
|
133
|
+
## Models Merged
|
|
134
|
+
|
|
135
|
+
The following models were included in the merge:
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
for model_name in models:
|
|
139
|
+
text += f"- {model_name}\n"
|
|
140
|
+
text += "\n"
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
text += f"""\
|
|
144
|
+
## Configuration
|
|
145
|
+
|
|
146
|
+
The following YAML configuration was used to produce this model:
|
|
147
|
+
|
|
148
|
+
```yaml
|
|
149
|
+
{OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
|
|
150
|
+
```
|
|
151
|
+
"""
|
|
152
|
+
except Exception as e:
|
|
153
|
+
return (
|
|
154
|
+
text # If the algorithm config cannot be converted to YAML, we skip it.
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if isinstance(models_or_modelpool, BaseModelPool):
|
|
158
|
+
try:
|
|
159
|
+
text += f"""
|
|
160
|
+
```yaml
|
|
161
|
+
{OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
|
|
162
|
+
```
|
|
163
|
+
"""
|
|
164
|
+
except Exception as e:
|
|
165
|
+
pass # If the model pool config cannot be converted to YAML, we skip it.
|
|
166
|
+
return text
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def generate_complete_readme(
|
|
170
|
+
algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
|
|
171
|
+
):
|
|
172
|
+
# Generate the complete README content
|
|
173
|
+
text = generate_readme_head(
|
|
174
|
+
[modelpool.get_model_path(m) for m in modelpool.model_names]
|
|
175
|
+
)
|
|
176
|
+
readme_body = generate_readme_body(
|
|
177
|
+
algorithm,
|
|
178
|
+
models_or_modelpool=modelpool,
|
|
179
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
180
|
+
)
|
|
181
|
+
complete_readme = text + "\n" + readme_body
|
|
182
|
+
return complete_readme
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import OrderedDict
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import Optional
|
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
from torch.func import functional_call, jvp
|
|
@@ -9,7 +9,7 @@ from torch.func import functional_call, jvp
|
|
|
9
9
|
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def dict_params_to_tuple(dict_params: dict):
|
|
12
|
+
def dict_params_to_tuple(dict_params: dict) -> Tuple:
|
|
13
13
|
return tuple(v for k, v in dict_params.items())
|
|
14
14
|
|
|
15
15
|
|
|
@@ -33,7 +33,7 @@ class LinearizedModelWraper(nn.Module):
|
|
|
33
33
|
for p in self.params0_values:
|
|
34
34
|
p.requires_grad_(False)
|
|
35
35
|
|
|
36
|
-
def tuple_params_to_dict(self, tuple_params):
|
|
36
|
+
def tuple_params_to_dict(self, tuple_params) -> Dict[str, Any]:
|
|
37
37
|
"""
|
|
38
38
|
Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
|
|
39
39
|
|
|
@@ -50,7 +50,7 @@ class LinearizedModelWraper(nn.Module):
|
|
|
50
50
|
state_dict[k] = p
|
|
51
51
|
return state_dict
|
|
52
52
|
|
|
53
|
-
def forward(self, *args, **kwargs):
|
|
53
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
54
54
|
"""
|
|
55
55
|
Computes the linearized model output using a first-order Taylor decomposition.
|
|
56
56
|
|
|
@@ -4,12 +4,12 @@ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://hug
|
|
|
4
4
|
|
|
5
5
|
from .configuration_deepseek import DeepseekV2Config
|
|
6
6
|
from .modeling_deepseek import (
|
|
7
|
+
DeepseekV2DecoderLayer,
|
|
7
8
|
DeepseekV2ForCausalLM,
|
|
8
9
|
DeepseekV2ForSequenceClassification,
|
|
9
10
|
DeepseekV2MLP,
|
|
10
11
|
DeepseekV2Model,
|
|
11
12
|
DeepseekV2MoE,
|
|
12
|
-
DeepseekV2DecoderLayer,
|
|
13
13
|
)
|
|
14
14
|
from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
|
|
15
15
|
from .tokenization_deepseek_fast import DeepseekTokenizerFast
|
|
@@ -17,17 +17,18 @@
|
|
|
17
17
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
18
|
# See the License for the specific language governing permissions and
|
|
19
19
|
# limitations under the License.
|
|
20
|
-
"""
|
|
20
|
+
"""PyTorch DeepSeek model."""
|
|
21
21
|
import math
|
|
22
22
|
import warnings
|
|
23
23
|
from typing import List, Optional, Tuple, Union
|
|
24
24
|
|
|
25
|
+
import numpy as np
|
|
25
26
|
import torch
|
|
27
|
+
import torch.distributed as dist
|
|
26
28
|
import torch.nn.functional as F
|
|
27
29
|
import torch.utils.checkpoint
|
|
28
30
|
from torch import nn
|
|
29
31
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
30
|
-
|
|
31
32
|
from transformers.activations import ACT2FN
|
|
32
33
|
from transformers.cache_utils import Cache, DynamicCache
|
|
33
34
|
from transformers.modeling_attn_mask_utils import (
|
|
@@ -54,9 +55,8 @@ from transformers.utils import (
|
|
|
54
55
|
replace_return_docstrings,
|
|
55
56
|
)
|
|
56
57
|
from transformers.utils.import_utils import is_torch_fx_available
|
|
58
|
+
|
|
57
59
|
from .configuration_deepseek import DeepseekV2Config
|
|
58
|
-
import torch.distributed as dist
|
|
59
|
-
import numpy as np
|
|
60
60
|
|
|
61
61
|
if is_flash_attn_2_available():
|
|
62
62
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from . import register
|
|
2
|
+
from .configuration_smile_gemma2 import SmileGemma2Config
|
|
3
|
+
from .modeling_smile_gemma2 import (
|
|
4
|
+
SmileGemma2ForCausalLM,
|
|
5
|
+
SmileGemma2ForSequenceClassification,
|
|
6
|
+
SmileGemma2ForTokenClassification,
|
|
7
|
+
SmileGemma2Model,
|
|
8
|
+
SmileGemma2PreTrainedModel,
|
|
9
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SmileGemma2Config(Gemma2Config):
|
|
5
|
+
model_type = "smile_gemma2"
|
|
6
|
+
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
num_experts_per_tok: int = 1,
|
|
10
|
+
rank_of_router: int = None,
|
|
11
|
+
rank_of_expert: int = None,
|
|
12
|
+
num_local_experts: int = None,
|
|
13
|
+
**kwargs,
|
|
14
|
+
):
|
|
15
|
+
self.num_experts_per_tok = num_experts_per_tok
|
|
16
|
+
self.rank_of_router = rank_of_router
|
|
17
|
+
self.rank_of_expert = rank_of_expert
|
|
18
|
+
self.num_local_experts = num_local_experts
|
|
19
|
+
|
|
20
|
+
super().__init__(**kwargs)
|