fusion-bench 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. fusion_bench/compat/method/__init__.py +1 -1
  2. fusion_bench/dataset/fer2013.py +0 -1
  3. fusion_bench/method/__init__.py +2 -2
  4. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  5. fusion_bench/method/doge_ta/__init__.py +2 -0
  6. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  7. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  8. fusion_bench/method/opcm/opcm.py +93 -84
  9. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  10. fusion_bench/method/opcm/ties_merging.py +71 -52
  11. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  12. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  13. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
  14. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +22 -21
  15. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
  16. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
  17. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
  18. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  19. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  20. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
  21. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
  22. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
  23. /fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +0 -0
@@ -20,7 +20,7 @@ class AlgorithmFactory:
20
20
  # model merging methods
21
21
  "clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
22
22
  "clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
23
- "clip_layer_wise_adamerging_doge_ta": ".DOGE_TA.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
23
+ "clip_layer_wise_adamerging_doge_ta": ".doge_ta.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
24
24
  "singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
25
25
  "clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
26
26
  # plug-and-play model merging methods
@@ -7,7 +7,6 @@ def load_fer2013(path: str = "clip-benchmark/wds_fer2013", split: str = "train")
7
7
  dataset = dataset.rename_columns({"jpg": "image", "cls": "label"})
8
8
  return dataset
9
9
 
10
-
11
10
  if __name__ == "__main__":
12
11
  dataset = load_fer2013(split="test")
13
12
  print(dataset)
@@ -53,7 +53,7 @@ _import_structure = {
53
53
  "PWEMoExactParetoOptimalForCLIP",
54
54
  ],
55
55
  "ada_svd": ["AdaSVDMergingForCLIPVisionModel"],
56
- "DOGE_TA": ["DOGE_TA_Algorithm"],
56
+ "doge_ta": ["DOGE_TA_Algorithm"],
57
57
  "task_singular_vector": ["TaskSingularVectorMerging"],
