fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
fusion_bench/models/hf_utils.py
CHANGED
|
@@ -5,23 +5,65 @@ This module contains utilities for working with Hugging Face models.
|
|
|
5
5
|
import inspect
|
|
6
6
|
import os
|
|
7
7
|
import shutil
|
|
8
|
-
from typing import Optional, cast
|
|
8
|
+
from typing import List, Optional, cast
|
|
9
9
|
|
|
10
|
-
from omegaconf import OmegaConf
|
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
|
11
11
|
from transformers.modeling_utils import PreTrainedModel
|
|
12
12
|
|
|
13
|
-
from fusion_bench import
|
|
14
|
-
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
13
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
15
14
|
|
|
16
|
-
log =
|
|
15
|
+
log = get_rankzero_logger(__name__)
|
|
17
16
|
|
|
18
17
|
__all__ = [
|
|
18
|
+
"load_model_card_template",
|
|
19
19
|
"save_pretrained_with_remote_code",
|
|
20
|
-
"
|
|
21
|
-
"generate_readme_body",
|
|
22
|
-
"generate_complete_readme",
|
|
20
|
+
"create_default_model_card",
|
|
23
21
|
]
|
|
24
22
|
|
|
23
|
+
MODEL_CARD_TEMPLATE_DIRS = [
|
|
24
|
+
os.path.join(os.path.dirname(__file__), "model_card_templates")
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_model_card_template(basename: str) -> str:
|
|
29
|
+
"""
|
|
30
|
+
Load a model card template from file.
|
|
31
|
+
|
|
32
|
+
Searches for a template file by name, first checking if the name is a direct file path,
|
|
33
|
+
then searching through predefined template directories.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name (str): The name of the template file or a direct file path to the template.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The contents of the template file as a string.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
FileNotFoundError: If the template file is not found in any of the search locations.
|
|
43
|
+
"""
|
|
44
|
+
if os.path.exists(basename):
|
|
45
|
+
with open(basename, "r") as f:
|
|
46
|
+
return f.read()
|
|
47
|
+
|
|
48
|
+
for template_dir in MODEL_CARD_TEMPLATE_DIRS:
|
|
49
|
+
template_path = os.path.join(template_dir, basename)
|
|
50
|
+
if os.path.exists(template_path):
|
|
51
|
+
with open(template_path, "r") as f:
|
|
52
|
+
return f.read()
|
|
53
|
+
|
|
54
|
+
raise FileNotFoundError(f"Model card template '{basename}' not found.")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def try_to_yaml(config):
|
|
58
|
+
if config is None:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
return OmegaConf.to_yaml(config, resolve=True, sort_keys=True)
|
|
63
|
+
except Exception as e:
|
|
64
|
+
log.error(f"Failed to convert config to YAML: {e}. Return `None`.")
|
|
65
|
+
return None
|
|
66
|
+
|
|
25
67
|
|
|
26
68
|
def save_pretrained_with_remote_code(
|
|
27
69
|
model: PreTrainedModel,
|
|
@@ -99,84 +141,25 @@ def save_pretrained_with_remote_code(
|
|
|
99
141
|
f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
|
|
100
142
|
|
|
101
143
|
|
|
102
|
-
def
|
|
103
|
-
models: list[str]
|
|
104
|
-
|
|
144
|
+
def create_default_model_card(
|
|
145
|
+
models: list[str],
|
|
146
|
+
*,
|
|
147
|
+
title: str = "Deep Model Fusion",
|
|
105
148
|
tags: list[str] = ["fusion-bench", "merge"],
|
|
149
|
+
description=None,
|
|
150
|
+
algorithm_config: DictConfig = None,
|
|
151
|
+
modelpool_config: DictConfig = None,
|
|
106
152
|
):
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def generate_readme_body(
|
|
120
|
-
algorithm: BaseAlgorithm,
|
|
121
|
-
models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
|
|
122
|
-
models: list[str] = None,
|
|
123
|
-
):
|
|
124
|
-
text = """\
|
|
125
|
-
# Merge
|
|
126
|
-
|
|
127
|
-
This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
128
|
-
|
|
129
|
-
"""
|
|
130
|
-
|
|
131
|
-
if models is not None:
|
|
132
|
-
text += """
|
|
133
|
-
## Models Merged
|
|
134
|
-
|
|
135
|
-
The following models were included in the merge:
|
|
136
|
-
|
|
137
|
-
"""
|
|
138
|
-
for model_name in models:
|
|
139
|
-
text += f"- {model_name}\n"
|
|
140
|
-
text += "\n"
|
|
141
|
-
|
|
142
|
-
try:
|
|
143
|
-
text += f"""\
|
|
144
|
-
## Configuration
|
|
145
|
-
|
|
146
|
-
The following YAML configuration was used to produce this model:
|
|
147
|
-
|
|
148
|
-
```yaml
|
|
149
|
-
{OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
|
|
150
|
-
```
|
|
151
|
-
"""
|
|
152
|
-
except Exception as e:
|
|
153
|
-
return (
|
|
154
|
-
text # If the algorithm config cannot be converted to YAML, we skip it.
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
if isinstance(models_or_modelpool, BaseModelPool):
|
|
158
|
-
try:
|
|
159
|
-
text += f"""
|
|
160
|
-
```yaml
|
|
161
|
-
{OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
|
|
162
|
-
```
|
|
163
|
-
"""
|
|
164
|
-
except Exception as e:
|
|
165
|
-
pass # If the model pool config cannot be converted to YAML, we skip it.
|
|
166
|
-
return text
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def generate_complete_readme(
|
|
170
|
-
algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
|
|
171
|
-
):
|
|
172
|
-
# Generate the complete README content
|
|
173
|
-
text = generate_readme_head(
|
|
174
|
-
[modelpool.get_model_path(m) for m in modelpool.model_names]
|
|
175
|
-
)
|
|
176
|
-
readme_body = generate_readme_body(
|
|
177
|
-
algorithm,
|
|
178
|
-
models_or_modelpool=modelpool,
|
|
179
|
-
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
153
|
+
from jinja2 import Template
|
|
154
|
+
|
|
155
|
+
template: Template = Template(load_model_card_template("default.md"))
|
|
156
|
+
card = template.render(
|
|
157
|
+
models=models,
|
|
158
|
+
library_name="transformers",
|
|
159
|
+
title=title,
|
|
160
|
+
tags=tags,
|
|
161
|
+
description=description,
|
|
162
|
+
algorithm_config_str=try_to_yaml(algorithm_config),
|
|
163
|
+
modelpool_config_str=try_to_yaml(modelpool_config),
|
|
180
164
|
)
|
|
181
|
-
|
|
182
|
-
return complete_readme
|
|
165
|
+
return card
|
|
@@ -45,21 +45,21 @@ def linearize_lora_model_(model):
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def load_fft_vision_model_hf(
|
|
48
|
-
model_name: str,
|
|
48
|
+
model_name: str, return_vision_model=True
|
|
49
49
|
) -> Union[CLIPVisionTransformer, CLIPVisionModel]:
|
|
50
50
|
"""
|
|
51
51
|
Load a CLIP vision model from Hugging Face.
|
|
52
52
|
|
|
53
53
|
Args:
|
|
54
54
|
model_name (str): The name of the CLIP vision model to load from Hugging Face.
|
|
55
|
-
|
|
55
|
+
return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
|
|
56
56
|
|
|
57
57
|
Returns:
|
|
58
58
|
Union[CLIPVisionTransformer, CLIPVisionModel]: The vision model.
|
|
59
59
|
"""
|
|
60
60
|
model = CLIPVisionModel.from_pretrained(model_name)
|
|
61
61
|
|
|
62
|
-
if
|
|
62
|
+
if return_vision_model:
|
|
63
63
|
return CLIPVisionModel.from_pretrained(model_name).vision_model
|
|
64
64
|
else:
|
|
65
65
|
return model
|
|
@@ -69,7 +69,7 @@ def load_lora_vision_model_hf(
|
|
|
69
69
|
base_model_name: str,
|
|
70
70
|
peft_name: str,
|
|
71
71
|
merge_and_unload: bool = False,
|
|
72
|
-
|
|
72
|
+
return_vision_model=True,
|
|
73
73
|
) -> PeftModel:
|
|
74
74
|
"""
|
|
75
75
|
Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
|
|
@@ -80,7 +80,7 @@ def load_lora_vision_model_hf(
|
|
|
80
80
|
base_model_name (str): The name of the base vision model to load from Hugging Face.
|
|
81
81
|
peft_name (str): The name of the LoRA adaptation to apply to the base model.
|
|
82
82
|
merge_and_unload (bool, optional): If True, the LoRA adaptation is merged into the base model and the LoRA layers are removed. Defaults to False.
|
|
83
|
-
|
|
83
|
+
return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
|
|
84
84
|
|
|
85
85
|
Returns:
|
|
86
86
|
PeftModel: The adapted vision model, optionally merged and unloaded.
|
|
@@ -97,7 +97,7 @@ def load_lora_vision_model_hf(
|
|
|
97
97
|
vision_model = peft_model
|
|
98
98
|
|
|
99
99
|
# Return the vision model
|
|
100
|
-
if
|
|
100
|
+
if return_vision_model:
|
|
101
101
|
return vision_model
|
|
102
102
|
else:
|
|
103
103
|
model.vision_model = vision_model
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
---
|
|
2
|
+
base_model:
|
|
3
|
+
{%- for model in models %}
|
|
4
|
+
- {{ model }}
|
|
5
|
+
{%- endfor %}
|
|
6
|
+
library_name: {{ library_name }}
|
|
7
|
+
tags:
|
|
8
|
+
{%- for tag in tags %}
|
|
9
|
+
- {{ tag }}
|
|
10
|
+
{%- endfor %}
|
|
11
|
+
---
|
|
12
|
+
# {{ title }}
|
|
13
|
+
|
|
14
|
+
{% if description is not none %}{{ description }}{% endif %}
|
|
15
|
+
|
|
16
|
+
## Models Merged
|
|
17
|
+
|
|
18
|
+
This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
19
|
+
|
|
20
|
+
The following models were included in the merge:
|
|
21
|
+
{% for model in models %}
|
|
22
|
+
- {{ model }}
|
|
23
|
+
{%- endfor %}
|
|
24
|
+
|
|
25
|
+
{% if algorithm_config_str is not none or modelpool_config_str is not none %}
|
|
26
|
+
## Configuration
|
|
27
|
+
|
|
28
|
+
The following YAML configuration was used to produce this model:
|
|
29
|
+
|
|
30
|
+
{% if algorithm_config_str is not none -%}
|
|
31
|
+
### Algorithm Configuration
|
|
32
|
+
|
|
33
|
+
```yaml
|
|
34
|
+
{{ algorithm_config_str -}}
|
|
35
|
+
```
|
|
36
|
+
{%- endif %}
|
|
37
|
+
|
|
38
|
+
{% if modelpool_config_str is not none -%}
|
|
39
|
+
### Model Pool Configuration
|
|
40
|
+
|
|
41
|
+
```yaml
|
|
42
|
+
{{ modelpool_config_str -}}
|
|
43
|
+
```
|
|
44
|
+
{%- endif %}
|
|
45
|
+
|
|
46
|
+
{% endif %}
|
|
@@ -17,7 +17,6 @@ from transformers.modeling_outputs import (
|
|
|
17
17
|
)
|
|
18
18
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
19
19
|
from transformers.models.llama.modeling_llama import (
|
|
20
|
-
LLAMA_INPUTS_DOCSTRING,
|
|
21
20
|
LlamaRMSNorm,
|
|
22
21
|
LlamaRotaryEmbedding,
|
|
23
22
|
apply_rotary_pos_emb,
|
|
@@ -25,7 +24,6 @@ from transformers.models.llama.modeling_llama import (
|
|
|
25
24
|
)
|
|
26
25
|
from transformers.processing_utils import Unpack
|
|
27
26
|
from transformers.utils import (
|
|
28
|
-
LossKwargs,
|
|
29
27
|
add_start_docstrings_to_model_forward,
|
|
30
28
|
can_return_tuple,
|
|
31
29
|
is_torch_flex_attn_available,
|
|
@@ -296,7 +294,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
|
|
|
296
294
|
self.embed_tokens = value
|
|
297
295
|
|
|
298
296
|
@can_return_tuple
|
|
299
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
300
297
|
def forward(
|
|
301
298
|
self,
|
|
302
299
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -566,9 +563,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
|
|
|
566
563
|
return causal_mask
|
|
567
564
|
|
|
568
565
|
|
|
569
|
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
570
|
-
|
|
571
|
-
|
|
572
566
|
class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
|
|
573
567
|
_tied_weights_keys = ["lm_head.weight"]
|
|
574
568
|
_tp_plan = {"lm_head": "colwise_rep"}
|
|
@@ -603,7 +597,6 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
|
|
|
603
597
|
|
|
604
598
|
@can_return_tuple
|
|
605
599
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
606
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
607
600
|
@replace_return_docstrings(
|
|
608
601
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
609
602
|
)
|
|
@@ -620,7 +613,7 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
|
|
|
620
613
|
output_hidden_states: Optional[bool] = None,
|
|
621
614
|
cache_position: Optional[torch.LongTensor] = None,
|
|
622
615
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
623
|
-
**kwargs
|
|
616
|
+
**kwargs,
|
|
624
617
|
) -> CausalLMOutputWithPast:
|
|
625
618
|
r"""
|
|
626
619
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -31,7 +31,6 @@ from transformers.models.qwen2.modeling_qwen2 import (
|
|
|
31
31
|
)
|
|
32
32
|
from transformers.processing_utils import Unpack
|
|
33
33
|
from transformers.utils import (
|
|
34
|
-
LossKwargs,
|
|
35
34
|
add_code_sample_docstrings,
|
|
36
35
|
add_start_docstrings,
|
|
37
36
|
add_start_docstrings_to_model_forward,
|
|
@@ -607,9 +606,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
|
|
|
607
606
|
return causal_mask
|
|
608
607
|
|
|
609
608
|
|
|
610
|
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
611
|
-
|
|
612
|
-
|
|
613
609
|
class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
614
610
|
_tied_weights_keys = ["lm_head.weight"]
|
|
615
611
|
_tp_plan = {"lm_head": "colwise_rep"}
|
|
@@ -660,7 +656,7 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
|
660
656
|
output_hidden_states: Optional[bool] = None,
|
|
661
657
|
cache_position: Optional[torch.LongTensor] = None,
|
|
662
658
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
663
|
-
**kwargs
|
|
659
|
+
**kwargs,
|
|
664
660
|
) -> CausalLMOutputWithPast:
|
|
665
661
|
r"""
|
|
666
662
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
fusion_bench/models/we_moe.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import Generic, List
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.func
|
|
@@ -9,7 +9,7 @@ from torch.func import functional_call
|
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
|
|
11
11
|
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
12
|
-
from fusion_bench.utils.type import StateDictType
|
|
12
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
13
13
|
|
|
14
14
|
log = logging.getLogger(__name__)
|
|
15
15
|
|
|
@@ -76,15 +76,15 @@ def construct_weight_ensembling_gate(
|
|
|
76
76
|
return gate
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
class WeightEnsemblingMoE(nn.Module):
|
|
79
|
+
class WeightEnsemblingMoE(nn.Module, Generic[TorchModelType]):
|
|
80
80
|
# variable to store the merged state dict temporarily
|
|
81
81
|
_merged_state_dict: StateDictType = None
|
|
82
82
|
|
|
83
83
|
def __init__(
|
|
84
84
|
self,
|
|
85
85
|
hidden_size: int,
|
|
86
|
-
base_model:
|
|
87
|
-
expert_models: List[
|
|
86
|
+
base_model: TorchModelType,
|
|
87
|
+
expert_models: List[TorchModelType],
|
|
88
88
|
init_lambda: float = 0.2,
|
|
89
89
|
batch_first: bool = False,
|
|
90
90
|
router_hidden_layers: int = 2,
|
|
@@ -101,8 +101,8 @@ class WeightEnsemblingMoE(nn.Module):
|
|
|
101
101
|
Args:
|
|
102
102
|
|
|
103
103
|
hidden_size (int): The size of the hidden layer in the models.
|
|
104
|
-
base_model (
|
|
105
|
-
expert_models (List[
|
|
104
|
+
base_model (TorchModelType): The base model that will be used as a reference for the expert models.
|
|
105
|
+
expert_models (List[TorchModelType]): A list of expert models that will be combined.
|
|
106
106
|
init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
|
|
107
107
|
batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
|
|
108
108
|
router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
|
|
@@ -145,7 +145,7 @@ class WeightEnsemblingMoE(nn.Module):
|
|
|
145
145
|
self._merged_state_dict,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
-
def merge_weights(self, expert_weights):
|
|
148
|
+
def merge_weights(self, expert_weights) -> StateDictType:
|
|
149
149
|
state_dict = self.base_model.state_dict(keep_vars=True)
|
|
150
150
|
for weight, task_vector in zip(expert_weights, self.task_vectors):
|
|
151
151
|
for name, param in task_vector.named_parameters():
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
|
|
5
6
|
|
|
6
7
|
import lightning as L
|
|
@@ -9,19 +10,24 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
9
10
|
from torch import nn
|
|
10
11
|
from tqdm.auto import tqdm
|
|
11
12
|
|
|
12
|
-
import fusion_bench
|
|
13
|
-
from fusion_bench
|
|
13
|
+
import fusion_bench
|
|
14
|
+
from fusion_bench import (
|
|
15
|
+
BaseAlgorithm,
|
|
16
|
+
BaseHydraProgram,
|
|
17
|
+
BaseModelPool,
|
|
18
|
+
BaseTaskPool,
|
|
19
|
+
RuntimeConstants,
|
|
20
|
+
import_object,
|
|
21
|
+
instantiate,
|
|
22
|
+
timeit_context,
|
|
23
|
+
)
|
|
14
24
|
from fusion_bench.mixins import LightningFabricMixin
|
|
15
|
-
from fusion_bench.modelpool import BaseModelPool
|
|
16
|
-
from fusion_bench.programs import BaseHydraProgram
|
|
17
|
-
from fusion_bench.taskpool import BaseTaskPool
|
|
18
|
-
from fusion_bench.utils import import_object, instantiate, timeit_context
|
|
19
25
|
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
20
26
|
from fusion_bench.utils.json import print_json
|
|
21
|
-
from fusion_bench.utils.
|
|
27
|
+
from fusion_bench.utils.path import create_symlink
|
|
22
28
|
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
23
29
|
|
|
24
|
-
log =
|
|
30
|
+
log = fusion_bench.get_rankzero_logger(__name__)
|
|
25
31
|
|
|
26
32
|
|
|
27
33
|
class FabricModelFusionProgram(
|
|
@@ -60,6 +66,7 @@ class FabricModelFusionProgram(
|
|
|
60
66
|
path: DictConfig = None,
|
|
61
67
|
**kwargs,
|
|
62
68
|
):
|
|
69
|
+
super().__init__(**kwargs)
|
|
63
70
|
self._method = method
|
|
64
71
|
self._modelpool = modelpool
|
|
65
72
|
self._taskpool = taskpool
|
|
@@ -70,8 +77,10 @@ class FabricModelFusionProgram(
|
|
|
70
77
|
self.fast_dev_run = fast_dev_run
|
|
71
78
|
self.seed = seed
|
|
72
79
|
self.path = path
|
|
73
|
-
|
|
74
|
-
|
|
80
|
+
RuntimeConstants.debug = fast_dev_run
|
|
81
|
+
RuntimeConstants.print_function_call = print_function_call
|
|
82
|
+
if path is not None:
|
|
83
|
+
RuntimeConstants.cache_dir = path.get("cache_dir", None)
|
|
75
84
|
|
|
76
85
|
if print_config:
|
|
77
86
|
print_config_tree(
|
|
@@ -224,8 +233,16 @@ class FabricModelFusionProgram(
|
|
|
224
233
|
fabric = self.fabric
|
|
225
234
|
if self.seed is not None:
|
|
226
235
|
L.seed_everything(self.seed)
|
|
227
|
-
|
|
228
|
-
|
|
236
|
+
|
|
237
|
+
# create symbol link to hydra output directory
|
|
238
|
+
if (
|
|
239
|
+
self.fabric.is_global_zero
|
|
240
|
+
and self.log_dir is not None
|
|
241
|
+
and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
|
|
242
|
+
):
|
|
243
|
+
create_symlink(
|
|
244
|
+
get_hydra_output_dir(), self.log_dir, link_name="hydra_output"
|
|
245
|
+
)
|
|
229
246
|
|
|
230
247
|
log.info("Running the model fusion program.")
|
|
231
248
|
# setup the modelpool, method, and taskpool
|
|
@@ -278,51 +295,3 @@ class FabricModelFusionProgram(
|
|
|
278
295
|
json.dump(report, open(self.report_save_path, "w"))
|
|
279
296
|
else:
|
|
280
297
|
log.info("No task pool specified. Skipping evaluation.")
|
|
281
|
-
|
|
282
|
-
@rank_zero_only
|
|
283
|
-
def _link_hydra_output(self):
|
|
284
|
-
"""
|
|
285
|
-
Creates a symbolic link to the Hydra output directory within the specified log directory.
|
|
286
|
-
|
|
287
|
-
If `self.log_dir` is not None, this method will:
|
|
288
|
-
1. Retrieve the Hydra output directory using `get_hydra_output_dir()`.
|
|
289
|
-
2. Create the log directory if it does not already exist.
|
|
290
|
-
3. Create a symbolic link named "hydra_output_<basename_of_hydra_output_dir>"
|
|
291
|
-
within the log directory, pointing to the Hydra output directory.
|
|
292
|
-
|
|
293
|
-
Note:
|
|
294
|
-
- The symbolic link is created only if the Hydra output directory is not None.
|
|
295
|
-
- The `target_is_directory` parameter is set to True to indicate that the target is a directory.
|
|
296
|
-
|
|
297
|
-
Raises:
|
|
298
|
-
OSError: If the symbolic link creation fails.
|
|
299
|
-
"""
|
|
300
|
-
if self.log_dir is not None:
|
|
301
|
-
# make symlink to the hydra output directory
|
|
302
|
-
try:
|
|
303
|
-
hydra_output_dir = get_hydra_output_dir()
|
|
304
|
-
except Exception as e:
|
|
305
|
-
hydra_output_dir = None
|
|
306
|
-
|
|
307
|
-
if hydra_output_dir is not None:
|
|
308
|
-
if os.path.abspath(hydra_output_dir) == os.path.abspath(self.log_dir):
|
|
309
|
-
return
|
|
310
|
-
|
|
311
|
-
os.makedirs(self.log_dir, exist_ok=True)
|
|
312
|
-
try:
|
|
313
|
-
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
|
314
|
-
if os.name == "nt":
|
|
315
|
-
os.system(
|
|
316
|
-
f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}"
|
|
317
|
-
)
|
|
318
|
-
else:
|
|
319
|
-
os.symlink(
|
|
320
|
-
hydra_output_dir,
|
|
321
|
-
os.path.join(
|
|
322
|
-
self.log_dir,
|
|
323
|
-
"hydra_output_" + os.path.basename(hydra_output_dir),
|
|
324
|
-
),
|
|
325
|
-
target_is_directory=True,
|
|
326
|
-
)
|
|
327
|
-
except OSError as e:
|
|
328
|
-
log.warning(f"Failed to create symbolic link: {e}")
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -12,9 +12,9 @@ import os
|
|
|
12
12
|
import hydra
|
|
13
13
|
from omegaconf import DictConfig, OmegaConf
|
|
14
14
|
|
|
15
|
+
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
15
16
|
from fusion_bench.programs import BaseHydraProgram
|
|
16
17
|
from fusion_bench.utils import instantiate
|
|
17
|
-
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
18
18
|
|
|
19
19
|
log = logging.getLogger(__name__)
|
|
20
20
|
|
|
@@ -34,6 +34,39 @@ def _get_default_config_path():
|
|
|
34
34
|
version_base=None,
|
|
35
35
|
)
|
|
36
36
|
def main(cfg: DictConfig) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Main entry point for the FusionBench command-line interface.
|
|
39
|
+
|
|
40
|
+
This function serves as the primary entry point for the `fusion_bench` CLI command.
|
|
41
|
+
It is decorated with Hydra's main decorator to handle configuration management,
|
|
42
|
+
command-line argument parsing, and configuration file loading.
|
|
43
|
+
|
|
44
|
+
The function performs the following operations:
|
|
45
|
+
1. Resolves any interpolations in the configuration using OmegaConf
|
|
46
|
+
2. Instantiates the appropriate program class based on the configuration
|
|
47
|
+
3. Executes the program's run method to perform the fusion task
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
cfg (DictConfig): The Hydra configuration object containing all settings
|
|
51
|
+
for the fusion task. This includes method configuration, model pool
|
|
52
|
+
configuration, task pool configuration, and other runtime parameters.
|
|
53
|
+
The configuration is automatically loaded by Hydra from the specified
|
|
54
|
+
config files and command-line overrides.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
None: This function doesn't return a value but executes the fusion
|
|
58
|
+
program which may save results, log outputs, or perform other
|
|
59
|
+
side effects as configured.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
This function is typically called automatically when running:
|
|
63
|
+
```bash
|
|
64
|
+
fusion_bench method=... modelpool=... taskpool=...
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
The Hydra decorator handles parsing these command-line arguments and
|
|
68
|
+
loading the corresponding configuration files to populate the cfg parameter.
|
|
69
|
+
"""
|
|
37
70
|
OmegaConf.resolve(cfg)
|
|
38
71
|
program: BaseHydraProgram = instantiate(cfg)
|
|
39
72
|
program.run()
|