fusion-bench 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +8 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_state_dict.py +3 -0
- fusion_bench/utils/rich_utils.py +7 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +37 -30
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
fusion_bench/method/__init__.py
CHANGED
|
@@ -26,9 +26,12 @@ _import_structure = {
|
|
|
26
26
|
"linear": [
|
|
27
27
|
"ExPOAlgorithm",
|
|
28
28
|
"ExPOAlgorithmForLlama",
|
|
29
|
+
"SimpleAverageForCausalLM",
|
|
29
30
|
"SimpleAverageForLlama",
|
|
31
|
+
"TaskArithmeticForCausalLM",
|
|
30
32
|
"TaskArithmeticForLlama",
|
|
31
33
|
"LinearInterpolationAlgorithm",
|
|
34
|
+
"TiesMergingForCausalLM",
|
|
32
35
|
],
|
|
33
36
|
"slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
|
|
34
37
|
"simple_average": ["SimpleAverageAlgorithm"],
|
|
@@ -72,6 +75,7 @@ _import_structure = {
|
|
|
72
75
|
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
73
76
|
"tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
|
|
74
77
|
"model_stock": ["ModelStock"],
|
|
78
|
+
"wudi": ["wudi_merging", "WUDIMerging"],
|
|
75
79
|
# plug-and-play model merging methods
|
|
76
80
|
"concrete_subspace": [
|
|
77
81
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -184,8 +188,11 @@ if TYPE_CHECKING:
|
|
|
184
188
|
ExPOAlgorithm,
|
|
185
189
|
ExPOAlgorithmForLlama,
|
|
186
190
|
LinearInterpolationAlgorithm,
|
|
191
|
+
SimpleAverageForCausalLM,
|
|
187
192
|
SimpleAverageForLlama,
|
|
193
|
+
TaskArithmeticForCausalLM,
|
|
188
194
|
TaskArithmeticForLlama,
|
|
195
|
+
TiesMergingForCausalLM,
|
|
189
196
|
)
|
|
190
197
|
from .lm_finetune import *
|
|
191
198
|
from .mixture_of_experts import (
|
|
@@ -238,6 +245,7 @@ if TYPE_CHECKING:
|
|
|
238
245
|
FlanT5WeightEnsemblingMoEAlgorithm,
|
|
239
246
|
)
|
|
240
247
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
248
|
+
from .wudi import WUDIMerging, wudi_merging
|
|
241
249
|
|
|
242
250
|
else:
|
|
243
251
|
sys.modules[__name__] = LazyImporter(
|
fusion_bench/method/ensemble.py
CHANGED
|
@@ -17,7 +17,21 @@ from fusion_bench.models.wrappers.ensemble import (
|
|
|
17
17
|
log = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
@auto_register_config
|
|
20
21
|
class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initializes the SimpleEnsembleAlgorithm with an optional device map.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(**kwargs)
|
|
34
|
+
|
|
21
35
|
@torch.no_grad()
|
|
22
36
|
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
|
|
23
37
|
"""
|
|
@@ -30,9 +44,10 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
|
30
44
|
EnsembleModule: The ensembled model.
|
|
31
45
|
"""
|
|
32
46
|
log.info(f"Running ensemble algorithm with {len(modelpool)} models")
|
|
33
|
-
|
|
34
47
|
models = [modelpool.load_model(m) for m in modelpool.model_names]
|
|
35
|
-
|
|
48
|
+
|
|
49
|
+
log.info("creating ensemble module")
|
|
50
|
+
ensemble = EnsembleModule(models=models, device_map=self.device_map)
|
|
36
51
|
return ensemble
|
|
37
52
|
|
|
38
53
|
|
|
@@ -2,5 +2,9 @@
|
|
|
2
2
|
from .expo import ExPOAlgorithm
|
|
3
3
|
from .linear_interpolation import LinearInterpolationAlgorithm
|
|
4
4
|
from .llama_expo import ExPOAlgorithmForLlama
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
5
|
+
from .simple_average_for_causallm import SimpleAverageForCausalLM, SimpleAverageForLlama
|
|
6
|
+
from .task_arithmetic_for_causallm import (
|
|
7
|
+
TaskArithmeticForCausalLM,
|
|
8
|
+
TaskArithmeticForLlama,
|
|
9
|
+
)
|
|
10
|
+
from .ties_merging_for_causallm import TiesMergingForCausalLM
|
|
@@ -18,16 +18,16 @@ log = get_rankzero_logger(__name__)
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@auto_register_config
|
|
21
|
-
class
|
|
21
|
+
class SimpleAverageForCausalLM(BaseAlgorithm):
|
|
22
22
|
R"""
|
|
23
23
|
A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
|
|
24
24
|
|
|
25
25
|
Examples:
|
|
26
|
-
The following example demonstrates how to use the `
|
|
26
|
+
The following example demonstrates how to use the `SimpleAverageForCausalLM` algorithm to merge Mistral models.
|
|
27
27
|
|
|
28
28
|
```bash
|
|
29
29
|
fusion_bench \
|
|
30
|
-
method=linear/
|
|
30
|
+
method=linear/simple_average_for_causallm \
|
|
31
31
|
method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
|
|
32
32
|
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
|
|
33
33
|
```
|
|
@@ -35,7 +35,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
35
35
|
|
|
36
36
|
def __init__(
|
|
37
37
|
self,
|
|
38
|
-
merge_backbone: bool,
|
|
38
|
+
merge_backbone: bool = False,
|
|
39
39
|
model_save_path: Optional[str] = None,
|
|
40
40
|
show_pbar: bool = False,
|
|
41
41
|
**kwargs,
|
|
@@ -81,3 +81,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
81
81
|
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
82
82
|
f.write(model_card_str)
|
|
83
83
|
return model
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
SimpleAverageForLlama = SimpleAverageForCausalLM
|
|
87
|
+
"""Alias for SimpleAverageForCausalLM"""
|
|
@@ -1,22 +1,27 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
3
4
|
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
|
-
from fusion_bench import timeit_context
|
|
7
|
+
from fusion_bench import auto_register_config, timeit_context
|
|
7
8
|
from fusion_bench.method import TaskArithmeticAlgorithm
|
|
8
9
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
9
10
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
11
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
10
12
|
|
|
11
13
|
log = logging.getLogger(__name__)
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
|
|
16
|
+
@auto_register_config
|
|
17
|
+
class TaskArithmeticForCausalLM(
|
|
18
|
+
TaskArithmeticAlgorithm,
|
|
19
|
+
):
|
|
15
20
|
R"""
|
|
16
21
|
Examples:
|
|
17
22
|
|
|
18
23
|
fusion_bench \
|
|
19
|
-
method=linear/
|
|
24
|
+
method=linear/task_arithmetic_for_causallm \
|
|
20
25
|
method.scaling_factor=0.3 \
|
|
21
26
|
method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
|
|
22
27
|
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
|
|
@@ -29,18 +34,14 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
|
|
|
29
34
|
def __init__(
|
|
30
35
|
self,
|
|
31
36
|
scaling_factor: float,
|
|
32
|
-
merge_backbone: bool,
|
|
37
|
+
merge_backbone: bool = False,
|
|
33
38
|
model_save_path: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
34
40
|
):
|
|
35
|
-
|
|
36
|
-
self.model_save_path = model_save_path
|
|
37
|
-
super().__init__(scaling_factor=scaling_factor)
|
|
41
|
+
super().__init__(scaling_factor=scaling_factor, **kwargs)
|
|
38
42
|
|
|
39
43
|
@override
|
|
40
44
|
def run(self, modelpool: CausalLMPool):
|
|
41
|
-
if self.model_save_path:
|
|
42
|
-
tokenizer = modelpool.load_tokenizer()
|
|
43
|
-
|
|
44
45
|
if self.merge_backbone:
|
|
45
46
|
assert modelpool.has_pretrained
|
|
46
47
|
backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
|
|
@@ -52,6 +53,15 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
|
|
|
52
53
|
|
|
53
54
|
if self.model_save_path is not None:
|
|
54
55
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
description = f"Merged model using task arithmetic with scaling factor {self.scaling_factor}."
|
|
57
|
+
modelpool.save_model(
|
|
58
|
+
model=model,
|
|
59
|
+
path=self.model_save_path,
|
|
60
|
+
save_tokenizer=True,
|
|
61
|
+
algorithm_config=self.config,
|
|
62
|
+
description=description,
|
|
63
|
+
)
|
|
57
64
|
return model
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
TaskArithmeticForLlama = TaskArithmeticForCausalLM
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from fusion_bench import auto_register_config, timeit_context
|
|
8
|
+
from fusion_bench.method import TiesMergingAlgorithm
|
|
9
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
|
+
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
11
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@auto_register_config
|
|
17
|
+
class TiesMergingForCausalLM(
|
|
18
|
+
TiesMergingAlgorithm,
|
|
19
|
+
):
|
|
20
|
+
R"""
|
|
21
|
+
TIES merging algorithm for CausalLM models.
|
|
22
|
+
|
|
23
|
+
This class extends the TiesMergingAlgorithm to work specifically with CausalLM models,
|
|
24
|
+
providing model saving capabilities and backbone merging support.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
_config_mapping = TiesMergingAlgorithm._config_mapping | {
|
|
28
|
+
"merge_backbone": "merge_backbone",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
scaling_factor: float,
|
|
34
|
+
threshold: float,
|
|
35
|
+
remove_keys: List[str] = None,
|
|
36
|
+
merge_func: str = "sum",
|
|
37
|
+
merge_backbone: bool = False,
|
|
38
|
+
model_save_path: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
super().__init__(
|
|
42
|
+
scaling_factor=scaling_factor,
|
|
43
|
+
threshold=threshold,
|
|
44
|
+
remove_keys=remove_keys,
|
|
45
|
+
merge_func=merge_func,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def run(self, modelpool: CausalLMPool):
|
|
51
|
+
if self.merge_backbone:
|
|
52
|
+
assert modelpool.has_pretrained
|
|
53
|
+
backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
|
|
54
|
+
model = modelpool.load_model("_pretrained_")
|
|
55
|
+
backbone_model = super().run(backbone_modelpool)
|
|
56
|
+
model.model.layers = backbone_model
|
|
57
|
+
else:
|
|
58
|
+
model = super().run(modelpool)
|
|
59
|
+
|
|
60
|
+
if self.model_save_path is not None:
|
|
61
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
62
|
+
description = f"Merged model using TIES merging with scaling factor {self.scaling_factor} and threshold {self.threshold}."
|
|
63
|
+
modelpool.save_model(
|
|
64
|
+
model=model,
|
|
65
|
+
path=self.model_save_path,
|
|
66
|
+
save_tokenizer=True,
|
|
67
|
+
algorithm_config=self.config,
|
|
68
|
+
description=description,
|
|
69
|
+
)
|
|
70
|
+
return model
|
|
@@ -89,7 +89,7 @@ class SimpleAverageAlgorithm(
|
|
|
89
89
|
modelpool = BaseModelPool(modelpool)
|
|
90
90
|
|
|
91
91
|
log.info(
|
|
92
|
-
f"Fusing models using simple average on {len(modelpool.model_names)} models."
|
|
92
|
+
f"Fusing models using simple average on {len(modelpool.model_names)} models. "
|
|
93
93
|
f"models: {modelpool.model_names}"
|
|
94
94
|
)
|
|
95
95
|
sd: Optional[StateDictType] = None
|
|
@@ -119,7 +119,7 @@ class SimpleAverageAlgorithm(
|
|
|
119
119
|
|
|
120
120
|
if isinstance(forward_model, LazyStateDict):
|
|
121
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
122
|
-
forward_model = forward_model.meta_module.to_empty(
|
|
122
|
+
forward_model = deepcopy(forward_model.meta_module).to_empty(
|
|
123
123
|
device=forward_model._device
|
|
124
124
|
)
|
|
125
125
|
result = forward_model.load_state_dict(sd, strict=False)
|
|
@@ -6,11 +6,20 @@ http://arxiv.org/abs/2212.04089
|
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import ( # noqa: F401
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Dict,
|
|
12
|
+
List,
|
|
13
|
+
Mapping,
|
|
14
|
+
Optional,
|
|
15
|
+
TypeVar,
|
|
16
|
+
Union,
|
|
17
|
+
)
|
|
10
18
|
|
|
11
19
|
import torch
|
|
12
20
|
from torch import nn
|
|
13
21
|
|
|
22
|
+
from fusion_bench import LazyStateDict
|
|
14
23
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
15
24
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
25
|
from fusion_bench.modelpool import BaseModelPool
|
|
@@ -21,6 +30,8 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
21
30
|
)
|
|
22
31
|
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
23
32
|
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from transformers import PreTrainedModel
|
|
24
35
|
log = logging.getLogger(__name__)
|
|
25
36
|
|
|
26
37
|
|
|
@@ -125,25 +136,39 @@ class TaskArithmeticAlgorithm(
|
|
|
125
136
|
with self.profile("merge weights"):
|
|
126
137
|
if task_vector is None:
|
|
127
138
|
task_vector = state_dict_sub(
|
|
128
|
-
model.state_dict(
|
|
129
|
-
pretrained_model.state_dict(
|
|
139
|
+
model.state_dict(),
|
|
140
|
+
pretrained_model.state_dict(),
|
|
130
141
|
)
|
|
131
142
|
else:
|
|
132
143
|
task_vector = state_dict_add(
|
|
133
144
|
task_vector,
|
|
134
145
|
state_dict_sub(
|
|
135
|
-
model.state_dict(
|
|
136
|
-
pretrained_model.state_dict(
|
|
146
|
+
model.state_dict(),
|
|
147
|
+
pretrained_model.state_dict(),
|
|
137
148
|
),
|
|
138
149
|
)
|
|
139
150
|
with self.profile("merge weights"):
|
|
140
151
|
# scale the task vector
|
|
141
152
|
task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
|
|
142
153
|
# add the task vector to the pretrained model
|
|
143
|
-
state_dict = state_dict_add(
|
|
144
|
-
pretrained_model.state_dict(keep_vars=True), task_vector
|
|
145
|
-
)
|
|
154
|
+
state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
|
|
146
155
|
|
|
147
156
|
self.print_profile_summary()
|
|
148
|
-
|
|
149
|
-
|
|
157
|
+
|
|
158
|
+
# apply state dict to model
|
|
159
|
+
if isinstance(pretrained_model, nn.Module):
|
|
160
|
+
model = pretrained_model
|
|
161
|
+
model.load_state_dict(state_dict)
|
|
162
|
+
elif isinstance(pretrained_model, LazyStateDict):
|
|
163
|
+
model = deepcopy(pretrained_model.meta_module)
|
|
164
|
+
model = model.to_empty(device=pretrained_model._device)
|
|
165
|
+
result = model.load_state_dict(state_dict, strict=False)
|
|
166
|
+
if result.unexpected_keys:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
169
|
+
)
|
|
170
|
+
if result.missing_keys:
|
|
171
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
172
|
+
else:
|
|
173
|
+
raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
|
|
174
|
+
return model
|
|
@@ -9,11 +9,14 @@ Overview of Ties-Merging:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
+
from copy import deepcopy
|
|
12
13
|
from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
14
|
|
|
14
15
|
import torch
|
|
15
16
|
from torch import Tensor, nn
|
|
17
|
+
from transformers import PreTrainedModel
|
|
16
18
|
|
|
19
|
+
from fusion_bench import LazyStateDict
|
|
17
20
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
21
|
from fusion_bench.method import BaseAlgorithm
|
|
19
22
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
@@ -98,12 +101,25 @@ class TiesMergingAlgorithm(
|
|
|
98
101
|
merge_func=merge_func,
|
|
99
102
|
)
|
|
100
103
|
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
101
|
-
|
|
104
|
+
state_dict = vector_to_state_dict(
|
|
102
105
|
merged_check, ptm_check, remove_keys=remove_keys
|
|
103
106
|
)
|
|
104
|
-
|
|
105
|
-
# Load the merged state dict into the pretrained model
|
|
106
|
-
pretrained_model.load_state_dict(merged_state_dict)
|
|
107
|
-
|
|
108
107
|
self.print_profile_summary()
|
|
109
|
-
|
|
108
|
+
|
|
109
|
+
# apply state dict to model
|
|
110
|
+
if isinstance(pretrained_model, nn.Module):
|
|
111
|
+
model = pretrained_model
|
|
112
|
+
model.load_state_dict(state_dict)
|
|
113
|
+
elif isinstance(pretrained_model, LazyStateDict):
|
|
114
|
+
model = deepcopy(pretrained_model.meta_module)
|
|
115
|
+
model = model.to_empty(device=pretrained_model._device)
|
|
116
|
+
result = model.load_state_dict(state_dict, strict=False)
|
|
117
|
+
if result.unexpected_keys:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
120
|
+
)
|
|
121
|
+
if result.missing_keys:
|
|
122
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
123
|
+
else:
|
|
124
|
+
raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
|
|
125
|
+
return model
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .wudi import WUDIMerging, wudi_merging
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
|
|
3
|
+
Arxiv: http://arxiv.org/abs/2503.08099
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
12
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
13
|
+
from fusion_bench.utils import timeit_context
|
|
14
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def wudi_merging(
|
|
18
|
+
task_vectors: List[torch.Tensor],
|
|
19
|
+
accelerator="cuda",
|
|
20
|
+
iter_num: int = 300,
|
|
21
|
+
exclude_keys: List[str] = None,
|
|
22
|
+
):
|
|
23
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
24
|
+
|
|
25
|
+
with timeit_context("WUDI Merging"):
|
|
26
|
+
new_vector = {}
|
|
27
|
+
for key in tqdm(task_vectors[0], desc="WUDI Merging", leave=False):
|
|
28
|
+
tqdm.write(f"key: {key}")
|
|
29
|
+
original_device = task_vectors[0][key].device
|
|
30
|
+
tvs = torch.stack(
|
|
31
|
+
[
|
|
32
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
33
|
+
for task_vector in task_vectors
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
num_tvs = len(tvs)
|
|
37
|
+
new_vector[key] = torch.nn.Parameter(torch.sum(tvs, dim=0))
|
|
38
|
+
|
|
39
|
+
if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
|
|
40
|
+
optimizer = torch.optim.Adam([new_vector[key]], lr=1e-5, weight_decay=0)
|
|
41
|
+
l2_norms = torch.square(
|
|
42
|
+
torch.norm(tvs.reshape(tvs.shape[0], -1), p=2, dim=-1)
|
|
43
|
+
)
|
|
44
|
+
for i in tqdm(
|
|
45
|
+
range(iter_num),
|
|
46
|
+
):
|
|
47
|
+
disturbing_vectors = new_vector[key].unsqueeze(0) - tvs
|
|
48
|
+
product = torch.matmul(disturbing_vectors, tvs.transpose(1, 2))
|
|
49
|
+
loss = torch.sum(
|
|
50
|
+
torch.square(product) / l2_norms.unsqueeze(-1).unsqueeze(-1)
|
|
51
|
+
)
|
|
52
|
+
optimizer.zero_grad()
|
|
53
|
+
loss.backward()
|
|
54
|
+
optimizer.step()
|
|
55
|
+
else:
|
|
56
|
+
new_vector[key] = new_vector[key] / num_tvs
|
|
57
|
+
new_vector[key] = new_vector[key].to(
|
|
58
|
+
device=original_device, non_blocking=True
|
|
59
|
+
)
|
|
60
|
+
return new_vector
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@auto_register_config
|
|
64
|
+
class WUDIMerging(
|
|
65
|
+
LightningFabricMixin,
|
|
66
|
+
BaseAlgorithm,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
iter_num: int,
|
|
75
|
+
exclude_keys: List[str] = None,
|
|
76
|
+
**kwargs,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
|
|
80
|
+
def run(self, modelpool: BaseModelPool):
|
|
81
|
+
# load the pretrained model and the task vectors of all the finetuned models
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
84
|
+
task_vectors = []
|
|
85
|
+
for model_name in modelpool.model_names:
|
|
86
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
87
|
+
task_vectors.append(
|
|
88
|
+
state_dict_sub(
|
|
89
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
del finetuned_model # free memory
|
|
93
|
+
|
|
94
|
+
merged_tv = wudi_merging(
|
|
95
|
+
task_vectors,
|
|
96
|
+
accelerator=self.fabric.device,
|
|
97
|
+
iter_num=self.iter_num,
|
|
98
|
+
exclude_keys=self.exclude_keys,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
pretrained_model.load_state_dict(
|
|
102
|
+
state_dict_add(pretrained_model.state_dict(), merged_tv)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return pretrained_model
|
|
@@ -100,6 +100,10 @@ class LightningFabricMixin:
|
|
|
100
100
|
self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
|
|
101
101
|
return self._fabric_instance
|
|
102
102
|
|
|
103
|
+
@fabric.setter
|
|
104
|
+
def fabric(self, instance: L.Fabric):
|
|
105
|
+
self._fabric_instance = instance
|
|
106
|
+
|
|
103
107
|
@property
|
|
104
108
|
def log_dir(self):
|
|
105
109
|
"""
|