58
58
  "isotropic_merging": [
59
59
  "ISO_C_Merge", # alias
@@ -128,7 +128,7 @@ if TYPE_CHECKING:
128
128
  from .dare import DareSimpleAverage, DareTaskArithmetic, DareTiesMerging
129
129
  from .dawe import DataAdaptiveWeightEnsemblingForCLIP
130
130
  from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
131
- from .DOGE_TA import DOGE_TA_Algorithm
131
+ from .doge_ta import DOGE_TA_Algorithm
132
132
  from .dummy import DummyAlgorithm
133
133
  from .ensemble import (
134
134
  MaxModelPredictorAlgorithm,
@@ -9,7 +9,7 @@ fusion_bench \
9
9
  modelpool=clip-vit-base-patch32_TA8 \
10
10
  taskpool=clip-vit-classification_TA8 \
11
11
  fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
12
- fabric.loggers.name=clip_layer_wise_adamerging_adam
12
+ fabric.loggers.name=clip_layer_wise_adamerging_adamerging
13
13
  ```
14
14
  """
15
15
 
@@ -0,0 +1,2 @@
1
+ # flake8: noqa F401
2
+ from .doge_ta import DOGE_TA_Algorithm
@@ -9,7 +9,7 @@ fusion_bench \
9
9
  modelpool=clip-vit-base-patch32_TA8 \
10
10
  taskpool=clip-vit-classification_TA8 \
11
11
  fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
12
- fabric.loggers.name=clip_layer_wise_adamerging_adam
12
+ fabric.loggers.name=clip_layer_wise_adamerging_adamerging
13
13
  ```
14
14
  """
15
15
 
@@ -7,7 +7,7 @@ Example Usage:
7
7
 
8
8
  ```bash
9
9
  fusion_bench \
10
- method=DOGE_TA/DOGE_TA \
10
+ method=doge_ta/doge_ta \
11
11
  modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
12
12
  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
13
13
 
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
15
15
  from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
19
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
20
  from fusion_bench.utils import instantiate
21
21
  from fusion_bench.utils.json import load_from_json, save_to_json
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
31
31
  class OPCMForCLIP(
32
32
  BaseAlgorithm,
33
33
  LightningFabricMixin,
34
+ SimpleProfilerMixin,
34
35
  ):
35
36
  def __init__(
36
37
  self,
@@ -64,7 +65,8 @@ class OPCMForCLIP(
64
65
  L.seed_everything(self.seed)
65
66
  accelerator = self.fabric.device
66
67
 
67
- pretrained_model = modelpool.load_pretrained_model()
68
+ with self.profile("loading model"):
69
+ pretrained_model = modelpool.load_pretrained_model()
68
70
 
69
71
  model_names = modelpool.model_names
70
72
  if self.shuffle_order:
@@ -83,15 +85,17 @@ class OPCMForCLIP(
83
85
  )
84
86
 
85
87
  # get the average model
86
- merged_model = modelpool.load_model(model_names[0])
88
+ with self.profile("loading model"):
89
+ merged_model = modelpool.load_model(model_names[0])
87
90
 
88
91
  if self.evaluate_on_every_step:
89
- self.taskpool._is_setup = False
90
- self.taskpool._test_datasets = DictConfig(
91
- {model_names[0]: self._test_datasets[model_names[0]]}
92
- )
93
- report = self.taskpool.evaluate(deepcopy(merged_model))
94
- save_to_json(report, Path(self.log_dir) / "report_0.json")
92
+ with self.profile("evaluating model"):
93
+ self.taskpool._is_setup = False
94
+ self.taskpool._test_datasets = DictConfig(
95
+ {model_names[0]: self._test_datasets[model_names[0]]}
96
+ )
97
+ report = self.taskpool.evaluate(deepcopy(merged_model))
98
+ save_to_json(report, Path(self.log_dir) / "report_0.json")
95
99
 
96
100
  self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
97
101
  self.all_task_vector_norm = [self.avg_task_vector_norm]
@@ -113,90 +117,95 @@ class OPCMForCLIP(
113
117
  enumerate(model_names[1:]), desc="Processing models"
114
118
  ):
115
119
  model_idx += 1
116
- task_model = modelpool.load_model(model_name)
120
+ with self.profile("loading model"):
121
+ task_model = modelpool.load_model(model_name)
117
122
 
118
- self.all_task_vector_norm.append(
119
- get_task_vector_norm(task_model, pretrained_model)
120
- )
121
- self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
122
- self.fabric.log(
123
- "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
124
- )
125
- self.fabric.log(
126
- "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
127
- )
123
+ with self.profile("merging model"):
124
+ self.all_task_vector_norm.append(
125
+ get_task_vector_norm(task_model, pretrained_model)
126
+ )
127
+ self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
128
+ self.fabric.log(
129
+ "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
130
+ )
131
+ self.fabric.log(
132
+ "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
133
+ )
128
134
 
129
- self.lambda_t = 1 # temporary value
130
-
131
- for module_name, module in tqdm(
132
- list(merged_model.named_modules()),
133
- desc=f"Processing {model_name}",
134
- leave=False,
135
- ):
136
- if not is_leaf_module(module):
137
- continue
138
-
139
- if isinstance(module, nn.Linear):
140
- module.weight.data = self.merge_linear_weights(
141
- module.weight,
142
- pretrained_model.get_submodule(module_name).weight,
143
- task_model.get_submodule(module_name).weight,
144
- param_name=".".join([module_name, "weight"]),
145
- alpha=self.alpha,
146
- accelerator=accelerator,
147
- )
148
- if module.bias is not None:
149
- module.bias.data = self.merge_other_parameters(
150
- module.bias,
151
- pretrained_model.get_submodule(module_name).bias,
152
- task_model.get_submodule(module_name).bias,
153
- param_name=".".join([module_name, "bias"]),
135
+ self.lambda_t = 1 # temporary value
136
+
137
+ for module_name, module in tqdm(
138
+ list(merged_model.named_modules()),
139
+ desc=f"Processing {model_name}",
140
+ leave=False,
141
+ ):
142
+ if not is_leaf_module(module):
143
+ continue
144
+
145
+ if isinstance(module, nn.Linear):
146
+ module.weight.data = self.merge_linear_weights(
147
+ module.weight,
148
+ pretrained_model.get_submodule(module_name).weight,
149
+ task_model.get_submodule(module_name).weight,
150
+ param_name=".".join([module_name, "weight"]),
151
+ alpha=self.alpha,
154
152
  accelerator=accelerator,
155
153
  )
156
- else:
157
- for param_name, param in module.named_parameters():
158
- param.data = self.merge_other_parameters(
159
- merged_W=param,
160
- pretrained_W=pretrained_model.get_submodule(
161
- module_name
162
- ).get_parameter(param_name),
163
- task_W=task_model.get_submodule(module_name).get_parameter(
164
- param_name
165
- ),
166
- param_name=".".join([module_name, param_name]),
167
- accelerator=accelerator,
168
- )
169
-
170
- task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
171
- self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
172
- for param_name, param in merged_model.named_parameters():
173
- param.data = pretrained_model.get_parameter(param_name) + (
174
- param - pretrained_model.get_parameter(param_name)
175
- ) * (self.avg_task_vector_norm / task_vector_norm)
176
- self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
177
- self.fabric.log(
178
- "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
179
- )
180
- self.previous_lambda_t = self.lambda_t
181
- self.lambda_t = None
154
+ if module.bias is not None:
155
+ module.bias.data = self.merge_other_parameters(
156
+ module.bias,
157
+ pretrained_model.get_submodule(module_name).bias,
158
+ task_model.get_submodule(module_name).bias,
159
+ param_name=".".join([module_name, "bias"]),
160
+ accelerator=accelerator,
161
+ )
162
+ else:
163
+ for param_name, param in module.named_parameters():
164
+ param.data = self.merge_other_parameters(
165
+ merged_W=param,
166
+ pretrained_W=pretrained_model.get_submodule(
167
+ module_name
168
+ ).get_parameter(param_name),
169
+ task_W=task_model.get_submodule(module_name).get_parameter(
170
+ param_name
171
+ ),
172
+ param_name=".".join([module_name, param_name]),
173
+ accelerator=accelerator,
174
+ )
175
+
176
+ task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
177
+ self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
178
+ for param_name, param in merged_model.named_parameters():
179
+ param.data = pretrained_model.get_parameter(param_name) + (
180
+ param - pretrained_model.get_parameter(param_name)
181
+ ) * (self.avg_task_vector_norm / task_vector_norm)
182
+ self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
183
+ self.fabric.log(
184
+ "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
185
+ )
186
+ self.previous_lambda_t = self.lambda_t
187
+ self.lambda_t = None
182
188
 
183
- self.fabric.log(
184
- "model/merged_task_vector_norm",
185
- get_task_vector_norm(merged_model, pretrained_model),
186
- step=model_idx,
187
- )
189
+ self.fabric.log(
190
+ "model/merged_task_vector_norm",
191
+ get_task_vector_norm(merged_model, pretrained_model),
192
+ step=model_idx,
193
+ )
188
194
 
189
195
  if self.save_on_every_step:
190
- self.save_merged_model(merged_model, model_idx)
196
+ with self.profile("saving model"):
197
+ self.save_merged_model(merged_model, model_idx)
191
198
 
192
199
  if self.evaluate_on_every_step:
193
- self.taskpool._is_setup = False
194
- self.taskpool._test_datasets = DictConfig(
195
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
196
- )
197
- report = self.taskpool.evaluate(deepcopy(merged_model))
198
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
200
+ with self.profile("evaluating model"):
201
+ self.taskpool._is_setup = False
202
+ self.taskpool._test_datasets = DictConfig(
203
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
204
+ )
205
+ report = self.taskpool.evaluate(deepcopy(merged_model))
206
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
199
207
 
208
+ self.print_profile_summary()
200
209
  return merged_model
201
210
 
202
211
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -227,7 +236,7 @@ class OPCMForCLIP(
227
236
  split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
228
237
 
229
238
  projected_task_tv = u.T @ task_tv @ v
230
- projected_task_tv.diag().fill_(0)
239
+ projected_task_tv.diagonal().fill_(0)
231
240
 
232
241
  projected_task_tv[:split_rank, :split_rank] = 0
233
242
 
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
15
15
  from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
19
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
20
  from fusion_bench.utils.json import load_from_json, save_to_json
21
21
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
@@ -24,7 +24,11 @@ if TYPE_CHECKING:
24
24
  from torch.utils.tensorboard import SummaryWriter
25
25
 
26
26
 
27
- class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
27
+ class ContinualTaskArithmeticForCLIP(
28
+ BaseAlgorithm,
29
+ LightningFabricMixin,
30
+ SimpleProfilerMixin,
31
+ ):
28
32
  def __init__(
29
33
  self,
30
34
  scaling_factor: float,
@@ -79,32 +83,42 @@ class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
79
83
  for model_idx, model_name in tqdm(
80
84
  enumerate(model_names), desc="Processing models"
81
85
  ):
82
- task_model = modelpool.load_model(model_name)
86
+ with self.profile("loading model"):
87
+ task_model = modelpool.load_model(model_name)
83
88
 
84
- for param_name, param in task_model.named_parameters():
85
- if not param.requires_grad:
86
- continue
89
+ with self.profile("merging model"):
90
+ for param_name, param in task_model.named_parameters():
91
+ if not param.requires_grad:
92
+ continue
87
93
 
88
- task_param = param
89
- merged_param = merged_model.get_parameter(param_name)
90
- pretrained_param = pretrained_model.get_parameter(param_name)
94
+ task_param = param
95
+ merged_param = merged_model.get_parameter(param_name)
96
+ pretrained_param = pretrained_model.get_parameter(param_name)
91
97
 
92
- new_param = merged_param + self.scaling_factor * (
93
- task_param - pretrained_param
94
- )
95
- merged_model.get_parameter(param_name).data = new_param
98
+ new_param = merged_param + self.scaling_factor * (
99
+ task_param - pretrained_param
100
+ )
101
+ merged_model.get_parameter(param_name).data = new_param
96
102
 
97
103
  if self.save_on_every_step:
98
- self.save_merged_model(merged_model, model_idx)
104
+ with self.profile("saving model"):
105
+ self.save_merged_model(merged_model, model_idx)
99
106
 
100
107
  if self.evaluate_on_every_step:
101
- self.taskpool._is_setup = False
102
- self.taskpool._test_datasets = DictConfig(
103
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
104
- )
105
- report = self.taskpool.evaluate(deepcopy(merged_model))
106
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
107
-
108
+ with self.profile("evaluating model"):
109
+ self.taskpool._is_setup = False
110
+ self.taskpool._test_datasets = DictConfig(
111
+ {
112
+ n: self._test_datasets[n]
113
+ for n in model_names[: model_idx + 1]
114
+ }
115
+ )
116
+ report = self.taskpool.evaluate(deepcopy(merged_model))
117
+ save_to_json(
118
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
119
+ )
120
+
121
+ self.print_profile_summary()
108
122
  return merged_model
109
123
 
110
124
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -20,7 +20,7 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
20
20
  ties_merging,
21
21
  vector_to_state_dict,
22
22
  )
23
- from fusion_bench.mixins import LightningFabricMixin
23
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
24
24
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
25
25
  from fusion_bench.utils.json import load_from_json, save_to_json
26
26
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
@@ -29,7 +29,11 @@ if TYPE_CHECKING:
29
29
  from torch.utils.tensorboard import SummaryWriter
30
30
 
31
31
 
32
- class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
32
+ class ContinualTiesMergingForCLIP(
33
+ BaseAlgorithm,
34
+ LightningFabricMixin,
35
+ SimpleProfilerMixin,
36
+ ):
33
37
  def __init__(
34
38
  self,
35
39
  scaling_factor: float,
@@ -84,68 +88,83 @@ class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
84
88
  )
85
89
 
86
90
  # get the average model
87
- pretrained_model = modelpool.load_pretrained_model()
91
+ with self.profile("loading model"):
92
+ pretrained_model = modelpool.load_pretrained_model()
88
93
  merged_model = deepcopy(pretrained_model)
89
94
 
90
95
  for model_idx, model_name in tqdm(
91
96
  enumerate(model_names), desc="Processing models"
92
97
  ):
93
- task_model = modelpool.load_model(model_name)
98
+ with self.profile("loading model"):
99
+ task_model = modelpool.load_model(model_name)
94
100
 
95
- task_vector = state_dict_sub(
96
- task_model.state_dict(),
97
- pretrained_model.state_dict(),
98
- )
99
- if model_idx == 0:
100
- # if is the first model, the merged task vector is equal to the task vector
101
- ties_merging_state_dict = task_vector
102
- else:
103
- # if is not the first model, we need to merge the task vector with the previous merged task vector
104
- merged_tv = state_dict_sub(
105
- merged_model.state_dict(),
101
+ with self.profile("merging model"):
102
+ task_vector = state_dict_sub(
103
+ task_model.state_dict(),
106
104
  pretrained_model.state_dict(),
107
105
  )
108
- tv_flat_checks = torch.vstack(
109
- [
110
- state_dict_to_vector(merged_tv, remove_keys=self.remove_keys),
111
- state_dict_to_vector(task_vector, remove_keys=self.remove_keys),
112
- ]
113
- )
114
- # perform the TIES merging
115
- ties_merging_tv = ties_merging(
116
- tv_flat_checks,
117
- reset_thresh=self.threshold,
118
- merge_func=self.merge_func,
119
- )
120
- # convert the merged task vector back to a state dict
121
- ties_merging_state_dict = vector_to_state_dict(
122
- ties_merging_tv,
123
- merged_model.state_dict(),
124
- remove_keys=self.remove_keys,
125
- )
126
-
127
- for param_name, param in task_model.named_parameters():
128
- if not param.requires_grad:
129
- continue
130
-
131
- merged_param = merged_model.get_parameter(param_name)
132
- new_param = (
133
- merged_param
134
- + self.scaling_factor * ties_merging_state_dict[param_name]
135
- )
136
- merged_model.get_parameter(param_name).data = new_param
106
+ if model_idx == 0:
107
+ # if is the first model, the merged task vector is equal to the task vector
108
+ ties_merging_state_dict = task_vector
109
+ else:
110
+ # if is not the first model, we need to merge the task vector with the previous merged task vector
111
+ merged_tv = state_dict_sub(
112
+ merged_model.state_dict(),
113
+ pretrained_model.state_dict(),
114
+ )
115
+ tv_flat_checks = torch.vstack(
116
+ [
117
+ state_dict_to_vector(
118
+ merged_tv, remove_keys=self.remove_keys
119
+ ),
120
+ state_dict_to_vector(
121
+ task_vector, remove_keys=self.remove_keys
122
+ ),
123
+ ]
124
+ )
125
+ # perform the TIES merging
126
+ ties_merging_tv = ties_merging(
127
+ tv_flat_checks,
128
+ reset_thresh=self.threshold,
129
+ merge_func=self.merge_func,
130
+ )
131
+ # convert the merged task vector back to a state dict
132
+ ties_merging_state_dict = vector_to_state_dict(
133
+ ties_merging_tv,
134
+ merged_model.state_dict(),
135
+ remove_keys=self.remove_keys,
136
+ )
137
+
138
+ for param_name, param in task_model.named_parameters():
139
+ if not param.requires_grad:
140
+ continue
141
+
142
+ merged_param = merged_model.get_parameter(param_name)
143
+ new_param = (
144
+ merged_param
145
+ + self.scaling_factor * ties_merging_state_dict[param_name]
146
+ )
147
+ merged_model.get_parameter(param_name).data = new_param
137
148
 
138
149
  if self.save_on_every_step:
139
- self.save_merged_model(merged_model, model_idx)
150
+ with self.profile("saving model"):
151
+ self.save_merged_model(merged_model, model_idx)
140
152
 
141
153
  if self.evaluate_on_every_step:
142
- self.taskpool._is_setup = False
143
- self.taskpool._test_datasets = DictConfig(
144
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
145
- )
146
- report = self.taskpool.evaluate(deepcopy(merged_model))
147
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
148
-
154
+ with self.profile("evaluating model"):
155
+ self.taskpool._is_setup = False
156
+ self.taskpool._test_datasets = DictConfig(
157
+ {
158
+ n: self._test_datasets[n]
159
+ for n in model_names[: model_idx + 1]
160
+ }
161
+ )
162
+ report = self.taskpool.evaluate(deepcopy(merged_model))
163
+ save_to_json(
164
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
165
+ )
166
+
167
+ self.print_profile_summary()
149
168
  return merged_model
150
169
 
151
170
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -16,6 +16,7 @@ import torch
16
16
  from torch import Tensor, nn
17
17
  from torch.func import functional_call
18
18
 
19
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
19
20
  from fusion_bench.utils.type import StateDictType, TorchModelType
20
21
 
21
22
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
@@ -23,52 +24,6 @@ __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
23
24
  log = logging.getLogger(__name__)
24
25
 
25
26
 
26
- def del_attr(obj, names: List[str]):
27
- """
28
- Deletes an attribute from an object recursively.
29
-
30
- Args:
31
- obj (object): Object to delete attribute from.
32
- names (list): List of attribute names to delete recursively.
33
- """
34
- if len(names) == 1:
35
- delattr(obj, names[0])
36
- else:
37
- del_attr(getattr(obj, names[0]), names[1:])
38
-
39
-
40
- def set_attr(obj, names: List[str], val):
41
- """
42
- Sets an attribute of an object recursively.
43
-
44
- Args:
45
- obj (object): Object to set attribute of.
46
- names (list): List of attribute names to set recursively.
47
- val (object): Value to set the attribute to.
48
- """
49
- if len(names) == 1:
50
- setattr(obj, names[0], val)
51
- else:
52
- set_attr(getattr(obj, names[0]), names[1:], val)
53
-
54
-
55
- def get_attr(obj, names: List[str]):
56
- """
57
- Gets an attribute of an object recursively.
58
-
59
- Args:
60
- obj (object): Object to get attribute of.
61
- names (list): List of attribute names to get recursively.
62
-
63
- Returns:
64
- object: The attribute of the object.
65
- """
66
- if len(names) == 1:
67
- return getattr(obj, names[0])
68
- else:
69
- return get_attr(getattr(obj, names[0]), names[1:])
70
-
71
-
72
27
  def get_layer_wise_weights(
73
28
  num_models: int,
74
29
  num_layers: int,
@@ -10,132 +10,17 @@ import torch
10
10
  from torch import Tensor, nn
11
11
  from torch.func import functional_call
12
12
 
13
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
13
14
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add
14
15
  from fusion_bench.utils.type import StateDictType
15
16
 
17
+ from .layer_wise_fusion import fuse_weights, get_layer_wise_weights
18
+
16
19
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
17
20
 
18
21
  log = logging.getLogger(__name__)
19
22
 
20
23
 
21
- def del_attr(obj, names: List[str]):
22
- """
23
- Deletes an attribute from an object recursively.
24
-
25
- Args:
26
- obj (object): Object to delete attribute from.
27
- names (list): List of attribute names to delete recursively.
28
- """
29
- if len(names) == 1:
30
- delattr(obj, names[0])
31
- else:
32
- del_attr(getattr(obj, names[0]), names[1:])
33
-
34
-
35
- def set_attr(obj, names: List[str], val):
36
- """
37
- Sets an attribute of an object recursively.
38
-
39
- Args:
40
- obj (object): Object to set attribute of.
41
- names (list): List of attribute names to set recursively.
42
- val (object): Value to set the attribute to.
43
- """
44
- if len(names) == 1:
45
- setattr(obj, names[0], val)
46
- else:
47
- set_attr(getattr(obj, names[0]), names[1:], val)
48
-
49
-
50
- def get_attr(obj, names: List[str]):
51
- """
52
- Gets an attribute of an object recursively.
53
-
54
- Args:
55
- obj (object): Object to get attribute of.
56
- names (list): List of attribute names to get recursively.
57
-
58
- Returns:
59
- object: The attribute of the object.
60
- """
61
- if len(names) == 1:
62
- return getattr(obj, names[0])
63
- else:
64
- return get_attr(getattr(obj, names[0]), names[1:])
65
-
66
-
67
- def get_layer_wise_weights(
68
- num_models: int,
69
- num_layers: int,
70
- init_values: float = None,
71
- dtype: torch.dtype = torch.float32,
72
- ):
73
- """
74
- Return a tensor of layer-wise weights for the given number of models and layers.
75
-
76
- Args:
77
- num_models (int): The number of models to fuse.
78
- num_layers (int): The number of layers in each model.
79
- init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models.
80
- dtype (torch.dtype): dtype of weights. This should be the same with model dtype.
81
-
82
- Returns:
83
- Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights.
84
- """
85
- assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
86
- assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}"
87
- if init_values is None:
88
- init_values = 1.0 / num_models
89
- return torch.full((num_models, num_layers), init_values, dtype=dtype)
90
-
91
-
92
- def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
93
- """
94
- Fuse the layer-wise weights with the given state dictionaries.
95
-
96
- Args:
97
- layer_wise_weight (Tensor): A tensor of shape (num_models,) containing the layer-wise weights.
98
- state_dicts (List[Tensor]): A list of state dictionaries, each containing the weights for a single layer.
99
-
100
- Returns:
101
- Tensor: A tensor of shape (num_params,) containing the fused weights.
102
- """
103
- assert len(layer_wise_weight) == len(
104
- tensors
105
- ), f"layer_wise_weight.shape={layer_wise_weight.shape}, len(tensors)={len(tensors)}"
106
- return sum(
107
- layer_wise_weight[i] * w.to(layer_wise_weight.device)
108
- for i, w in enumerate(tensors)
109
- )
110
-
111
-
112
- def fuse_weights(
113
- layer_wise_weight: Tensor, state_dicts: List[StateDictType]
114
- ) -> StateDictType:
115
- """
116
- Fuse the weights of multiple models using layer-wise fusion.
117
-
118
- Args:
119
- layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
120
- state_dicts (List[StateDict]): A list of state dictionaries, one for each model.
121
-
122
- Returns:
123
- A dictionary mapping each weight tensor key to the fused weight tensor.
124
- """
125
- num_models = len(state_dicts)
126
- num_layers = len(state_dicts[0])
127
- assert layer_wise_weight.shape == (
128
- num_models,
129
- num_layers,
130
- ), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})"
131
- return {
132
- k: _fuse_weights(
133
- layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts]
134
- )
135
- for i, k in enumerate(state_dicts[0].keys())
136
- }
137
-
138
-
139
24
  class LayerWiseMergedModel(nn.Module):
140
25
  _merged_state_dict: StateDictType = None
141
26
 
@@ -390,7 +275,7 @@ class LayerWiseMergedModel(nn.Module):
390
275
  layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
391
276
  sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
392
277
 
393
- layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
278
+ layer_delta_scale = layer_delta * layer_lamdas.view(-1, 1, 1)
394
279
  sum_over_delta = layer_delta_scale.sum(dim=0)
395
280
 
396
281
  # Iterate through each vector and calculate the loss one by one
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.11
3
+ Version: 0.2.12
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License
@@ -45,6 +45,7 @@ Requires-Dist: rich
45
45
  Requires-Dist: scipy
46
46
  Requires-Dist: h5py
47
47
  Requires-Dist: pytest
48
+ Dynamic: license-file
48
49
 
49
50
  <div align='center'>
50
51
 
@@ -69,6 +70,18 @@ FusionBench is a benchmark suite designed to evaluate the performance of various
69
70
 
70
71
  Projects based on FusionBench and news from the community (descending order of date):
71
72
 
73
+ <details>
74
+ <summary>Hao Mark Chen, et al. FW-Merging: Scaling Model Merging with Frank-Wolfe Optimization. Mar 2025. https://arxiv.org/abs/2503.12649</summary>
75
+
76
+ Model merging has emerged as a promising approach for multi-task learning (MTL), offering a data-efficient alternative to conventional fine-tuning. However, with the rapid development of the open-source AI ecosystem and the increasing availability of fine-tuned foundation models, existing model merging methods face two key limitations: (i) They are primarily designed for in-house fine-tuned models, making them less adaptable to diverse model sources with partially unknown model and task information, (ii) They struggle to scale effectively when merging numerous model checkpoints. To address these challenges, we formulate model merging as a constrained optimization problem and introduce a novel approach: Frank-Wolfe Merging (FW-Merging). Inspired by Frank-Wolfe optimization, our approach iteratively selects the most relevant model in the pool to minimize a linear approximation of the objective function and then executes a local merging similar to the Frank-Wolfe update. The objective function is designed to capture the desired behavior of the target-merged model, while the fine-tuned candidate models define the constraint set. More importantly, FW-Merging serves as an orthogonal technique for existing merging methods, seamlessly integrating with them to further enhance accuracy performance. Our experiments show that FW-Merging scales across diverse model sources, remaining stable with 16 irrelevant models and improving by 15.3% with 16 relevant models on 20 CV tasks, while maintaining constant memory overhead, unlike the linear overhead of data-informed merging methods. Compared with the state-of-the-art approaches, FW-Merging surpasses the data-free merging method by 32.8% and outperforms the data-informed Adamerging by 8.39% when merging 20 ViT models.
77
+ </details>
78
+
79
+ <details>
80
+ <summary>Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. Feb 2025. https://arxiv.org/abs/2502.04959</summary>
81
+
82
+ Model merging integrates the weights of multiple task-specific models into a single multi-task model. Despite recent interest in the problem, a significant performance gap between the combined and single-task models remains. In this paper, we investigate the key characteristics of task matrices -- weight update matrices applied to a pre-trained model -- that enable effective merging. We show that alignment between singular components of task-specific and merged matrices strongly correlates with performance improvement over the pre-trained model. Based on this, we propose an isotropic merging framework that flattens the singular value spectrum of task matrices, enhances alignment, and reduces the performance gap. Additionally, we incorporate both common and task-specific subspaces to further improve alignment and performance. Our proposed approach achieves state-of-the-art performance across multiple scenarios, including various sets of tasks and model scales. This work advances the understanding of model merging dynamics, offering an effective methodology to merge models without requiring additional training.
83
+ </details>
84
+
72
85
  <details>
73
86
  <summary>Anke Tang, et al. Merging Models on the Fly Without Retraining: A Sequential Approach to Scalable Continual Model Merging. Jan 2025. https://arxiv.org/pdf/2501.09522</summary>
74
87
 
@@ -1,7 +1,7 @@
1
1
  fusion_bench/__init__.py,sha256=68dF-zPvb8E2MgYnmgIJsxIHJBy1MApKeOrRZvQEVlg,421
2
2
  fusion_bench/__main__.py,sha256=weUjxpP3ULnDgUxCehdbmoCM9cqfkhDhGB85tAF5qoE,81
3
3
  fusion_bench/compat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- fusion_bench/compat/method/__init__.py,sha256=97izLAf4JssNAoOXR4MYffFxb3OEwpHeQeSlL_ihMKI,5566
4
+ fusion_bench/compat/method/__init__.py,sha256=qbm_0o4Y-X2FY3skmsQpYnKQ3qnR24Z0-uLOEnzO59M,5566
5
5
  fusion_bench/compat/method/base_algorithm.py,sha256=63_AQDj1eJOO6RyTSGXVC6G2DsG8yg9E4pT3RJXgP3A,1952
6
6
  fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py,sha256=m68BRGy4P-P9lLB10oXOBI-p58a-0FOPcrJ4r4MU32k,1100
7
7
  fusion_bench/compat/modelpool/__init__.py,sha256=KD8Ddr9D7rJ5YdHEQsTuNmQ0bgQfqF4l3WNMtHmRHD8,4687
@@ -15,7 +15,7 @@ fusion_bench/constants/__init__.py,sha256=Pyc4dLbl6oNduOCdnpeXQ9LDyVoIrkdl9eZ_l2
15
15
  fusion_bench/constants/paths.py,sha256=DVZyQ9FLhkyUdw6ARpXUCAMf_B8hFyJ6UNI-oYly3pE,591
16
16
  fusion_bench/dataset/__init__.py,sha256=OJiYmcqz0Vm5O7mE4PB5QFJeL_KjrsseQTRsQATGTm4,1050
17
17
  fusion_bench/dataset/clip_dataset.py,sha256=XLpCOiXlLEP3DffAlBn4P2PpUenbEFl-Yk9MNy6nbbI,2790
18
- fusion_bench/dataset/fer2013.py,sha256=Lub_xVhHfqaiPprvOsDVspJNioh1FjSrkhn3gL_UXDA,404
18
+ fusion_bench/dataset/fer2013.py,sha256=bAdujQSj1PcUVFlKJgqcHAuE9AWz7JE1fzZ6scFVvmc,403
19
19
  fusion_bench/dataset/gpt2_glue.py,sha256=Qq1ZkEIQsTjj8tImvkZDNlduocSYwlEfVrDReZqDWdw,8761
20
20
  fusion_bench/dataset/gsm8k.py,sha256=CmANZ0A89PfPwVu_myKhXk1D9IwypOpjH3iqDo1KxcQ,2233
21
21
  fusion_bench/dataset/image_dataset.py,sha256=MSZE_UESyRRQDwnkm2KpyIARUg9SWcwqnH4fDNstzS4,1870
@@ -41,20 +41,16 @@ fusion_bench/dataset/llama/stanford_shp.py,sha256=6ueXKnFXIBBobacU1h5WxGLZrSOtBk
41
41
  fusion_bench/dataset/llama/ultrachat.py,sha256=Go7WvrDAYnm184fdazHGRYLbSY6Xd7jrESyQeUJtOww,1736
42
42
  fusion_bench/dataset/llama/wikitext.py,sha256=9ZHR-nMfXRumd3o-PIj3n7B83YlVeqpGkZ2zJs2B-9Y,2883
43
43
  fusion_bench/dataset/llama/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- fusion_bench/method/__init__.py,sha256=QGJzdOpZxonu_WUNXSFQIiMy4OHsgqmcU5Bs6OB_RT0,7040
44
+ fusion_bench/method/__init__.py,sha256=7S1ODkq2Zppx59o80qcIwDlRtfOC2EU58ooGFlDdJIU,7040
45
45
  fusion_bench/method/base_algorithm.py,sha256=5dutGZfPqNhO8F8FOlo3UFR91TZu2Xj7O0pTB40JvWo,1135
46
46
  fusion_bench/method/dummy.py,sha256=hb1y6LR_geRZ5eRgGwt5zJUcHYorCeIbs5i76CvurUc,1031
47
47
  fusion_bench/method/ensemble.py,sha256=rGxvJTeorfcBuE_e0XO-0-MAc9un7ZCC46ikKGuAcN4,3077
48
48
  fusion_bench/method/model_recombination.py,sha256=2tviqmYSPOL0_Ktv8_gt_YzQ4tyCANHxXquUot_3Cgo,5360
49
49
  fusion_bench/method/simple_average.py,sha256=2ghcL1E-eLbIYDCHYCoR9WtiYSb1GvFAH163OTTTEEI,4481
50
- fusion_bench/method/DOGE_TA/DOGE_TA.py,sha256=veNjBfq65fB7oqQL66zAuA339WCY5mG-mefkVteg2-k,13785
51
- fusion_bench/method/DOGE_TA/__init__.py,sha256=OTukCLUlbCUTDqGBtgBZop7eYFDfU2wjG4PkP4fXN4Q,59
52
- fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py,sha256=YdQ4trHohW6QzWC2enYvXA44WHxvzmoH_6sMrPn6z60,1305
53
- fusion_bench/method/DOGE_TA/layer_wise_adamerging.py,sha256=rLk3Nep5d6wMUNCp6q7pC7L0pfBvUwGBIuiGM7CQOf4,9780
54
50
  fusion_bench/method/ada_svd/__init__.py,sha256=4XzQbbvE9HI3NtEmEFvo8iC3ds_85vJXe7P7qJfL7kk,77
55
51
  fusion_bench/method/ada_svd/clip_vision.py,sha256=QrT6cSwgVEGxXEpVhkvKQVQaoRW5P9V52Y3_8NX0f-o,12556
56
52
  fusion_bench/method/adamerging/__init__.py,sha256=nt0saBT_3bqghk-pINQ-XCWm9UWwSZllu4R1sDuAJAA,376
57
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py,sha256=YdQ4trHohW6QzWC2enYvXA44WHxvzmoH_6sMrPn6z60,1305
53
+ fusion_bench/method/adamerging/clip_layer_wise_adamerging.py,sha256=UUSldRPBxHVOfkMM7ZwqZay5Wjc6XQ3Vy9PgyqV_TZo,1311
58
54
  fusion_bench/method/adamerging/clip_task_wise_adamerging.py,sha256=Tys9pDJzz5YNUCO43pO44fGAnizfSaeAwgH4-vVxRN4,6948
59
55
  fusion_bench/method/adamerging/entropy_loss.py,sha256=ZeVe0Hq1PaMfppLqDbB0MOscZUZRNh4CALrvt8pmQC0,736
60
56
  fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py,sha256=osc6ueCgiS4u8KUV_sZkHGFBYC8dThnTSp4NB0wkQIg,12915
@@ -87,6 +83,10 @@ fusion_bench/method/dawe/warppers/dawe_model.py,sha256=Z1L91vu3UzEHWrHs9i9UbwZpn
87
83
  fusion_bench/method/depth_upscaling/__init__.py,sha256=heVUh4tTzK427A10RFknf9eHwoZ1cpn1_0xyNXRU7YM,135
88
84
  fusion_bench/method/depth_upscaling/depth_upscaling.py,sha256=pf08zEae-WaWM4oUwn6_Dm65K59wf9AbTQ5iZU0ydsc,3256
89
85
  fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py,sha256=bSMhnrG-JtR0JBnOFy7aWAhD6A-YBB84qm_YnWjc7pA,2180
86
+ fusion_bench/method/doge_ta/__init__.py,sha256=dixO0i5fmhgC_W2_DAQ4PzYnkMCZX5D8tDz84soqQ-Q,59
87
+ fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py,sha256=UUSldRPBxHVOfkMM7ZwqZay5Wjc6XQ3Vy9PgyqV_TZo,1311
88
+ fusion_bench/method/doge_ta/doge_ta.py,sha256=ec0qIq3F72nhbCVlfqdk1PYFM7QIlfMofeVFVvmDKiE,13785
89
+ fusion_bench/method/doge_ta/layer_wise_adamerging.py,sha256=rLk3Nep5d6wMUNCp6q7pC7L0pfBvUwGBIuiGM7CQOf4,9780
90
90
  fusion_bench/method/fisher_merging/__init__.py,sha256=KWsjrtxKkPYwcUA5rB_6UNIqvesqk2NJw5AY_1ztLVE,225
91
91
  fusion_bench/method/fisher_merging/clip_fisher_merging.py,sha256=QCutGqjkfW3OWETPZsCChqLRAhvfJp4QKD9TGSpTyV0,7635
92
92
  fusion_bench/method/fisher_merging/fisher_merging.py,sha256=CPU-tJiDv9FCIBYl7Pn0zA5cdRB1Md5kWchRDlJgly0,20456
@@ -109,9 +109,9 @@ fusion_bench/method/mixture_of_experts/__init__.py,sha256=r95iu1-3tgIUP7sWuAbLuq
109
109
  fusion_bench/method/mixture_of_experts/mixtral_merging.py,sha256=-n1CLP1o08VyMSfaTq42kRutbw-cFDSCWHTu0iNh6ok,4237
110
110
  fusion_bench/method/mixture_of_experts/mixtral_upcycling.py,sha256=tQYAeS8MLFEfH3zDFfNZrML7lRnpGLN-HquQvjPtHNw,11208
111
111
  fusion_bench/method/opcm/__init__.py,sha256=0QcltOnjIYV1XEPDEagChLixLAhjiBnYwfWK00am29k,202
112
- fusion_bench/method/opcm/opcm.py,sha256=USPPMFFVQ9UbcGvvK1573tgkO1kgcrhA5jzKdbNTy9g,10693
113
- fusion_bench/method/opcm/task_arithmetic.py,sha256=SNuuSyzHqvOT_e3i0z0MHNWaMP6xnDdkI9c2t1OcxO4,4328
114
- fusion_bench/method/opcm/ties_merging.py,sha256=38ogIysnRfePhB9SAfr1BPwtHyM8gEdhU2td_yTiB2g,6080
112
+ fusion_bench/method/opcm/opcm.py,sha256=-sqfK5q_-yr_3YWigmXKVYRP1J7swHOR9eGMMzu1Dgw,11445
113
+ fusion_bench/method/opcm/task_arithmetic.py,sha256=YvtsWkjtnk7E3C4_xNr--uQWjQhoDZZB-klSx81_tGw,4824
114
+ fusion_bench/method/opcm/ties_merging.py,sha256=-N3i7eMbhK95qyJsmmNMKNmPCkgGHGFa423a52cgi6g,6868
115
115
  fusion_bench/method/opcm/utils.py,sha256=_q7yy3ENNFUh1qUd5J5DThRL4J1tIxEcknCO2AKmeYM,2102
116
116
  fusion_bench/method/opcm/weight_average.py,sha256=JfQoIU5J1jvrNKpO9k_t4Zj0y8PtteIfyoSQWx1yg2k,4379
117
117
  fusion_bench/method/pruning/__init__.py,sha256=3gtmay2bkdIAEGjpAhbY2ztMZOZLKhiJcKV3mCe2H5w,252
@@ -259,8 +259,8 @@ fusion_bench/models/surgery/__init__.py,sha256=tcUSi2m9GzGWfvRDQScIbdEbFBS_35gm9
259
259
  fusion_bench/models/surgery/surgerymodelwrapper.py,sha256=F8jX88K5zVWC6HsfN-nGNkEiPwNrN11ydyQQ1EZHehM,5133
260
260
  fusion_bench/models/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
261
261
  fusion_bench/models/wrappers/ensemble.py,sha256=wIMZMRyXw5boWAm96c4Tiyebs_HDQovKxpGQ8rLnHUQ,6308
262
- fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=ZizBGQtSLKOzMLFAhrMNMcv6ZNdvABTyO7M1-DGHh3c,12316
263
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py,sha256=k335dxzq3ezuYkDVOv4ePi128NVyiHVCW6zyuDRTg30,20689
262
+ fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=KamNaq4DlyxQrOp1i9aQLgA2WX81YD5NhzAQ5GF6rg0,11188
263
+ fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py,sha256=q5Hc4BtLpAawMbxsWJRL-8OR-x7994Jhr9IyN7vKZ9o,16930
264
264
  fusion_bench/models/wrappers/task_wise_fusion.py,sha256=Wn3buQvWw_lihWaKB03_iz34cBPzwBD94kBT6uafWVQ,8404
265
265
  fusion_bench/optim/__init__.py,sha256=lemrcuiA6OLjQkpYm-RP-Ox2MgjngN1ywvCo0NgShlM,61
266
266
  fusion_bench/optim/exception.py,sha256=fMgo1heiqfGhuI5RIbf30BwWSShn5RQiyeb30QtfTI0,1607
@@ -359,6 +359,7 @@ fusion_bench/utils/plot/token_notebook.py,sha256=bsntXf46Zz_RavTxNiB9c3-KvHw7LFw
359
359
  fusion_bench/utils/strenum/__init__.py,sha256=id9ORi1uXrDxhbmVxitJ1KDwLS4H3AAwFpaK5h1cQzw,8531
360
360
  fusion_bench/utils/strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
361
361
  fusion_bench/utils/strenum/_version.py,sha256=6JQRo9LcvODbCOeVFYQb9HNJ_J9XiG_Zbn8ws2A3BV8,18466
362
+ fusion_bench-0.2.12.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
362
363
  fusion_bench_config/README.md,sha256=Lc8YSBJ5oxf9KV5kKDivJ9LRyGuraGQPmBbgbdVA-j4,703
363
364
  fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=GtK3VuD2FOpFHH_1Hi6tlaYpdLE5Cz0nYKP92Ss9G2Y,1164
364
365
  fusion_bench_config/fabric_model_fusion.yaml,sha256=1shmbuC0B9snkFkLErBCiroF-z7UnEHscyEmKBne7Oo,949
@@ -472,7 +473,6 @@ fusion_bench_config/method/pwe_moe_ls_for_clip.yaml,sha256=brs9zYeuXfFnnCoRrSaAY
472
473
  fusion_bench_config/method/simple_average.yaml,sha256=GtMNvt0-qWOevRX2V6fjiYUO2BwDvMw-EcxRMS_PhZQ,53
473
474
  fusion_bench_config/method/task_arithmetic.yaml,sha256=TbpAeTwIX48PFOkZU-Ihuu6U9Y5XHZJGDu7vHLt5FjU,74
474
475
  fusion_bench_config/method/ties_merging.yaml,sha256=N-XyOTEW0JRtyRJizpHqtb1GEIogUU22XSG76QvIvnw,292
475
- fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml,sha256=6R9NRuWmj0oapJ_raMB6R6rZPMckt2JtMLrTQ6HhrFc,77
476
476
  fusion_bench_config/method/ada_svd/clip_vision.yaml,sha256=KDpDpzuNVqqyyqJcL0q-Ml2A7IUqn_-2dOZXs8zHKlU,184
477
477
  fusion_bench_config/method/adamerging/clip.yaml,sha256=fBG7jBBepygKpCbM3fmUeVAr2zzx0g8C21rGGfnEPkA,730
478
478
  fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml,sha256=7FPPMf6lcOD2dlNUbb5JyF3pqJ3D2jmvbWAbW9WGn0Y,546
@@ -493,6 +493,7 @@ fusion_bench_config/method/dare/simple_average.yaml,sha256=oTFSCHul86NTjTtJYK5pN
493
493
  fusion_bench_config/method/dare/task_arithmetic.yaml,sha256=Cvsam89yquamn_GkITT6q8qFKN_Yb5nv8p-XgvnVrgU,134
494
494
  fusion_bench_config/method/dare/ties_merging.yaml,sha256=50mPiRkzLN7gxaIs56sPWkAUSvqvdxjQJ8eVl1yUGOg,418
495
495
  fusion_bench_config/method/dawe/dawe_for_clip.yaml,sha256=8-Z_kwwGCy1AO4brW-R_pe8oJ0yqoD4WCLI9ZtJ4KOo,1026
496
+ fusion_bench_config/method/doge_ta/doge_ta.yaml,sha256=6R9NRuWmj0oapJ_raMB6R6rZPMckt2JtMLrTQ6HhrFc,77
496
497
  fusion_bench_config/method/ensemble/max_model_predictor.yaml,sha256=fsWuNJwr1ohVB2aJ5L2fsiDLztm5GieE9JS99w--two,56
497
498
  fusion_bench_config/method/ensemble/simple_ensemble.yaml,sha256=bw9FabjhQYNbttsiMgTVd-Z4KIowf050Uy97vKtm2ys,55
498
499
  fusion_bench_config/method/ensemble/weighted_ensemble.yaml,sha256=U_wQXtogtgiqOTszHUgcGNfrKlXD6JrR_HjqNwAkkKo,262
@@ -691,7 +692,8 @@ fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml,sha256=aX0rWw
691
692
  fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml,sha256=mRx-Xx4s6_IBoJJRogIBW4egmqW0wi1kGVWp_YwYVvQ,233
692
693
  fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml,sha256=6Rgfq3cjCRWbAL8Bb-Dkvl9eJP4FKmqewBpokajwYWU,335
693
694
  fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml,sha256=1vaVb059Wh3XMD8MhXD9p5a0zx8mi9HovOcS0k51uK8,1699
694
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml,sha256=GfTY343bt5YtxtUkQxSacrtQav9lT9Y-t1VIL1Chs4k,1726
695
+ fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml,sha256=dwBb3wPfyxH6cx6txBd31OOlrfCvPkM-nIN46FJer-I,1790
696
+ fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml,sha256=2BBuK1uyKL_9uo3X3bScjZiK-PtIiE_7RHj4onK_3R0,1725
695
697
  fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml,sha256=2YBIzqYGluOT2r6dOFpUYE4Cbdd2XoHAUps-kCDxVPQ,185
696
698
  fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml,sha256=W1y3fKY9UTTRyv7nqbIO5DESlQVfNsWlhkHJMUYh7B4,1824
697
699
  fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml,sha256=JUzGOLANW92Y_rljOOZKmwBQvWrJsko_ziayurzHSTY,880
@@ -736,9 +738,8 @@ fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397
736
738
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml,sha256=2AqMiNCRRunLIrssHvFzu1lUzOaQn8uOHM9yjrQq-_A,109
737
739
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml,sha256=iQMj2VpDTe_D8OfCo94w5Ud2MON-EGa0DzVr6UmphrA,436
738
740
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml,sha256=i5Bn8bLl2cgqvrgtIGmoovUfSMehk_m-6C2wwcx5JMU,435
739
- fusion_bench-0.2.11.dist-info/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
740
- fusion_bench-0.2.11.dist-info/METADATA,sha256=AYdGcKXZ6BeHCv1piGgpK1yktQqVga-PjUDxS4RYwog,16780
741
- fusion_bench-0.2.11.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
742
- fusion_bench-0.2.11.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
743
- fusion_bench-0.2.11.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
744
- fusion_bench-0.2.11.dist-info/RECORD,,
741
+ fusion_bench-0.2.12.dist-info/METADATA,sha256=V0KZSil6pMjhZVA3x0wUrW-eskY5DsyclRkiuh8sfec,20085
742
+ fusion_bench-0.2.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
743
+ fusion_bench-0.2.12.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
744
+ fusion_bench-0.2.12.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
745
+ fusion_bench-0.2.12.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,14 +1,14 @@
1
1
  defaults:
2
2
  - /model/flan-t5@models:
3
3
  - flan-t5-base
4
- - flan-t5-base_glue-cola
5
- - flan-t5-base_glue-mnli
6
- - flan-t5-base_glue-mrpc
7
- - flan-t5-base_glue-qnli
8
- - flan-t5-base_glue-qqp
9
- - flan-t5-base_glue-rte
10
- - flan-t5-base_glue-sst2
11
- - flan-t5-base_glue-stsb
4
+ - flan-t5-base_glue-cola_lora-16
5
+ - flan-t5-base_glue-mnli_lora-16
6
+ - flan-t5-base_glue-mrpc_lora-16
7
+ - flan-t5-base_glue-qnli_lora-16
8
+ - flan-t5-base_glue-qqp_lora-16
9
+ - flan-t5-base_glue-rte_lora-16
10
+ - flan-t5-base_glue-sst2_lora-16
11
+ - flan-t5-base_glue-stsb_lora-16
12
12
  _target_: fusion_bench.modelpool.Seq2SeqLMPool
13
13
  _recursive_: false
14
14
 
@@ -0,0 +1,68 @@
1
+ defaults:
2
+ - /model/flan-t5@models:
3
+ - flan-t5-base
4
+ - flan-t5-base_glue-cola
5
+ - flan-t5-base_glue-mnli
6
+ - flan-t5-base_glue-mrpc
7
+ - flan-t5-base_glue-qnli
8
+ - flan-t5-base_glue-qqp
9
+ - flan-t5-base_glue-rte
10
+ - flan-t5-base_glue-sst2
11
+ - flan-t5-base_glue-stsb
12
+ _target_: fusion_bench.modelpool.Seq2SeqLMPool
13
+ _recursive_: false
14
+
15
+ _dataset_loader: fusion_bench.tasks.flan_t5_text_generation.glue_load_dataset.load_glue_dataset
16
+ test_datasets:
17
+ glue-cola:
18
+ _target_: ${..._dataset_loader}
19
+ _recursive_: false
20
+ name: cola
21
+ tokenizer: ${...tokenizer}
22
+ split: validation
23
+ glue-mnli:
24
+ _target_: ${..._dataset_loader}
25
+ _recursive_: false
26
+ name: mnli
27
+ tokenizer: ${...tokenizer}
28
+ split: validation_matched
29
+ glue-mrpc:
30
+ _target_: ${..._dataset_loader}
31
+ _recursive_: false
32
+ name: mrpc
33
+ tokenizer: ${...tokenizer}
34
+ split: validation
35
+ glue-qnli:
36
+ _target_: ${..._dataset_loader}
37
+ _recursive_: false
38
+ name: qnli
39
+ tokenizer: ${...tokenizer}
40
+ split: validation
41
+ glue-qqp:
42
+ _target_: ${..._dataset_loader}
43
+ _recursive_: false
44
+ name: qqp
45
+ tokenizer: ${...tokenizer}
46
+ split: validation
47
+ glue-rte:
48
+ _target_: ${..._dataset_loader}
49
+ _recursive_: false
50
+ name: rte
51
+ tokenizer: ${...tokenizer}
52
+ split: validation
53
+ glue-sst2:
54
+ _target_: ${..._dataset_loader}
55
+ _recursive_: false
56
+ name: sst2
57
+ tokenizer: ${...tokenizer}
58
+ split: validation
59
+ glue-stsb:
60
+ _target_: ${..._dataset_loader}
61
+ _recursive_: false
62
+ name: stsb
63
+ tokenizer: ${...tokenizer}
64
+ split: validation
65
+
66
+ tokenizer:
67
+ _target_: transformers.AutoTokenizer.from_pretrained
68
+ pretrained_model_name_or_path: google/flan-t5-base
@@ -1,2 +0,0 @@
1
- # flake8: noqa F401
2
- from .DOGE_TA import DOGE_TA_Algorithm