fusion-bench 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/constants/banner.py +12 -0
- fusion_bench/method/linear/simple_average_for_llama.py +17 -3
- fusion_bench/method/simple_average.py +10 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +37 -6
- fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
- fusion_bench/utils/lazy_state_dict.py +75 -3
- fusion_bench/utils/misc.py +19 -1
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.19.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.19.dist-info}/RECORD +15 -13
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.19.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.19.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.19.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
# ███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗
|
|
2
|
+
# ██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║
|
|
3
|
+
# █████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║
|
|
4
|
+
# ██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║
|
|
5
|
+
# ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
|
|
6
|
+
# ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
|
|
1
7
|
# flake8: noqa: F401
|
|
2
8
|
from . import (
|
|
3
9
|
constants,
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
FUSION_BENCH_BANNER = (
|
|
2
|
+
""
|
|
3
|
+
+ "███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗\n"
|
|
4
|
+
+ "██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║\n"
|
|
5
|
+
+ "█████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║\n"
|
|
6
|
+
+ "██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║\n"
|
|
7
|
+
+ "██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║\n"
|
|
8
|
+
+ "╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝\n"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
if __name__ == "__main__":
|
|
12
|
+
print(FUSION_BENCH_BANNER)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
3
|
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
@@ -6,6 +7,11 @@ from fusion_bench import timeit_context
|
|
|
6
7
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
7
8
|
from fusion_bench.method.simple_average import SimpleAverageAlgorithm
|
|
8
9
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
10
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
11
|
+
from omegaconf import flag_override
|
|
12
|
+
from fusion_bench.utils import instantiate
|
|
13
|
+
|
|
14
|
+
log = getRankZeroLogger(__name__)
|
|
9
15
|
|
|
10
16
|
|
|
11
17
|
class SimpleAverageForLlama(BaseAlgorithm):
|
|
@@ -40,12 +46,20 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
40
46
|
|
|
41
47
|
if self.merge_backbone:
|
|
42
48
|
assert modelpool.has_pretrained
|
|
43
|
-
|
|
49
|
+
log.info(
|
|
50
|
+
"Merging backbone of the model pool, use CausalLMBackbonePool instead of CausalLMPool."
|
|
51
|
+
)
|
|
52
|
+
modelpool_config = deepcopy(modelpool.config)
|
|
53
|
+
with flag_override(modelpool_config, "allow_objects", True):
|
|
54
|
+
modelpool_config._target_ = (
|
|
55
|
+
"fusion_bench.modelpool.causal_lm.CausalLMBackbonePool"
|
|
56
|
+
)
|
|
57
|
+
backbone_modelpool = instantiate(modelpool_config)
|
|
44
58
|
model = modelpool.load_model("_pretrained_")
|
|
45
59
|
backbone_model = SimpleAverageAlgorithm().run(backbone_modelpool)
|
|
46
60
|
model.model.layers = backbone_model
|
|
47
61
|
else:
|
|
48
|
-
model = SimpleAverageAlgorithm().run()
|
|
62
|
+
model = SimpleAverageAlgorithm().run(modelpool=modelpool)
|
|
49
63
|
|
|
50
64
|
if self.model_save_path is not None:
|
|
51
65
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
@@ -8,6 +8,7 @@ from torch import nn
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
9
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils import LazyStateDict
|
|
11
12
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
13
|
state_dict_add,
|
|
13
14
|
state_dict_avg,
|
|
@@ -104,6 +105,15 @@ class SimpleAverageAlgorithm(
|
|
|
104
105
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
105
106
|
sd = state_dict_div(sd, len(modelpool.model_names))
|
|
106
107
|
|
|
108
|
+
if isinstance(forward_model, LazyStateDict):
|
|
109
|
+
# if the model is a LazyStateDict, convert it to an empty module
|
|
110
|
+
forward_model = forward_model.meta_module.to_empty(
|
|
111
|
+
device=(
|
|
112
|
+
"cpu"
|
|
113
|
+
if forward_model._torch_dtype is None
|
|
114
|
+
else forward_model._torch_dtype
|
|
115
|
+
)
|
|
116
|
+
)
|
|
107
117
|
forward_model.load_state_dict(sd)
|
|
108
118
|
# print profile report and log the merged models
|
|
109
119
|
self.print_profile_summary()
|
|
@@ -22,6 +22,8 @@ from typing_extensions import override
|
|
|
22
22
|
from fusion_bench.modelpool import BaseModelPool
|
|
23
23
|
from fusion_bench.utils import instantiate
|
|
24
24
|
from fusion_bench.utils.dtype import parse_dtype
|
|
25
|
+
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
26
|
+
from fusion_bench.utils.packages import import_object
|
|
25
27
|
|
|
26
28
|
log = logging.getLogger(__name__)
|
|
27
29
|
|
|
@@ -30,6 +32,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
30
32
|
_config_mapping = BaseModelPool._config_mapping | {
|
|
31
33
|
"_tokenizer": "tokenizer",
|
|
32
34
|
"_model_kwargs": "model_kwargs",
|
|
35
|
+
"load_lazy": "load_lazy",
|
|
33
36
|
}
|
|
34
37
|
|
|
35
38
|
def __init__(
|
|
@@ -38,6 +41,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
38
41
|
*,
|
|
39
42
|
tokenizer: Optional[DictConfig],
|
|
40
43
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
+
load_lazy: bool = False,
|
|
41
45
|
**kwargs,
|
|
42
46
|
):
|
|
43
47
|
super().__init__(models, **kwargs)
|
|
@@ -51,6 +55,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
51
55
|
self._model_kwargs.torch_dtype = parse_dtype(
|
|
52
56
|
self._model_kwargs.torch_dtype
|
|
53
57
|
)
|
|
58
|
+
self.load_lazy = load_lazy
|
|
54
59
|
|
|
55
60
|
@override
|
|
56
61
|
def load_model(
|
|
@@ -88,21 +93,41 @@ class CausalLMPool(BaseModelPool):
|
|
|
88
93
|
model_kwargs.update(kwargs)
|
|
89
94
|
|
|
90
95
|
if isinstance(model_name_or_config, str):
|
|
96
|
+
# If model_name_or_config is a string, it is the name or the path of the model
|
|
91
97
|
log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
92
98
|
if model_name_or_config in self._models.keys():
|
|
93
99
|
model_config = self._models[model_name_or_config]
|
|
94
100
|
if isinstance(model_config, str):
|
|
95
101
|
# model_config is a string
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
if not self.load_lazy:
|
|
103
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
104
|
+
model_config,
|
|
105
|
+
*args,
|
|
106
|
+
**model_kwargs,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
# model_config is a string, but we want to use LazyStateDict
|
|
110
|
+
model = LazyStateDict(
|
|
111
|
+
checkpoint=model_config,
|
|
112
|
+
meta_module_class=AutoModelForCausalLM,
|
|
113
|
+
*args,
|
|
114
|
+
**model_kwargs,
|
|
115
|
+
)
|
|
101
116
|
return model
|
|
102
117
|
elif isinstance(model_name_or_config, (DictConfig, Dict)):
|
|
103
118
|
model_config = model_name_or_config
|
|
104
119
|
|
|
105
|
-
|
|
120
|
+
if not self.load_lazy:
|
|
121
|
+
model = instantiate(model_config, *args, **model_kwargs)
|
|
122
|
+
else:
|
|
123
|
+
meta_module_class = model_config.pop("_target_")
|
|
124
|
+
checkpoint = model_config.pop("pretrained_model_name_or_path")
|
|
125
|
+
model = LazyStateDict(
|
|
126
|
+
checkpoint=checkpoint,
|
|
127
|
+
meta_module_class=meta_module_class,
|
|
128
|
+
*args,
|
|
129
|
+
**model_kwargs,
|
|
130
|
+
)
|
|
106
131
|
return model
|
|
107
132
|
|
|
108
133
|
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
@@ -179,6 +204,12 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
179
204
|
def load_model(
|
|
180
205
|
self, model_name_or_config: str | DictConfig, *args, **kwargs
|
|
181
206
|
) -> Module:
|
|
207
|
+
if self.load_lazy:
|
|
208
|
+
log.warning(
|
|
209
|
+
"CausalLMBackbonePool does not support lazy loading. "
|
|
210
|
+
"Falling back to normal loading."
|
|
211
|
+
)
|
|
212
|
+
self.load_lazy = False
|
|
182
213
|
model: AutoModelForCausalLM = super().load_model(
|
|
183
214
|
model_name_or_config, *args, **kwargs
|
|
184
215
|
)
|