fusion-bench 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +21 -2
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/method/__init__.py +8 -2
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +11 -42
- fusion_bench/mixins/serialization.py +18 -8
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -33
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +65 -87
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +9 -4
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +53 -45
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
|
@@ -90,14 +90,21 @@ def auto_register_config(cls):
|
|
|
90
90
|
# Auto-register parameters in _config_mapping
|
|
91
91
|
if not "_config_mapping" in cls.__dict__:
|
|
92
92
|
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
|
|
93
|
+
registered_parameters = tuple(cls._config_mapping.values())
|
|
94
|
+
|
|
93
95
|
for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
if (
|
|
97
|
+
sig.parameters[param_name].kind
|
|
98
|
+
not in [
|
|
99
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
100
|
+
_ParameterKind.VAR_KEYWORD,
|
|
101
|
+
]
|
|
102
|
+
) and (param_name not in registered_parameters):
|
|
98
103
|
cls._config_mapping[param_name] = param_name
|
|
99
104
|
|
|
100
105
|
def __init__(self, *args, **kwargs):
|
|
106
|
+
nonlocal original_init, registered_parameters
|
|
107
|
+
|
|
101
108
|
# auto-register the attributes based on the signature
|
|
102
109
|
sig = inspect.signature(original_init)
|
|
103
110
|
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
|
@@ -114,10 +121,13 @@ def auto_register_config(cls):
|
|
|
114
121
|
|
|
115
122
|
# Handle keyword arguments and defaults
|
|
116
123
|
for param_name in param_names:
|
|
117
|
-
if
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
124
|
+
if (
|
|
125
|
+
sig.parameters[param_name].kind
|
|
126
|
+
not in [
|
|
127
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
128
|
+
_ParameterKind.VAR_KEYWORD,
|
|
129
|
+
]
|
|
130
|
+
) and (param_name not in registered_parameters):
|
|
121
131
|
# Skip if already set by positional argument
|
|
122
132
|
param_index = param_names.index(param_name)
|
|
123
133
|
if param_index >= 0 and param_index < len(args):
|
|
@@ -8,7 +8,7 @@ from copy import deepcopy
|
|
|
8
8
|
from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
|
|
9
9
|
|
|
10
10
|
import peft
|
|
11
|
-
from omegaconf import DictConfig, flag_override
|
|
11
|
+
from omegaconf import DictConfig, OmegaConf, flag_override
|
|
12
12
|
from torch import nn
|
|
13
13
|
from torch.nn.modules import Module
|
|
14
14
|
from transformers import (
|
|
@@ -19,43 +19,32 @@ from transformers import (
|
|
|
19
19
|
)
|
|
20
20
|
from typing_extensions import override
|
|
21
21
|
|
|
22
|
-
from fusion_bench
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
from fusion_bench import (
|
|
23
|
+
BaseModelPool,
|
|
24
|
+
auto_register_config,
|
|
25
|
+
import_object,
|
|
26
|
+
instantiate,
|
|
27
|
+
parse_dtype,
|
|
28
|
+
)
|
|
25
29
|
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
26
|
-
from fusion_bench.utils.packages import import_object
|
|
27
30
|
|
|
28
31
|
log = logging.getLogger(__name__)
|
|
29
32
|
|
|
30
33
|
|
|
34
|
+
@auto_register_config
|
|
31
35
|
class CausalLMPool(BaseModelPool):
|
|
32
|
-
_config_mapping = BaseModelPool._config_mapping | {
|
|
33
|
-
"_tokenizer": "tokenizer",
|
|
34
|
-
"_model_kwargs": "model_kwargs",
|
|
35
|
-
"load_lazy": "load_lazy",
|
|
36
|
-
}
|
|
37
|
-
|
|
38
36
|
def __init__(
|
|
39
37
|
self,
|
|
40
38
|
models,
|
|
41
39
|
*,
|
|
42
|
-
tokenizer: Optional[DictConfig],
|
|
40
|
+
tokenizer: Optional[DictConfig | str],
|
|
43
41
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
-
|
|
42
|
+
enable_lazy_loading: bool = False,
|
|
45
43
|
**kwargs,
|
|
46
44
|
):
|
|
47
45
|
super().__init__(models, **kwargs)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
self._model_kwargs = model_kwargs
|
|
51
|
-
if self._model_kwargs is None:
|
|
52
|
-
self._model_kwargs = DictConfig({})
|
|
53
|
-
with flag_override(self._model_kwargs, "allow_objects", True):
|
|
54
|
-
if hasattr(self._model_kwargs, "torch_dtype"):
|
|
55
|
-
self._model_kwargs.torch_dtype = parse_dtype(
|
|
56
|
-
self._model_kwargs.torch_dtype
|
|
57
|
-
)
|
|
58
|
-
self.load_lazy = load_lazy
|
|
46
|
+
if model_kwargs is None:
|
|
47
|
+
self.model_kwargs = DictConfig({})
|
|
59
48
|
|
|
60
49
|
def get_model_path(self, model_name: str):
|
|
61
50
|
model_name_or_config = self._models[model_name]
|
|
@@ -66,6 +55,16 @@ class CausalLMPool(BaseModelPool):
|
|
|
66
55
|
else:
|
|
67
56
|
raise RuntimeError("Invalid model configuration")
|
|
68
57
|
|
|
58
|
+
def get_model_kwargs(self):
|
|
59
|
+
model_kwargs = (
|
|
60
|
+
OmegaConf.to_container(self.model_kwargs, resolve=True)
|
|
61
|
+
if isinstance(self.model_kwargs, DictConfig)
|
|
62
|
+
else self.model_kwargs
|
|
63
|
+
)
|
|
64
|
+
if "torch_dtype" in model_kwargs:
|
|
65
|
+
model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
|
|
66
|
+
return model_kwargs
|
|
67
|
+
|
|
69
68
|
@override
|
|
70
69
|
def load_model(
|
|
71
70
|
self,
|
|
@@ -98,7 +97,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
98
97
|
pretrained_model_name_or_path: path_to_model_b
|
|
99
98
|
```
|
|
100
99
|
"""
|
|
101
|
-
model_kwargs =
|
|
100
|
+
model_kwargs = self.get_model_kwargs()
|
|
102
101
|
model_kwargs.update(kwargs)
|
|
103
102
|
|
|
104
103
|
if isinstance(model_name_or_config, str):
|
|
@@ -108,7 +107,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
108
107
|
model_config = self._models[model_name_or_config]
|
|
109
108
|
if isinstance(model_config, str):
|
|
110
109
|
# model_config is a string
|
|
111
|
-
if not self.
|
|
110
|
+
if not self.enable_lazy_loading:
|
|
112
111
|
model = AutoModelForCausalLM.from_pretrained(
|
|
113
112
|
model_config,
|
|
114
113
|
*args,
|
|
@@ -126,7 +125,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
126
125
|
elif isinstance(model_name_or_config, (DictConfig, Dict)):
|
|
127
126
|
model_config = model_name_or_config
|
|
128
127
|
|
|
129
|
-
if not self.
|
|
128
|
+
if not self.enable_lazy_loading:
|
|
130
129
|
model = instantiate(model_config, *args, **model_kwargs)
|
|
131
130
|
else:
|
|
132
131
|
meta_module_class = model_config.pop("_target_")
|
|
@@ -158,12 +157,12 @@ class CausalLMPool(BaseModelPool):
|
|
|
158
157
|
Returns:
|
|
159
158
|
PreTrainedTokenizer: The tokenizer.
|
|
160
159
|
"""
|
|
161
|
-
assert self.
|
|
160
|
+
assert self.tokenizer is not None, "Tokenizer is not defined in the config"
|
|
162
161
|
log.info("Loading tokenizer.", stacklevel=2)
|
|
163
|
-
if isinstance(self.
|
|
164
|
-
tokenizer = AutoTokenizer.from_pretrained(self.
|
|
162
|
+
if isinstance(self.tokenizer, str):
|
|
163
|
+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
|
|
165
164
|
else:
|
|
166
|
-
tokenizer = instantiate(self.
|
|
165
|
+
tokenizer = instantiate(self.tokenizer, *args, **kwargs)
|
|
167
166
|
return tokenizer
|
|
168
167
|
|
|
169
168
|
@override
|
|
@@ -213,12 +212,12 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
213
212
|
def load_model(
|
|
214
213
|
self, model_name_or_config: str | DictConfig, *args, **kwargs
|
|
215
214
|
) -> Module:
|
|
216
|
-
if self.
|
|
215
|
+
if self.enable_lazy_loading:
|
|
217
216
|
log.warning(
|
|
218
217
|
"CausalLMBackbonePool does not support lazy loading. "
|
|
219
218
|
"Falling back to normal loading."
|
|
220
219
|
)
|
|
221
|
-
self.
|
|
220
|
+
self.enable_lazy_loading = False
|
|
222
221
|
model: AutoModelForCausalLM = super().load_model(
|
|
223
222
|
model_name_or_config, *args, **kwargs
|
|
224
223
|
)
|
fusion_bench/models/__init__.py
CHANGED
fusion_bench/models/hf_utils.py
CHANGED
|
@@ -5,23 +5,63 @@ This module contains utilities for working with Hugging Face models.
|
|
|
5
5
|
import inspect
|
|
6
6
|
import os
|
|
7
7
|
import shutil
|
|
8
|
-
from typing import Optional, cast
|
|
8
|
+
from typing import List, Optional, cast
|
|
9
9
|
|
|
10
|
-
from omegaconf import OmegaConf
|
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
|
11
11
|
from transformers.modeling_utils import PreTrainedModel
|
|
12
12
|
|
|
13
|
-
from fusion_bench import
|
|
14
|
-
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
13
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
15
14
|
|
|
16
|
-
log =
|
|
15
|
+
log = get_rankzero_logger(__name__)
|
|
17
16
|
|
|
18
17
|
__all__ = [
|
|
18
|
+
"load_model_card_template",
|
|
19
19
|
"save_pretrained_with_remote_code",
|
|
20
|
-
"
|
|
21
|
-
"generate_readme_body",
|
|
22
|
-
"generate_complete_readme",
|
|
20
|
+
"create_default_model_card",
|
|
23
21
|
]
|
|
24
22
|
|
|
23
|
+
MODEL_CARD_TEMPLATE_DIRS = [
|
|
24
|
+
os.path.join(os.path.dirname(__file__), "model_card_templates")
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_model_card_template(basename: str) -> str:
|
|
29
|
+
"""
|
|
30
|
+
Load a model card template from file.
|
|
31
|
+
|
|
32
|
+
Searches for a template file by name, first checking if the name is a direct file path,
|
|
33
|
+
then searching through predefined template directories.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name (str): The name of the template file or a direct file path to the template.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The contents of the template file as a string.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
FileNotFoundError: If the template file is not found in any of the search locations.
|
|
43
|
+
"""
|
|
44
|
+
if os.path.exists(basename):
|
|
45
|
+
return open(basename).read()
|
|
46
|
+
|
|
47
|
+
for template_dir in MODEL_CARD_TEMPLATE_DIRS:
|
|
48
|
+
template_path = os.path.join(template_dir, basename)
|
|
49
|
+
if os.path.exists(template_path):
|
|
50
|
+
return open(template_path).read()
|
|
51
|
+
|
|
52
|
+
raise FileNotFoundError(f"Model card template '{basename}' not found.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def try_to_yaml(config):
|
|
56
|
+
if config is None:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
return OmegaConf.to_yaml(config, resolve=True, sort_keys=True)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
log.error(f"Failed to convert config to YAML: {e}. Return `None`.")
|
|
63
|
+
return None
|
|
64
|
+
|
|
25
65
|
|
|
26
66
|
def save_pretrained_with_remote_code(
|
|
27
67
|
model: PreTrainedModel,
|
|
@@ -99,84 +139,22 @@ def save_pretrained_with_remote_code(
|
|
|
99
139
|
f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
|
|
100
140
|
|
|
101
141
|
|
|
102
|
-
def
|
|
103
|
-
models: list[str]
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
text = "---\nbase_model:\n"
|
|
108
|
-
for model_name in models:
|
|
109
|
-
text += f"- {model_name}\n"
|
|
110
|
-
if library_name:
|
|
111
|
-
text += f"library_name: {library_name}\n"
|
|
112
|
-
text += "tags:\n"
|
|
113
|
-
for tag in tags:
|
|
114
|
-
text += f"- {tag}\n"
|
|
115
|
-
text += "---\n"
|
|
116
|
-
return text
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def generate_readme_body(
|
|
120
|
-
algorithm: BaseAlgorithm,
|
|
121
|
-
models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
|
|
122
|
-
models: list[str] = None,
|
|
142
|
+
def create_default_model_card(
|
|
143
|
+
models: list[str],
|
|
144
|
+
description=None,
|
|
145
|
+
algorithm_config: DictConfig = None,
|
|
146
|
+
modelpool_config: DictConfig = None,
|
|
123
147
|
):
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
""
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
The following models were included in the merge:
|
|
136
|
-
|
|
137
|
-
"""
|
|
138
|
-
for model_name in models:
|
|
139
|
-
text += f"- {model_name}\n"
|
|
140
|
-
text += "\n"
|
|
141
|
-
|
|
142
|
-
try:
|
|
143
|
-
text += f"""\
|
|
144
|
-
## Configuration
|
|
145
|
-
|
|
146
|
-
The following YAML configuration was used to produce this model:
|
|
147
|
-
|
|
148
|
-
```yaml
|
|
149
|
-
{OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
|
|
150
|
-
```
|
|
151
|
-
"""
|
|
152
|
-
except Exception as e:
|
|
153
|
-
return (
|
|
154
|
-
text # If the algorithm config cannot be converted to YAML, we skip it.
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
if isinstance(models_or_modelpool, BaseModelPool):
|
|
158
|
-
try:
|
|
159
|
-
text += f"""
|
|
160
|
-
```yaml
|
|
161
|
-
{OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
|
|
162
|
-
```
|
|
163
|
-
"""
|
|
164
|
-
except Exception as e:
|
|
165
|
-
pass # If the model pool config cannot be converted to YAML, we skip it.
|
|
166
|
-
return text
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def generate_complete_readme(
|
|
170
|
-
algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
|
|
171
|
-
):
|
|
172
|
-
# Generate the complete README content
|
|
173
|
-
text = generate_readme_head(
|
|
174
|
-
[modelpool.get_model_path(m) for m in modelpool.model_names]
|
|
175
|
-
)
|
|
176
|
-
readme_body = generate_readme_body(
|
|
177
|
-
algorithm,
|
|
178
|
-
models_or_modelpool=modelpool,
|
|
179
|
-
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
148
|
+
from jinja2 import Template
|
|
149
|
+
|
|
150
|
+
template: Template = Template(load_model_card_template("default.md"))
|
|
151
|
+
card = template.render(
|
|
152
|
+
models=models,
|
|
153
|
+
library_name="transformers",
|
|
154
|
+
tags=["fusion-bench", "merge"],
|
|
155
|
+
title="Deep Model Fusion",
|
|
156
|
+
description=description,
|
|
157
|
+
algorithm_config_str=try_to_yaml(algorithm_config),
|
|
158
|
+
modelpool_config_str=try_to_yaml(modelpool_config),
|
|
180
159
|
)
|
|
181
|
-
|
|
182
|
-
return complete_readme
|
|
160
|
+
return card
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
---
|
|
2
|
+
base_model:
|
|
3
|
+
{%- for model in models %}
|
|
4
|
+
- {{ model }}
|
|
5
|
+
{%- endfor %}
|
|
6
|
+
library_name: {{ library_name }}
|
|
7
|
+
tags:
|
|
8
|
+
{%- for tag in tags %}
|
|
9
|
+
- {{ tag }}
|
|
10
|
+
{%- endfor %}
|
|
11
|
+
---
|
|
12
|
+
# {{ title }}
|
|
13
|
+
|
|
14
|
+
{% if description is not none %}{{ description }}{% endif %}
|
|
15
|
+
|
|
16
|
+
## Models Merged
|
|
17
|
+
|
|
18
|
+
This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
19
|
+
|
|
20
|
+
The following models were included in the merge:
|
|
21
|
+
{% for model in models %}
|
|
22
|
+
- {{ model }}
|
|
23
|
+
{%- endfor %}
|
|
24
|
+
|
|
25
|
+
{% if algorithm_config_str is not none or modelpool_config_str is not none %}
|
|
26
|
+
## Configuration
|
|
27
|
+
|
|
28
|
+
The following YAML configuration was used to produce this model:
|
|
29
|
+
|
|
30
|
+
{% if algorithm_config_str is not none -%}
|
|
31
|
+
### Algorithm Configuration
|
|
32
|
+
|
|
33
|
+
```yaml
|
|
34
|
+
{{ algorithm_config_str -}}
|
|
35
|
+
```
|
|
36
|
+
{%- endif %}
|
|
37
|
+
|
|
38
|
+
{% if modelpool_config_str is not none -%}
|
|
39
|
+
### Model Pool Configuration
|
|
40
|
+
|
|
41
|
+
```yaml
|
|
42
|
+
{{ modelpool_config_str -}}
|
|
43
|
+
```
|
|
44
|
+
{%- endif %}
|
|
45
|
+
|
|
46
|
+
{% endif %}
|
|
@@ -17,7 +17,6 @@ from transformers.modeling_outputs import (
|
|
|
17
17
|
)
|
|
18
18
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
19
19
|
from transformers.models.llama.modeling_llama import (
|
|
20
|
-
LLAMA_INPUTS_DOCSTRING,
|
|
21
20
|
LlamaRMSNorm,
|
|
22
21
|
LlamaRotaryEmbedding,
|
|
23
22
|
apply_rotary_pos_emb,
|
|
@@ -25,7 +24,6 @@ from transformers.models.llama.modeling_llama import (
|
|
|
25
24
|
)
|
|
26
25
|
from transformers.processing_utils import Unpack
|
|
27
26
|
from transformers.utils import (
|
|
28
|
-
LossKwargs,
|
|
29
27
|
add_start_docstrings_to_model_forward,
|
|
30
28
|
can_return_tuple,
|
|
31
29
|
is_torch_flex_attn_available,
|
|
@@ -296,7 +294,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
|
|
|
296
294
|
self.embed_tokens = value
|
|
297
295
|
|
|
298
296
|
@can_return_tuple
|
|
299
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
300
297
|
def forward(
|
|
301
298
|
self,
|
|
302
299
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -566,9 +563,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
|
|
|
566
563
|
return causal_mask
|
|
567
564
|
|
|
568
565
|
|
|
569
|
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
570
|
-
|
|
571
|
-
|
|
572
566
|
class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
|
|
573
567
|
_tied_weights_keys = ["lm_head.weight"]
|
|
574
568
|
_tp_plan = {"lm_head": "colwise_rep"}
|
|
@@ -603,7 +597,6 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
|
|
|
603
597
|
|
|
604
598
|
@can_return_tuple
|
|
605
599
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
606
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
607
600
|
@replace_return_docstrings(
|
|
608
601
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
609
602
|
)
|
|
@@ -620,7 +613,7 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
|
|
|
620
613
|
output_hidden_states: Optional[bool] = None,
|
|
621
614
|
cache_position: Optional[torch.LongTensor] = None,
|
|
622
615
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
623
|
-
**kwargs
|
|
616
|
+
**kwargs,
|
|
624
617
|
) -> CausalLMOutputWithPast:
|
|
625
618
|
r"""
|
|
626
619
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -31,7 +31,6 @@ from transformers.models.qwen2.modeling_qwen2 import (
|
|
|
31
31
|
)
|
|
32
32
|
from transformers.processing_utils import Unpack
|
|
33
33
|
from transformers.utils import (
|
|
34
|
-
LossKwargs,
|
|
35
34
|
add_code_sample_docstrings,
|
|
36
35
|
add_start_docstrings,
|
|
37
36
|
add_start_docstrings_to_model_forward,
|
|
@@ -607,9 +606,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
|
|
|
607
606
|
return causal_mask
|
|
608
607
|
|
|
609
608
|
|
|
610
|
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
611
|
-
|
|
612
|
-
|
|
613
609
|
class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
614
610
|
_tied_weights_keys = ["lm_head.weight"]
|
|
615
611
|
_tp_plan = {"lm_head": "colwise_rep"}
|
|
@@ -660,7 +656,7 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
|
660
656
|
output_hidden_states: Optional[bool] = None,
|
|
661
657
|
cache_position: Optional[torch.LongTensor] = None,
|
|
662
658
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
663
|
-
**kwargs
|
|
659
|
+
**kwargs,
|
|
664
660
|
) -> CausalLMOutputWithPast:
|
|
665
661
|
r"""
|
|
666
662
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
|
|
5
6
|
|
|
6
7
|
import lightning as L
|
|
@@ -9,19 +10,24 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
9
10
|
from torch import nn
|
|
10
11
|
from tqdm.auto import tqdm
|
|
11
12
|
|
|
12
|
-
import fusion_bench
|
|
13
|
-
from fusion_bench
|
|
13
|
+
import fusion_bench
|
|
14
|
+
from fusion_bench import (
|
|
15
|
+
BaseAlgorithm,
|
|
16
|
+
BaseHydraProgram,
|
|
17
|
+
BaseModelPool,
|
|
18
|
+
BaseTaskPool,
|
|
19
|
+
RuntimeConstants,
|
|
20
|
+
import_object,
|
|
21
|
+
instantiate,
|
|
22
|
+
timeit_context,
|
|
23
|
+
)
|
|
14
24
|
from fusion_bench.mixins import LightningFabricMixin
|
|
15
|
-
from fusion_bench.modelpool import BaseModelPool
|
|
16
|
-
from fusion_bench.programs import BaseHydraProgram
|
|
17
|
-
from fusion_bench.taskpool import BaseTaskPool
|
|
18
|
-
from fusion_bench.utils import import_object, instantiate, timeit_context
|
|
19
25
|
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
20
26
|
from fusion_bench.utils.json import print_json
|
|
21
|
-
from fusion_bench.utils.
|
|
27
|
+
from fusion_bench.utils.path import create_symlink
|
|
22
28
|
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
23
29
|
|
|
24
|
-
log =
|
|
30
|
+
log = fusion_bench.get_rankzero_logger(__name__)
|
|
25
31
|
|
|
26
32
|
|
|
27
33
|
class FabricModelFusionProgram(
|
|
@@ -60,6 +66,7 @@ class FabricModelFusionProgram(
|
|
|
60
66
|
path: DictConfig = None,
|
|
61
67
|
**kwargs,
|
|
62
68
|
):
|
|
69
|
+
super().__init__(**kwargs)
|
|
63
70
|
self._method = method
|
|
64
71
|
self._modelpool = modelpool
|
|
65
72
|
self._taskpool = taskpool
|
|
@@ -70,8 +77,10 @@ class FabricModelFusionProgram(
|
|
|
70
77
|
self.fast_dev_run = fast_dev_run
|
|
71
78
|
self.seed = seed
|
|
72
79
|
self.path = path
|
|
73
|
-
|
|
74
|
-
|
|
80
|
+
RuntimeConstants.debug = fast_dev_run
|
|
81
|
+
RuntimeConstants.print_function_call = print_function_call
|
|
82
|
+
if path is not None:
|
|
83
|
+
RuntimeConstants.cache_dir = path.get("cache_dir", None)
|
|
75
84
|
|
|
76
85
|
if print_config:
|
|
77
86
|
print_config_tree(
|
|
@@ -224,8 +233,16 @@ class FabricModelFusionProgram(
|
|
|
224
233
|
fabric = self.fabric
|
|
225
234
|
if self.seed is not None:
|
|
226
235
|
L.seed_everything(self.seed)
|
|
227
|
-
|
|
228
|
-
|
|
236
|
+
|
|
237
|
+
# create symbol link to hydra output directory
|
|
238
|
+
if (
|
|
239
|
+
self.fabric.is_global_zero
|
|
240
|
+
and self.log_dir is not None
|
|
241
|
+
and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
|
|
242
|
+
):
|
|
243
|
+
create_symlink(
|
|
244
|
+
get_hydra_output_dir(), self.log_dir, link_name="hydra_output"
|
|
245
|
+
)
|
|
229
246
|
|
|
230
247
|
log.info("Running the model fusion program.")
|
|
231
248
|
# setup the modelpool, method, and taskpool
|
|
@@ -278,51 +295,3 @@ class FabricModelFusionProgram(
|
|
|
278
295
|
json.dump(report, open(self.report_save_path, "w"))
|
|
279
296
|
else:
|
|
280
297
|
log.info("No task pool specified. Skipping evaluation.")
|
|
281
|
-
|
|
282
|
-
@rank_zero_only
|
|
283
|
-
def _link_hydra_output(self):
|
|
284
|
-
"""
|
|
285
|
-
Creates a symbolic link to the Hydra output directory within the specified log directory.
|
|
286
|
-
|
|
287
|
-
If `self.log_dir` is not None, this method will:
|
|
288
|
-
1. Retrieve the Hydra output directory using `get_hydra_output_dir()`.
|
|
289
|
-
2. Create the log directory if it does not already exist.
|
|
290
|
-
3. Create a symbolic link named "hydra_output_<basename_of_hydra_output_dir>"
|
|
291
|
-
within the log directory, pointing to the Hydra output directory.
|
|
292
|
-
|
|
293
|
-
Note:
|
|
294
|
-
- The symbolic link is created only if the Hydra output directory is not None.
|
|
295
|
-
- The `target_is_directory` parameter is set to True to indicate that the target is a directory.
|
|
296
|
-
|
|
297
|
-
Raises:
|
|
298
|
-
OSError: If the symbolic link creation fails.
|
|
299
|
-
"""
|
|
300
|
-
if self.log_dir is not None:
|
|
301
|
-
# make symlink to the hydra output directory
|
|
302
|
-
try:
|
|
303
|
-
hydra_output_dir = get_hydra_output_dir()
|
|
304
|
-
except Exception as e:
|
|
305
|
-
hydra_output_dir = None
|
|
306
|
-
|
|
307
|
-
if hydra_output_dir is not None:
|
|
308
|
-
if os.path.abspath(hydra_output_dir) == os.path.abspath(self.log_dir):
|
|
309
|
-
return
|
|
310
|
-
|
|
311
|
-
os.makedirs(self.log_dir, exist_ok=True)
|
|
312
|
-
try:
|
|
313
|
-
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
|
314
|
-
if os.name == "nt":
|
|
315
|
-
os.system(
|
|
316
|
-
f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}"
|
|
317
|
-
)
|
|
318
|
-
else:
|
|
319
|
-
os.symlink(
|
|
320
|
-
hydra_output_dir,
|
|
321
|
-
os.path.join(
|
|
322
|
-
self.log_dir,
|
|
323
|
-
"hydra_output_" + os.path.basename(hydra_output_dir),
|
|
324
|
-
),
|
|
325
|
-
target_is_directory=True,
|
|
326
|
-
)
|
|
327
|
-
except OSError as e:
|
|
328
|
-
log.warning(f"Failed to create symbolic link: {e}")
|