fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,27 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
3
4
|
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
|
-
from fusion_bench import timeit_context
|
|
7
|
+
from fusion_bench import auto_register_config, timeit_context
|
|
7
8
|
from fusion_bench.method import TaskArithmeticAlgorithm
|
|
8
9
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
9
10
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
11
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
10
12
|
|
|
11
13
|
log = logging.getLogger(__name__)
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
|
|
16
|
+
@auto_register_config
|
|
17
|
+
class TaskArithmeticForCausalLM(
|
|
18
|
+
TaskArithmeticAlgorithm,
|
|
19
|
+
):
|
|
15
20
|
R"""
|
|
16
21
|
Examples:
|
|
17
22
|
|
|
18
23
|
fusion_bench \
|
|
19
|
-
method=linear/
|
|
24
|
+
method=linear/task_arithmetic_for_causallm \
|
|
20
25
|
method.scaling_factor=0.3 \
|
|
21
26
|
method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
|
|
22
27
|
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
|
|
@@ -29,18 +34,14 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
|
|
|
29
34
|
def __init__(
|
|
30
35
|
self,
|
|
31
36
|
scaling_factor: float,
|
|
32
|
-
merge_backbone: bool,
|
|
37
|
+
merge_backbone: bool = False,
|
|
33
38
|
model_save_path: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
34
40
|
):
|
|
35
|
-
|
|
36
|
-
self.model_save_path = model_save_path
|
|
37
|
-
super().__init__(scaling_factor=scaling_factor)
|
|
41
|
+
super().__init__(scaling_factor=scaling_factor, **kwargs)
|
|
38
42
|
|
|
39
43
|
@override
|
|
40
44
|
def run(self, modelpool: CausalLMPool):
|
|
41
|
-
if self.model_save_path:
|
|
42
|
-
tokenizer = modelpool.load_tokenizer()
|
|
43
|
-
|
|
44
45
|
if self.merge_backbone:
|
|
45
46
|
assert modelpool.has_pretrained
|
|
46
47
|
backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
|
|
@@ -52,6 +53,15 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
|
|
|
52
53
|
|
|
53
54
|
if self.model_save_path is not None:
|
|
54
55
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
description = f"Merged model using task arithmetic with scaling factor {self.scaling_factor}."
|
|
57
|
+
modelpool.save_model(
|
|
58
|
+
model=model,
|
|
59
|
+
path=self.model_save_path,
|
|
60
|
+
save_tokenizer=True,
|
|
61
|
+
algorithm_config=self.config,
|
|
62
|
+
description=description,
|
|
63
|
+
)
|
|
57
64
|
return model
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
TaskArithmeticForLlama = TaskArithmeticForCausalLM
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from fusion_bench import auto_register_config, timeit_context
|
|
8
|
+
from fusion_bench.method import TiesMergingAlgorithm
|
|
9
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
|
+
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
11
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@auto_register_config
|
|
17
|
+
class TiesMergingForCausalLM(
|
|
18
|
+
TiesMergingAlgorithm,
|
|
19
|
+
):
|
|
20
|
+
R"""
|
|
21
|
+
TIES merging algorithm for CausalLM models.
|
|
22
|
+
|
|
23
|
+
This class extends the TiesMergingAlgorithm to work specifically with CausalLM models,
|
|
24
|
+
providing model saving capabilities and backbone merging support.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
_config_mapping = TiesMergingAlgorithm._config_mapping | {
|
|
28
|
+
"merge_backbone": "merge_backbone",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
scaling_factor: float,
|
|
34
|
+
threshold: float,
|
|
35
|
+
remove_keys: List[str] = None,
|
|
36
|
+
merge_func: str = "sum",
|
|
37
|
+
merge_backbone: bool = False,
|
|
38
|
+
model_save_path: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
super().__init__(
|
|
42
|
+
scaling_factor=scaling_factor,
|
|
43
|
+
threshold=threshold,
|
|
44
|
+
remove_keys=remove_keys,
|
|
45
|
+
merge_func=merge_func,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def run(self, modelpool: CausalLMPool):
|
|
51
|
+
if self.merge_backbone:
|
|
52
|
+
assert modelpool.has_pretrained
|
|
53
|
+
backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
|
|
54
|
+
model = modelpool.load_model("_pretrained_")
|
|
55
|
+
backbone_model = super().run(backbone_modelpool)
|
|
56
|
+
model.model.layers = backbone_model
|
|
57
|
+
else:
|
|
58
|
+
model = super().run(modelpool)
|
|
59
|
+
|
|
60
|
+
if self.model_save_path is not None:
|
|
61
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
62
|
+
description = f"Merged model using TIES merging with scaling factor {self.scaling_factor} and threshold {self.threshold}."
|
|
63
|
+
modelpool.save_model(
|
|
64
|
+
model=model,
|
|
65
|
+
path=self.model_save_path,
|
|
66
|
+
save_tokenizer=True,
|
|
67
|
+
algorithm_config=self.config,
|
|
68
|
+
description=description,
|
|
69
|
+
)
|
|
70
|
+
return model
|
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -87,6 +87,7 @@ class OPCMForCLIP(
|
|
|
87
87
|
# get the average model
|
|
88
88
|
with self.profile("loading model"):
|
|
89
89
|
merged_model = modelpool.load_model(model_names[0])
|
|
90
|
+
assert merged_model is not None, "Failed to load the first model"
|
|
90
91
|
|
|
91
92
|
if self.evaluate_on_every_step:
|
|
92
93
|
with self.profile("evaluating model"):
|
|
@@ -89,7 +89,7 @@ class SimpleAverageAlgorithm(
|
|
|
89
89
|
modelpool = BaseModelPool(modelpool)
|
|
90
90
|
|
|
91
91
|
log.info(
|
|
92
|
-
f"Fusing models using simple average on {len(modelpool.model_names)} models."
|
|
92
|
+
f"Fusing models using simple average on {len(modelpool.model_names)} models. "
|
|
93
93
|
f"models: {modelpool.model_names}"
|
|
94
94
|
)
|
|
95
95
|
sd: Optional[StateDictType] = None
|
|
@@ -119,7 +119,7 @@ class SimpleAverageAlgorithm(
|
|
|
119
119
|
|
|
120
120
|
if isinstance(forward_model, LazyStateDict):
|
|
121
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
122
|
-
forward_model = forward_model.meta_module.to_empty(
|
|
122
|
+
forward_model = deepcopy(forward_model.meta_module).to_empty(
|
|
123
123
|
device=forward_model._device
|
|
124
124
|
)
|
|
125
125
|
result = forward_model.load_state_dict(sd, strict=False)
|
|
@@ -15,7 +15,7 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
15
15
|
state_dict_add,
|
|
16
16
|
state_dict_binary_mask,
|
|
17
17
|
state_dict_diff_abs,
|
|
18
|
-
|
|
18
|
+
state_dict_hadamard_product,
|
|
19
19
|
state_dict_mul,
|
|
20
20
|
state_dict_sub,
|
|
21
21
|
state_dict_sum,
|
|
@@ -111,7 +111,7 @@ class TallMaskTaskArithmeticAlgorithm(
|
|
|
111
111
|
|
|
112
112
|
with self.profile("compress and retrieve"):
|
|
113
113
|
for model_name in modelpool.model_names:
|
|
114
|
-
retrieved_task_vector =
|
|
114
|
+
retrieved_task_vector = state_dict_hadamard_product(
|
|
115
115
|
tall_masks[model_name], multi_task_vector
|
|
116
116
|
)
|
|
117
117
|
retrieved_state_dict = state_dict_add(
|
|
@@ -6,11 +6,20 @@ http://arxiv.org/abs/2212.04089
|
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import ( # noqa: F401
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Dict,
|
|
12
|
+
List,
|
|
13
|
+
Mapping,
|
|
14
|
+
Optional,
|
|
15
|
+
TypeVar,
|
|
16
|
+
Union,
|
|
17
|
+
)
|
|
10
18
|
|
|
11
19
|
import torch
|
|
12
20
|
from torch import nn
|
|
13
21
|
|
|
22
|
+
from fusion_bench import LazyStateDict
|
|
14
23
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
15
24
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
25
|
from fusion_bench.modelpool import BaseModelPool
|
|
@@ -21,6 +30,8 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
21
30
|
)
|
|
22
31
|
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
23
32
|
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from transformers import PreTrainedModel
|
|
24
35
|
log = logging.getLogger(__name__)
|
|
25
36
|
|
|
26
37
|
|
|
@@ -125,25 +136,39 @@ class TaskArithmeticAlgorithm(
|
|
|
125
136
|
with self.profile("merge weights"):
|
|
126
137
|
if task_vector is None:
|
|
127
138
|
task_vector = state_dict_sub(
|
|
128
|
-
model.state_dict(
|
|
129
|
-
pretrained_model.state_dict(
|
|
139
|
+
model.state_dict(),
|
|
140
|
+
pretrained_model.state_dict(),
|
|
130
141
|
)
|
|
131
142
|
else:
|
|
132
143
|
task_vector = state_dict_add(
|
|
133
144
|
task_vector,
|
|
134
145
|
state_dict_sub(
|
|
135
|
-
model.state_dict(
|
|
136
|
-
pretrained_model.state_dict(
|
|
146
|
+
model.state_dict(),
|
|
147
|
+
pretrained_model.state_dict(),
|
|
137
148
|
),
|
|
138
149
|
)
|
|
139
150
|
with self.profile("merge weights"):
|
|
140
151
|
# scale the task vector
|
|
141
152
|
task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
|
|
142
153
|
# add the task vector to the pretrained model
|
|
143
|
-
state_dict = state_dict_add(
|
|
144
|
-
pretrained_model.state_dict(keep_vars=True), task_vector
|
|
145
|
-
)
|
|
154
|
+
state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
|
|
146
155
|
|
|
147
156
|
self.print_profile_summary()
|
|
148
|
-
|
|
149
|
-
|
|
157
|
+
|
|
158
|
+
# apply state dict to model
|
|
159
|
+
if isinstance(pretrained_model, nn.Module):
|
|
160
|
+
model = pretrained_model
|
|
161
|
+
model.load_state_dict(state_dict)
|
|
162
|
+
elif isinstance(pretrained_model, LazyStateDict):
|
|
163
|
+
model = deepcopy(pretrained_model.meta_module)
|
|
164
|
+
model = model.to_empty(device=pretrained_model._device)
|
|
165
|
+
result = model.load_state_dict(state_dict, strict=False)
|
|
166
|
+
if result.unexpected_keys:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
169
|
+
)
|
|
170
|
+
if result.missing_keys:
|
|
171
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
172
|
+
else:
|
|
173
|
+
raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
|
|
174
|
+
return model
|
|
@@ -9,11 +9,14 @@ Overview of Ties-Merging:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
+
from copy import deepcopy
|
|
12
13
|
from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
14
|
|
|
14
15
|
import torch
|
|
15
16
|
from torch import Tensor, nn
|
|
17
|
+
from transformers import PreTrainedModel
|
|
16
18
|
|
|
19
|
+
from fusion_bench import LazyStateDict
|
|
17
20
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
21
|
from fusion_bench.method import BaseAlgorithm
|
|
19
22
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
@@ -98,12 +101,25 @@ class TiesMergingAlgorithm(
|
|
|
98
101
|
merge_func=merge_func,
|
|
99
102
|
)
|
|
100
103
|
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
101
|
-
|
|
104
|
+
state_dict = vector_to_state_dict(
|
|
102
105
|
merged_check, ptm_check, remove_keys=remove_keys
|
|
103
106
|
)
|
|
104
|
-
|
|
105
|
-
# Load the merged state dict into the pretrained model
|
|
106
|
-
pretrained_model.load_state_dict(merged_state_dict)
|
|
107
|
-
|
|
108
107
|
self.print_profile_summary()
|
|
109
|
-
|
|
108
|
+
|
|
109
|
+
# apply state dict to model
|
|
110
|
+
if isinstance(pretrained_model, nn.Module):
|
|
111
|
+
model = pretrained_model
|
|
112
|
+
model.load_state_dict(state_dict)
|
|
113
|
+
elif isinstance(pretrained_model, LazyStateDict):
|
|
114
|
+
model = deepcopy(pretrained_model.meta_module)
|
|
115
|
+
model = model.to_empty(device=pretrained_model._device)
|
|
116
|
+
result = model.load_state_dict(state_dict, strict=False)
|
|
117
|
+
if result.unexpected_keys:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
120
|
+
)
|
|
121
|
+
if result.missing_keys:
|
|
122
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
123
|
+
else:
|
|
124
|
+
raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
|
|
125
|
+
return model
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .wudi import WUDIMerging, wudi_merging
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
|
|
3
|
+
Arxiv: http://arxiv.org/abs/2503.08099
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
12
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
13
|
+
from fusion_bench.utils import timeit_context
|
|
14
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def wudi_merging(
|
|
18
|
+
task_vectors: List[torch.Tensor],
|
|
19
|
+
accelerator="cuda",
|
|
20
|
+
iter_num: int = 300,
|
|
21
|
+
exclude_keys: List[str] = None,
|
|
22
|
+
):
|
|
23
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
24
|
+
|
|
25
|
+
with timeit_context("WUDI Merging"):
|
|
26
|
+
new_vector = {}
|
|
27
|
+
for key in tqdm(task_vectors[0], desc="WUDI Merging", leave=False):
|
|
28
|
+
tqdm.write(f"key: {key}")
|
|
29
|
+
original_device = task_vectors[0][key].device
|
|
30
|
+
tvs = torch.stack(
|
|
31
|
+
[
|
|
32
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
33
|
+
for task_vector in task_vectors
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
num_tvs = len(tvs)
|
|
37
|
+
new_vector[key] = torch.nn.Parameter(torch.sum(tvs, dim=0))
|
|
38
|
+
|
|
39
|
+
if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
|
|
40
|
+
optimizer = torch.optim.Adam([new_vector[key]], lr=1e-5, weight_decay=0)
|
|
41
|
+
l2_norms = torch.square(
|
|
42
|
+
torch.norm(tvs.reshape(tvs.shape[0], -1), p=2, dim=-1)
|
|
43
|
+
)
|
|
44
|
+
for i in tqdm(
|
|
45
|
+
range(iter_num),
|
|
46
|
+
):
|
|
47
|
+
disturbing_vectors = new_vector[key].unsqueeze(0) - tvs
|
|
48
|
+
product = torch.matmul(disturbing_vectors, tvs.transpose(1, 2))
|
|
49
|
+
loss = torch.sum(
|
|
50
|
+
torch.square(product) / l2_norms.unsqueeze(-1).unsqueeze(-1)
|
|
51
|
+
)
|
|
52
|
+
optimizer.zero_grad()
|
|
53
|
+
loss.backward()
|
|
54
|
+
optimizer.step()
|
|
55
|
+
else:
|
|
56
|
+
new_vector[key] = new_vector[key] / num_tvs
|
|
57
|
+
new_vector[key] = new_vector[key].to(
|
|
58
|
+
device=original_device, non_blocking=True
|
|
59
|
+
)
|
|
60
|
+
return new_vector
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@auto_register_config
|
|
64
|
+
class WUDIMerging(
|
|
65
|
+
LightningFabricMixin,
|
|
66
|
+
BaseAlgorithm,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
iter_num: int,
|
|
75
|
+
exclude_keys: List[str] = None,
|
|
76
|
+
**kwargs,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
|
|
80
|
+
def run(self, modelpool: BaseModelPool):
|
|
81
|
+
# load the pretrained model and the task vectors of all the finetuned models
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
84
|
+
task_vectors = []
|
|
85
|
+
for model_name in modelpool.model_names:
|
|
86
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
87
|
+
task_vectors.append(
|
|
88
|
+
state_dict_sub(
|
|
89
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
del finetuned_model # free memory
|
|
93
|
+
|
|
94
|
+
merged_tv = wudi_merging(
|
|
95
|
+
task_vectors,
|
|
96
|
+
accelerator=self.fabric.device,
|
|
97
|
+
iter_num=self.iter_num,
|
|
98
|
+
exclude_keys=self.exclude_keys,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
pretrained_model.load_state_dict(
|
|
102
|
+
state_dict_add(pretrained_model.state_dict(), merged_tv)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return pretrained_model
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ _import_structure = {
|
|
|
11
11
|
"hydra_config": ["HydraConfigMixin"],
|
|
12
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
13
|
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
14
|
+
"pyinstrument": ["PyinstrumentProfilerMixin"],
|
|
14
15
|
"serialization": [
|
|
15
16
|
"BaseYAMLSerializable",
|
|
16
17
|
"YAMLSerializationMixin",
|
|
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|
|
25
26
|
from .hydra_config import HydraConfigMixin
|
|
26
27
|
from .lightning_fabric import LightningFabricMixin
|
|
27
28
|
from .openclip_classification import OpenCLIPClassificationMixin
|
|
29
|
+
from .pyinstrument import PyinstrumentProfilerMixin
|
|
28
30
|
from .serialization import (
|
|
29
31
|
BaseYAMLSerializable,
|
|
30
32
|
YAMLSerializationMixin,
|
|
@@ -100,6 +100,10 @@ class LightningFabricMixin:
|
|
|
100
100
|
self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
|
|
101
101
|
return self._fabric_instance
|
|
102
102
|
|
|
103
|
+
@fabric.setter
|
|
104
|
+
def fabric(self, instance: L.Fabric):
|
|
105
|
+
self._fabric_instance = instance
|
|
106
|
+
|
|
103
107
|
@property
|
|
104
108
|
def log_dir(self):
|
|
105
109
|
"""
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Generator, Optional, Union
|
|
4
|
+
|
|
5
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
6
|
+
from pyinstrument import Profiler
|
|
7
|
+
|
|
8
|
+
__all__ = ["PyinstrumentProfilerMixin"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PyinstrumentProfilerMixin:
|
|
12
|
+
"""
|
|
13
|
+
A mixin class that provides statistical profiling capabilities using pyinstrument.
|
|
14
|
+
|
|
15
|
+
This mixin allows for easy profiling of code blocks using a context manager.
|
|
16
|
+
It provides methods to start and stop profiling actions, save profiling results
|
|
17
|
+
to files, and print profiling summaries.
|
|
18
|
+
|
|
19
|
+
Note:
|
|
20
|
+
This mixin requires the `pyinstrument` package to be installed.
|
|
21
|
+
If not available, an ImportError will be raised when importing this module.
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
class MyClass(PyinstrumentProfilerMixin):
|
|
27
|
+
def do_something(self):
|
|
28
|
+
with self.profile("work"):
|
|
29
|
+
# do some work here
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
# save the profiling results
|
|
33
|
+
self.save_profile_report("profile_report.html")
|
|
34
|
+
|
|
35
|
+
# or print the summary
|
|
36
|
+
self.print_profile_summary()
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
_profiler (Profiler): An instance of the pyinstrument Profiler class.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
_profiler: Optional[Profiler] = None
|
|
44
|
+
_is_profiling: bool = False
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def profiler(self) -> Optional[Profiler]:
|
|
48
|
+
"""Get the profiler instance, creating it if necessary."""
|
|
49
|
+
if self._profiler is None:
|
|
50
|
+
self._profiler = Profiler()
|
|
51
|
+
return self._profiler
|
|
52
|
+
|
|
53
|
+
@contextmanager
|
|
54
|
+
def profile(self, action_name: Optional[str] = None) -> Generator:
|
|
55
|
+
"""
|
|
56
|
+
Context manager for profiling a code block.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
action_name: Optional name for the profiling action (for logging purposes).
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
with self.profile("expensive_operation"):
|
|
65
|
+
# do some expensive work here
|
|
66
|
+
expensive_function()
|
|
67
|
+
```
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
self.start_profile(action_name)
|
|
71
|
+
yield action_name
|
|
72
|
+
finally:
|
|
73
|
+
self.stop_profile(action_name)
|
|
74
|
+
|
|
75
|
+
def start_profile(self, action_name: Optional[str] = None):
|
|
76
|
+
"""
|
|
77
|
+
Start profiling.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
action_name: Optional name for the profiling action.
|
|
81
|
+
"""
|
|
82
|
+
if self._is_profiling:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
self.profiler.start()
|
|
86
|
+
self._is_profiling = True
|
|
87
|
+
if action_name:
|
|
88
|
+
print(f"Started profiling: {action_name}")
|
|
89
|
+
|
|
90
|
+
def stop_profile(self, action_name: Optional[str] = None):
|
|
91
|
+
"""
|
|
92
|
+
Stop profiling.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
action_name: Optional name for the profiling action.
|
|
96
|
+
"""
|
|
97
|
+
if not self._is_profiling:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
self.profiler.stop()
|
|
101
|
+
self._is_profiling = False
|
|
102
|
+
if action_name:
|
|
103
|
+
print(f"Stopped profiling: {action_name}")
|
|
104
|
+
|
|
105
|
+
@rank_zero_only
|
|
106
|
+
def print_profile_summary(
|
|
107
|
+
self, title: Optional[str] = None, unicode: bool = True, color: bool = True
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
Print a summary of the profiling results.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
title: Optional title to print before the summary.
|
|
114
|
+
unicode: Whether to use unicode characters in the output.
|
|
115
|
+
color: Whether to use color in the output.
|
|
116
|
+
"""
|
|
117
|
+
if self.profiler is None:
|
|
118
|
+
print("No profiling data available.")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if title is not None:
|
|
122
|
+
print(title)
|
|
123
|
+
|
|
124
|
+
print(self.profiler.output_text(unicode=unicode, color=color))
|
|
125
|
+
|
|
126
|
+
@rank_zero_only
|
|
127
|
+
def save_profile_report(
|
|
128
|
+
self,
|
|
129
|
+
output_path: Union[str, Path] = "profile_report.html",
|
|
130
|
+
format: str = "html",
|
|
131
|
+
title: Optional[str] = None,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Save the profiling results to a file.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
output_path: Path where to save the profiling report.
|
|
138
|
+
format: Output format ('html', or 'text').
|
|
139
|
+
title: Optional title for the report.
|
|
140
|
+
"""
|
|
141
|
+
if self.profiler is None:
|
|
142
|
+
print("No profiling data available.")
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
output_path = Path(output_path)
|
|
146
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
|
|
148
|
+
if format.lower() == "html":
|
|
149
|
+
content = self.profiler.output_html()
|
|
150
|
+
elif format.lower() == "text":
|
|
151
|
+
content = self.profiler.output_text(unicode=True, color=False)
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(f"Unsupported format: {format}. Use 'html', or 'text'.")
|
|
154
|
+
|
|
155
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
156
|
+
f.write(content)
|
|
157
|
+
|
|
158
|
+
print(f"Profile report saved to: {output_path}")
|
|
159
|
+
|
|
160
|
+
def reset_profile(self):
|
|
161
|
+
"""Reset the profiler to start fresh."""
|
|
162
|
+
if self._is_profiling:
|
|
163
|
+
self.stop_profile()
|
|
164
|
+
|
|
165
|
+
self._profiler = None
|
|
166
|
+
|
|
167
|
+
def __del__(self):
|
|
168
|
+
"""Cleanup when the object is destroyed."""
|
|
169
|
+
if self._is_profiling:
|
|
170
|
+
self.stop_profile()
|
|
171
|
+
|
|
172
|
+
if self._profiler is not None:
|
|
173
|
+
del self._profiler
|
|
174
|
+
self._profiler = None
|