fusion-bench 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. fusion_bench/method/__init__.py +8 -0
  2. fusion_bench/method/ensemble.py +17 -2
  3. fusion_bench/method/linear/__init__.py +6 -2
  4. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  5. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  6. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  7. fusion_bench/method/simple_average.py +2 -2
  8. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  9. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  10. fusion_bench/method/wudi/__init__.py +1 -0
  11. fusion_bench/method/wudi/wudi.py +105 -0
  12. fusion_bench/mixins/lightning_fabric.py +4 -0
  13. fusion_bench/mixins/serialization.py +25 -78
  14. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  15. fusion_bench/models/hf_clip.py +4 -0
  16. fusion_bench/models/hf_utils.py +2 -1
  17. fusion_bench/models/model_card_templates/default.md +8 -1
  18. fusion_bench/models/wrappers/ensemble.py +136 -7
  19. fusion_bench/scripts/cli.py +2 -2
  20. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  21. fusion_bench/utils/devices.py +30 -8
  22. fusion_bench/utils/lazy_state_dict.py +3 -0
  23. fusion_bench/utils/rich_utils.py +7 -3
  24. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  25. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +37 -30
  26. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  27. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  28. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  29. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  30. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  31. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  32. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  33. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  34. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  35. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  36. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  37. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  38. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,12 @@ _import_structure = {
26
26
  "linear": [
27
27
  "ExPOAlgorithm",
28
28
  "ExPOAlgorithmForLlama",
29
+ "SimpleAverageForCausalLM",
29
30
  "SimpleAverageForLlama",
31
+ "TaskArithmeticForCausalLM",
30
32
  "TaskArithmeticForLlama",
31
33
  "LinearInterpolationAlgorithm",
34
+ "TiesMergingForCausalLM",
32
35
  ],
33
36
  "slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
34
37
  "simple_average": ["SimpleAverageAlgorithm"],
@@ -72,6 +75,7 @@ _import_structure = {
72
75
  "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
73
76
  "tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
74
77
  "model_stock": ["ModelStock"],
78
+ "wudi": ["wudi_merging", "WUDIMerging"],
75
79
  # plug-and-play model merging methods
76
80
  "concrete_subspace": [
77
81
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -184,8 +188,11 @@ if TYPE_CHECKING:
184
188
  ExPOAlgorithm,
185
189
  ExPOAlgorithmForLlama,
186
190
  LinearInterpolationAlgorithm,
191
+ SimpleAverageForCausalLM,
187
192
  SimpleAverageForLlama,
193
+ TaskArithmeticForCausalLM,
188
194
  TaskArithmeticForLlama,
195
+ TiesMergingForCausalLM,
189
196
  )
190
197
  from .lm_finetune import *
191
198
  from .mixture_of_experts import (
@@ -238,6 +245,7 @@ if TYPE_CHECKING:
238
245
  FlanT5WeightEnsemblingMoEAlgorithm,
239
246
  )
240
247
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
248
+ from .wudi import WUDIMerging, wudi_merging
241
249
 
242
250
  else:
243
251
  sys.modules[__name__] = LazyImporter(
@@ -17,7 +17,21 @@ from fusion_bench.models.wrappers.ensemble import (
17
17
  log = logging.getLogger(__name__)
18
18
 
19
19
 
20
+ @auto_register_config
20
21
  class SimpleEnsembleAlgorithm(BaseAlgorithm):
22
+ def __init__(
23
+ self,
24
+ device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
25
+ **kwargs,
26
+ ):
27
+ """
28
+ Initializes the SimpleEnsembleAlgorithm with an optional device map.
29
+
30
+ Args:
31
+ device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
32
+ """
33
+ super().__init__(**kwargs)
34
+
21
35
  @torch.no_grad()
22
36
  def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
23
37
  """
@@ -30,9 +44,10 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
30
44
  EnsembleModule: The ensembled model.
31
45
  """
32
46
  log.info(f"Running ensemble algorithm with {len(modelpool)} models")
33
-
34
47
  models = [modelpool.load_model(m) for m in modelpool.model_names]
35
- ensemble = EnsembleModule(models=models)
48
+
49
+ log.info("creating ensemble module")
50
+ ensemble = EnsembleModule(models=models, device_map=self.device_map)
36
51
  return ensemble
37
52
 
38
53
 
@@ -2,5 +2,9 @@
2
2
  from .expo import ExPOAlgorithm
3
3
  from .linear_interpolation import LinearInterpolationAlgorithm
4
4
  from .llama_expo import ExPOAlgorithmForLlama
5
- from .simple_average_for_llama import SimpleAverageForLlama
6
- from .task_arithmetic_for_llama import TaskArithmeticForLlama
5
+ from .simple_average_for_causallm import SimpleAverageForCausalLM, SimpleAverageForLlama
6
+ from .task_arithmetic_for_causallm import (
7
+ TaskArithmeticForCausalLM,
8
+ TaskArithmeticForLlama,
9
+ )
10
+ from .ties_merging_for_causallm import TiesMergingForCausalLM
@@ -18,16 +18,16 @@ log = get_rankzero_logger(__name__)
18
18
 
19
19
 
20
20
  @auto_register_config
21
- class SimpleAverageForLlama(BaseAlgorithm):
21
+ class SimpleAverageForCausalLM(BaseAlgorithm):
22
22
  R"""
23
23
  A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
24
24
 
25
25
  Examples:
26
- The following example demonstrates how to use the `SimpleAverageForLlama` algorithm to merge Mistral models.
26
+ The following example demonstrates how to use the `SimpleAverageForCausalLM` algorithm to merge Mistral models.
27
27
 
28
28
  ```bash
29
29
  fusion_bench \
30
- method=linear/simple_average_for_llama \
30
+ method=linear/simple_average_for_causallm \
31
31
  method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
32
32
  modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
33
33
  ```
@@ -35,7 +35,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
35
35
 
36
36
  def __init__(
37
37
  self,
38
- merge_backbone: bool,
38
+ merge_backbone: bool = False,
39
39
  model_save_path: Optional[str] = None,
40
40
  show_pbar: bool = False,
41
41
  **kwargs,
@@ -81,3 +81,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
81
81
  with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
82
82
  f.write(model_card_str)
83
83
  return model
84
+
85
+
86
+ SimpleAverageForLlama = SimpleAverageForCausalLM
87
+ """Alias for SimpleAverageForCausalLM"""
@@ -1,22 +1,27 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
3
4
 
4
5
  from typing_extensions import override
5
6
 
6
- from fusion_bench import timeit_context
7
+ from fusion_bench import auto_register_config, timeit_context
7
8
  from fusion_bench.method import TaskArithmeticAlgorithm
8
9
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
9
10
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
11
+ from fusion_bench.models.hf_utils import create_default_model_card
10
12
 
11
13
  log = logging.getLogger(__name__)
12
14
 
13
15
 
14
- class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
16
+ @auto_register_config
17
+ class TaskArithmeticForCausalLM(
18
+ TaskArithmeticAlgorithm,
19
+ ):
15
20
  R"""
16
21
  Examples:
17
22
 
18
23
  fusion_bench \
19
- method=linear/task_arithmetic_for_llama \
24
+ method=linear/task_arithmetic_for_causallm \
20
25
  method.scaling_factor=0.3 \
21
26
  method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
22
27
  modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
@@ -29,18 +34,14 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
29
34
  def __init__(
30
35
  self,
31
36
  scaling_factor: float,
32
- merge_backbone: bool,
37
+ merge_backbone: bool = False,
33
38
  model_save_path: Optional[str] = None,
39
+ **kwargs,
34
40
  ):
35
- self.merge_backbone = merge_backbone
36
- self.model_save_path = model_save_path
37
- super().__init__(scaling_factor=scaling_factor)
41
+ super().__init__(scaling_factor=scaling_factor, **kwargs)
38
42
 
39
43
  @override
40
44
  def run(self, modelpool: CausalLMPool):
41
- if self.model_save_path:
42
- tokenizer = modelpool.load_tokenizer()
43
-
44
45
  if self.merge_backbone:
45
46
  assert modelpool.has_pretrained
46
47
  backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
@@ -52,6 +53,15 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
52
53
 
53
54
  if self.model_save_path is not None:
54
55
  with timeit_context(f"Saving the model to {self.model_save_path}"):
55
- tokenizer.save_pretrained(self.model_save_path)
56
- model.save_pretrained(self.model_save_path)
56
+ description = f"Merged model using task arithmetic with scaling factor {self.scaling_factor}."
57
+ modelpool.save_model(
58
+ model=model,
59
+ path=self.model_save_path,
60
+ save_tokenizer=True,
61
+ algorithm_config=self.config,
62
+ description=description,
63
+ )
57
64
  return model
65
+
66
+
67
+ TaskArithmeticForLlama = TaskArithmeticForCausalLM
@@ -0,0 +1,70 @@
1
+ import logging
2
+ import os
3
+ from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
4
+
5
+ from typing_extensions import override
6
+
7
+ from fusion_bench import auto_register_config, timeit_context
8
+ from fusion_bench.method import TiesMergingAlgorithm
9
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
10
+ from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
11
+ from fusion_bench.models.hf_utils import create_default_model_card
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @auto_register_config
17
+ class TiesMergingForCausalLM(
18
+ TiesMergingAlgorithm,
19
+ ):
20
+ R"""
21
+ TIES merging algorithm for CausalLM models.
22
+
23
+ This class extends the TiesMergingAlgorithm to work specifically with CausalLM models,
24
+ providing model saving capabilities and backbone merging support.
25
+ """
26
+
27
+ _config_mapping = TiesMergingAlgorithm._config_mapping | {
28
+ "merge_backbone": "merge_backbone",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ scaling_factor: float,
34
+ threshold: float,
35
+ remove_keys: List[str] = None,
36
+ merge_func: str = "sum",
37
+ merge_backbone: bool = False,
38
+ model_save_path: Optional[str] = None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(
42
+ scaling_factor=scaling_factor,
43
+ threshold=threshold,
44
+ remove_keys=remove_keys,
45
+ merge_func=merge_func,
46
+ **kwargs,
47
+ )
48
+
49
+ @override
50
+ def run(self, modelpool: CausalLMPool):
51
+ if self.merge_backbone:
52
+ assert modelpool.has_pretrained
53
+ backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
54
+ model = modelpool.load_model("_pretrained_")
55
+ backbone_model = super().run(backbone_modelpool)
56
+ model.model.layers = backbone_model
57
+ else:
58
+ model = super().run(modelpool)
59
+
60
+ if self.model_save_path is not None:
61
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
62
+ description = f"Merged model using TIES merging with scaling factor {self.scaling_factor} and threshold {self.threshold}."
63
+ modelpool.save_model(
64
+ model=model,
65
+ path=self.model_save_path,
66
+ save_tokenizer=True,
67
+ algorithm_config=self.config,
68
+ description=description,
69
+ )
70
+ return model
@@ -89,7 +89,7 @@ class SimpleAverageAlgorithm(
89
89
  modelpool = BaseModelPool(modelpool)
90
90
 
91
91
  log.info(
92
- f"Fusing models using simple average on {len(modelpool.model_names)} models."
92
+ f"Fusing models using simple average on {len(modelpool.model_names)} models. "
93
93
  f"models: {modelpool.model_names}"
94
94
  )
95
95
  sd: Optional[StateDictType] = None
@@ -119,7 +119,7 @@ class SimpleAverageAlgorithm(
119
119
 
120
120
  if isinstance(forward_model, LazyStateDict):
121
121
  # if the model is a LazyStateDict, convert it to an empty module
122
- forward_model = forward_model.meta_module.to_empty(
122
+ forward_model = deepcopy(forward_model.meta_module).to_empty(
123
123
  device=forward_model._device
124
124
  )
125
125
  result = forward_model.load_state_dict(sd, strict=False)
@@ -6,11 +6,20 @@ http://arxiv.org/abs/2212.04089
6
6
 
7
7
  import logging
8
8
  from copy import deepcopy
9
- from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
9
+ from typing import ( # noqa: F401
10
+ TYPE_CHECKING,
11
+ Dict,
12
+ List,
13
+ Mapping,
14
+ Optional,
15
+ TypeVar,
16
+ Union,
17
+ )
10
18
 
11
19
  import torch
12
20
  from torch import nn
13
21
 
22
+ from fusion_bench import LazyStateDict
14
23
  from fusion_bench.method.base_algorithm import BaseAlgorithm
15
24
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
25
  from fusion_bench.modelpool import BaseModelPool
@@ -21,6 +30,8 @@ from fusion_bench.utils.state_dict_arithmetic import (
21
30
  )
22
31
  from fusion_bench.utils.type import StateDictType, TorchModelType
23
32
 
33
+ if TYPE_CHECKING:
34
+ from transformers import PreTrainedModel
24
35
  log = logging.getLogger(__name__)
25
36
 
26
37
 
@@ -125,25 +136,39 @@ class TaskArithmeticAlgorithm(
125
136
  with self.profile("merge weights"):
126
137
  if task_vector is None:
127
138
  task_vector = state_dict_sub(
128
- model.state_dict(keep_vars=True),
129
- pretrained_model.state_dict(keep_vars=True),
139
+ model.state_dict(),
140
+ pretrained_model.state_dict(),
130
141
  )
131
142
  else:
132
143
  task_vector = state_dict_add(
133
144
  task_vector,
134
145
  state_dict_sub(
135
- model.state_dict(keep_vars=True),
136
- pretrained_model.state_dict(keep_vars=True),
146
+ model.state_dict(),
147
+ pretrained_model.state_dict(),
137
148
  ),
138
149
  )
139
150
  with self.profile("merge weights"):
140
151
  # scale the task vector
141
152
  task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
142
153
  # add the task vector to the pretrained model
143
- state_dict = state_dict_add(
144
- pretrained_model.state_dict(keep_vars=True), task_vector
145
- )
154
+ state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
146
155
 
147
156
  self.print_profile_summary()
148
- pretrained_model.load_state_dict(state_dict)
149
- return pretrained_model
157
+
158
+ # apply state dict to model
159
+ if isinstance(pretrained_model, nn.Module):
160
+ model = pretrained_model
161
+ model.load_state_dict(state_dict)
162
+ elif isinstance(pretrained_model, LazyStateDict):
163
+ model = deepcopy(pretrained_model.meta_module)
164
+ model = model.to_empty(device=pretrained_model._device)
165
+ result = model.load_state_dict(state_dict, strict=False)
166
+ if result.unexpected_keys:
167
+ raise ValueError(
168
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
169
+ )
170
+ if result.missing_keys:
171
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
172
+ else:
173
+ raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
174
+ return model
@@ -9,11 +9,14 @@ Overview of Ties-Merging:
9
9
  """
10
10
 
11
11
  import logging
12
+ from copy import deepcopy
12
13
  from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
13
14
 
14
15
  import torch
15
16
  from torch import Tensor, nn
17
+ from transformers import PreTrainedModel
16
18
 
19
+ from fusion_bench import LazyStateDict
17
20
  from fusion_bench.compat.modelpool import to_modelpool
18
21
  from fusion_bench.method import BaseAlgorithm
19
22
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
@@ -98,12 +101,25 @@ class TiesMergingAlgorithm(
98
101
  merge_func=merge_func,
99
102
  )
100
103
  merged_check = flat_ptm + scaling_factor * merged_tv
101
- merged_state_dict = vector_to_state_dict(
104
+ state_dict = vector_to_state_dict(
102
105
  merged_check, ptm_check, remove_keys=remove_keys
103
106
  )
104
-
105
- # Load the merged state dict into the pretrained model
106
- pretrained_model.load_state_dict(merged_state_dict)
107
-
108
107
  self.print_profile_summary()
109
- return pretrained_model
108
+
109
+ # apply state dict to model
110
+ if isinstance(pretrained_model, nn.Module):
111
+ model = pretrained_model
112
+ model.load_state_dict(state_dict)
113
+ elif isinstance(pretrained_model, LazyStateDict):
114
+ model = deepcopy(pretrained_model.meta_module)
115
+ model = model.to_empty(device=pretrained_model._device)
116
+ result = model.load_state_dict(state_dict, strict=False)
117
+ if result.unexpected_keys:
118
+ raise ValueError(
119
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
120
+ )
121
+ if result.missing_keys:
122
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
123
+ else:
124
+ raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
125
+ return model
@@ -0,0 +1 @@
1
+ from .wudi import WUDIMerging, wudi_merging
@@ -0,0 +1,105 @@
1
+ """
2
+ Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
3
+ Arxiv: http://arxiv.org/abs/2503.08099
4
+ """
5
+
6
+ from typing import List
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
12
+ from fusion_bench.mixins import LightningFabricMixin
13
+ from fusion_bench.utils import timeit_context
14
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
15
+
16
+
17
+ def wudi_merging(
18
+ task_vectors: List[torch.Tensor],
19
+ accelerator="cuda",
20
+ iter_num: int = 300,
21
+ exclude_keys: List[str] = None,
22
+ ):
23
+ exclude_keys = [] if exclude_keys is None else exclude_keys
24
+
25
+ with timeit_context("WUDI Merging"):
26
+ new_vector = {}
27
+ for key in tqdm(task_vectors[0], desc="WUDI Merging", leave=False):
28
+ tqdm.write(f"key: {key}")
29
+ original_device = task_vectors[0][key].device
30
+ tvs = torch.stack(
31
+ [
32
+ task_vector[key].to(device=accelerator, non_blocking=True)
33
+ for task_vector in task_vectors
34
+ ]
35
+ )
36
+ num_tvs = len(tvs)
37
+ new_vector[key] = torch.nn.Parameter(torch.sum(tvs, dim=0))
38
+
39
+ if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
40
+ optimizer = torch.optim.Adam([new_vector[key]], lr=1e-5, weight_decay=0)
41
+ l2_norms = torch.square(
42
+ torch.norm(tvs.reshape(tvs.shape[0], -1), p=2, dim=-1)
43
+ )
44
+ for i in tqdm(
45
+ range(iter_num),
46
+ ):
47
+ disturbing_vectors = new_vector[key].unsqueeze(0) - tvs
48
+ product = torch.matmul(disturbing_vectors, tvs.transpose(1, 2))
49
+ loss = torch.sum(
50
+ torch.square(product) / l2_norms.unsqueeze(-1).unsqueeze(-1)
51
+ )
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+ else:
56
+ new_vector[key] = new_vector[key] / num_tvs
57
+ new_vector[key] = new_vector[key].to(
58
+ device=original_device, non_blocking=True
59
+ )
60
+ return new_vector
61
+
62
+
63
+ @auto_register_config
64
+ class WUDIMerging(
65
+ LightningFabricMixin,
66
+ BaseAlgorithm,
67
+ ):
68
+ """
69
+ Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ iter_num: int,
75
+ exclude_keys: List[str] = None,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+
80
+ def run(self, modelpool: BaseModelPool):
81
+ # load the pretrained model and the task vectors of all the finetuned models
82
+ with torch.no_grad():
83
+ pretrained_model = modelpool.load_pretrained_model()
84
+ task_vectors = []
85
+ for model_name in modelpool.model_names:
86
+ finetuned_model = modelpool.load_model(model_name)
87
+ task_vectors.append(
88
+ state_dict_sub(
89
+ finetuned_model.state_dict(), pretrained_model.state_dict()
90
+ )
91
+ )
92
+ del finetuned_model # free memory
93
+
94
+ merged_tv = wudi_merging(
95
+ task_vectors,
96
+ accelerator=self.fabric.device,
97
+ iter_num=self.iter_num,
98
+ exclude_keys=self.exclude_keys,
99
+ )
100
+
101
+ pretrained_model.load_state_dict(
102
+ state_dict_add(pretrained_model.state_dict(), merged_tv)
103
+ )
104
+
105
+ return pretrained_model
@@ -100,6 +100,10 @@ class LightningFabricMixin:
100
100
  self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
101
101
  return self._fabric_instance
102
102
 
103
+ @fabric.setter
104
+ def fabric(self, instance: L.Fabric):
105
+ self._fabric_instance = instance
106
+
103
107
  @property
104
108
  def log_dir(self):
105
109
  """