fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
  30. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  32. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  33. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  34. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  35. fusion_bench/programs/fabric_fusion_program.py +5 -0
  36. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +1 -1
  39. fusion_bench/utils/lazy_state_dict.py +268 -0
  40. fusion_bench/utils/parameters.py +33 -0
  41. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  42. fusion_bench/utils/type.py +1 -0
  43. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
  44. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
  45. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  46. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  53. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  54. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  55. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  56. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  57. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  58. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  60. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  61. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  63. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  64. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  68. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  72. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  73. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  74. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
  75. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
  76. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
  77. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
  78. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
  79. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
  80. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
  81. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
  82. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  83. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  84. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  85. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  86. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,519 @@
1
+ """
2
+ This script contains the general implementation of the Task Arithmetic method.
3
+
4
+ http://arxiv.org/abs/2212.04089
5
+ """
6
+
7
+ import functools
8
+ import logging
9
+ import os
10
+ from abc import abstractmethod
11
+ from collections import defaultdict
12
+ from copy import deepcopy
13
+ from functools import partial
14
+ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, TypeVar, Union
15
+
16
+ import torch
17
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
18
+ from omegaconf import DictConfig
19
+ from torch import Tensor, nn
20
+ from torch.utils.data import DataLoader
21
+ from tqdm.autonotebook import tqdm
22
+
23
+ from fusion_bench.compat.method import ModelFusionAlgorithm
24
+ from fusion_bench.compat.modelpool import HuggingFaceClipVisionPool, ModelPool
25
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
26
+ from fusion_bench.mixins import CLIPClassificationMixin
27
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
28
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
29
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
30
+ LayerWiseMergedModel,
31
+ get_layer_wise_weights,
32
+ )
33
+ from fusion_bench.utils.data import load_tensor_from_file
34
+ from fusion_bench.utils.type import TorchModelType
35
+
36
+ from .utils import *
37
+
38
+ if TYPE_CHECKING:
39
+ from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
40
+
41
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
42
+ from fusion_bench.modelpool import BaseModelPool
43
+ from fusion_bench.utils import instantiate
44
+ from fusion_bench.utils.data import InfiniteDataLoader
45
+ from fusion_bench.utils.state_dict_arithmetic import (
46
+ state_dict_add,
47
+ state_dict_mul,
48
+ state_dict_sub,
49
+ )
50
+ from fusion_bench.utils.type import StateDictType
51
+
52
+ log = logging.getLogger(__name__)
53
+
54
+
55
+ def projection_simplex_sort(v, z=1):
56
+ # print(v.shape)
57
+ n_features = v.shape[0] # Get the number of elements in v
58
+ u, _ = torch.sort(v, descending=True) # Sort v in descending order
59
+ cssv = torch.cumsum(u, dim=0) - z # Compute cumulative sum and subtract z
60
+ ind = torch.arange(
61
+ 1, n_features + 1, dtype=torch.long, device=v.device
62
+ ) # Create index tensor (1 to n_features)
63
+ cond = u - cssv / ind > 0 # Condition to find rho
64
+ if cond.any(): # Ensure there is at least one valid rho
65
+ rho = ind[cond][-1] # Find the largest index satisfying the condition
66
+ theta = cssv[rho - 1] / rho # Compute the correct threshold theta
67
+ else:
68
+ theta = 0 # Default case when all values are zero or negative
69
+ w = torch.clamp(
70
+ v - theta, min=0
71
+ ) # Compute the projected vector, ensuring non-negativity
72
+ return w
73
+
74
+
75
+ @torch.no_grad()
76
+ def task_arithmetic_merge(
77
+ pretrained_model: nn.Module,
78
+ finetuned_models: List[Dict[str, Tensor]],
79
+ scaling_factor: float,
80
+ inplace: bool = True,
81
+ ) -> nn.Module:
82
+ """
83
+ Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
84
+
85
+ Args:
86
+ pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added.
87
+ finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
88
+ scaling_factor (float): A factor by which the task vectors will be scaled before merging.
89
+ inplace (bool, optional): If True, the pre-trained model will be modified in place.
90
+ If False, a copy of the pre-trained model will be modified. Defaults to True.
91
+
92
+ Returns:
93
+ nn.Module: The pre-trained model with the merged task vectors.
94
+ """
95
+ if not inplace:
96
+ pretrained_model = deepcopy(pretrained_model)
97
+ if isinstance(finetuned_models[0], nn.Module):
98
+ finetuned_models = [
99
+ deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models
100
+ ]
101
+ task_vector: StateDictType = None
102
+ # Calculate the total task vector
103
+ for model in finetuned_models:
104
+ if task_vector is None:
105
+ task_vector = state_dict_sub(
106
+ model,
107
+ pretrained_model.state_dict(keep_vars=True),
108
+ )
109
+ else:
110
+ task_vector = state_dict_add(
111
+ task_vector,
112
+ state_dict_sub(
113
+ model,
114
+ pretrained_model.state_dict(keep_vars=True),
115
+ ),
116
+ )
117
+ # scale the task vector
118
+ task_vector = state_dict_mul(task_vector, scaling_factor)
119
+ # add the task vector to the pretrained model
120
+ state_dict = state_dict_add(
121
+ pretrained_model.state_dict(keep_vars=True), task_vector
122
+ )
123
+ pretrained_model.load_state_dict(state_dict)
124
+ return pretrained_model
125
+
126
+
127
+ def entropy_loss(logits: Tensor, pred=None, eps: float = 1e-8) -> Tensor:
128
+ """
129
+ Compute the entropy loss of a set of logits.
130
+
131
+ Args:
132
+ logits (Tensor): The logits to compute the entropy loss of.
133
+ eps (float): A small value to avoid log(0). Default is 1e-8.
134
+
135
+ Returns:
136
+ Tensor: The entropy loss of the logits.
137
+ """
138
+ # Ensure the logits tensor has 2 dimensions
139
+ assert (
140
+ logits.dim() == 2
141
+ ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
142
+
143
+ # Compute the softmax probabilities
144
+ probs = torch.softmax(logits, dim=-1)
145
+
146
+ # Compute the entropy loss
147
+ return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
148
+
149
+
150
+ class FrankWolfeSoftAlgorithm(
151
+ CLIPClassificationMixin,
152
+ ModelFusionAlgorithm,
153
+ SimpleProfilerMixin,
154
+ ):
155
+ def __init__(
156
+ self,
157
+ max_iters: int,
158
+ dataset_size: int,
159
+ ada_iters: int,
160
+ ada_coeff: float,
161
+ merge_fn: str,
162
+ granularity: str = "task",
163
+ max_num_models: int = 100,
164
+ step_size: float = 0.3,
165
+ tasks: List[str] = [],
166
+ init_weight: str = "",
167
+ ada_loss="entropy_loss",
168
+ **kwargs,
169
+ ):
170
+ """
171
+ Initializes the TaskArithmeticAlgorithm with the given scaling factor.
172
+
173
+ Args:
174
+ step_size (int): The factor by which the task vectors will be scaled before merging.
175
+ """
176
+ self.merge_fn = merge_fn
177
+
178
+ self.init_weight = init_weight
179
+ self.max_iters = max_iters
180
+ self.ada_iters = ada_iters
181
+ self.ada_coeff = ada_coeff
182
+ self.granularity = granularity
183
+ self.tasks = tasks
184
+ self.step_size = step_size
185
+ self.dataset_size = dataset_size
186
+ self.max_num_models = max_num_models
187
+ self.ada_loss = ada_loss
188
+ super().__init__(**kwargs)
189
+
190
+ def on_frank_wolfe_iteration_start(self):
191
+ self.setup_zero_shot_classification_head()
192
+
193
+ @functools.cache
194
+ def get_shuffled_train_loader_iter(self, task: str, batch_size: int = 1):
195
+ # get dataloader kwargs
196
+ dataloader_kwargs = self._dataloader_kwargs.copy()
197
+ dataloader_kwargs["shuffle"] = True
198
+ dataloader_kwargs["batch_size"] = batch_size
199
+
200
+ # get the test dataset
201
+ clip_dataset = CLIPDataset(
202
+ self.modelpool.load_train_dataset(task), self.clip_processor
203
+ )
204
+ # create the dataloader
205
+ loader = DataLoader(clip_dataset, **dataloader_kwargs)
206
+ loader = self.fabric.setup_dataloaders(loader)
207
+ return iter(InfiniteDataLoader(loader))
208
+
209
+ @functools.cache
210
+ def get_shuffled_test_loader_iter(self, task: str, batch_size: int = 1):
211
+ return super().get_shuffled_test_loader_iter(task, batch_size=batch_size)
212
+
213
+ def run_adamerging(self, module):
214
+ use_entropy_loss = self.ada_loss == "entropy_loss"
215
+
216
+ optimizer = torch.optim.Adam([module.merge_weight], lr=1e-3)
217
+ module, optimizer = self.fabric.setup(module, optimizer)
218
+ module.train()
219
+ for step_idx in (
220
+ pbar := tqdm(
221
+ range(self.ada_iters),
222
+ "AdaMerging (2/2)",
223
+ dynamic_ncols=True,
224
+ disable=not self.fabric.is_global_zero,
225
+ )
226
+ ):
227
+ with self.profile("merge weights"):
228
+ module.merge_weights()
229
+
230
+ metrics = {}
231
+ total_loss = None
232
+ tasks = self.modelpool.model_names if self.tasks == [] else self.tasks
233
+ if not use_entropy_loss:
234
+ loss_fn = nn.CrossEntropyLoss()
235
+ for task in tasks:
236
+ with self.profile("data loading"):
237
+ if use_entropy_loss:
238
+ batch = next(
239
+ self.get_shuffled_test_loader_iter(task, batch_size=16)
240
+ )
241
+ else:
242
+ batch = next(
243
+ self.get_shuffled_train_loader_iter(task, batch_size=16)
244
+ )
245
+ # NOTE: The labels are not allowed to be used during test-time adaptation
246
+ images = batch[0]
247
+ with self.profile("forward pass"):
248
+ logits = self.compute_logits(module, images, task)
249
+ if use_entropy_loss:
250
+ loss = entropy_loss(logits)
251
+ else:
252
+ loss = loss_fn(logits, batch[1])
253
+ total_loss = loss if total_loss is None else total_loss + loss
254
+
255
+ optimizer.zero_grad()
256
+ with self.profile("compute grad"):
257
+ self.fabric.backward(total_loss)
258
+
259
+ with self.profile("base optimizer step"):
260
+ optimizer.step()
261
+
262
+ metrics.update({"train/loss": loss.item()})
263
+ self.fabric.log_dict(metrics, step=step_idx)
264
+ pbar.set_postfix(metrics)
265
+ return module
266
+
267
+ def frank_wolfe_iteration(self, merged_model, task):
268
+
269
+ merged_model.train()
270
+ # zero the gradients
271
+ requires_grad_dict = {}
272
+ for name, param in merged_model.named_parameters():
273
+ requires_grad_dict[name] = param.requires_grad
274
+ param.requires_grad = True
275
+ param.grad = None
276
+
277
+ loss_fn = nn.CrossEntropyLoss()
278
+ avg_loss = defaultdict(list)
279
+ log.info(f"Processing task {task}")
280
+ for i in range(self.dataset_size):
281
+ with self.profile("data loading"):
282
+ batch = next(self.get_shuffled_train_loader_iter(task))
283
+ with self.profile("forward pass"):
284
+ logits = self.compute_logits(merged_model, batch[0], task)
285
+ loss = loss_fn(logits, batch[1]) / (
286
+ self.dataset_size * len(self.modelpool.model_names)
287
+ )
288
+ with self.profile("backward pass"):
289
+ loss.backward()
290
+ avg_loss[task].append(loss.item())
291
+
292
+ # calculate the loss
293
+ avg_loss = {
294
+ task: sum(losses) / len(losses) for task, losses in avg_loss.items()
295
+ }
296
+ log.info(
297
+ f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}"
298
+ )
299
+
300
+ gradients = {
301
+ name: param.grad.clone().to("cpu")
302
+ for name, param in merged_model.named_parameters()
303
+ if param.requires_grad
304
+ }
305
+ for name, param in merged_model.named_parameters():
306
+ param.requires_grad = requires_grad_dict[name]
307
+ param.grad = None
308
+ merged_model.eval()
309
+
310
+ return gradients
311
+
312
+ def frank_wolfe_selection(
313
+ self, gradients, checkpoints, model_to_merge_names=[], type="task"
314
+ ):
315
+ assert type in [
316
+ "task",
317
+ "layer",
318
+ ], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']"
319
+ min_inner_product = float("inf")
320
+ min_model = None
321
+ min_model_name = None
322
+ log_dict = {}
323
+ if type == "task":
324
+ for model_name, model_to_merge in checkpoints.items():
325
+ model_to_merge = model_to_merge.to("cpu").state_dict()
326
+ inner_product_sum = 0
327
+ for param_name, param_value in model_to_merge.items():
328
+ # caclulate consine similarity
329
+ grad = gradients[param_name]
330
+ ckpt = model_to_merge[param_name]
331
+ param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
332
+ torch.norm(grad) * torch.norm(ckpt)
333
+ )
334
+ inner_product_sum += param_alignment
335
+ log_dict[model_name] = inner_product_sum.item()
336
+ if (
337
+ inner_product_sum < min_inner_product
338
+ and model_name not in model_to_merge_names
339
+ ):
340
+ min_inner_product = inner_product_sum
341
+ min_model = deepcopy(model_to_merge)
342
+ min_model_name = model_name
343
+ else:
344
+ min_model = {}
345
+ min_inner_product = {}
346
+ min_idx = {}
347
+ min_model_name = {}
348
+ for model_name, model_to_merge in checkpoints.items():
349
+ model_to_merge = model_to_merge.to("cpu").state_dict()
350
+ for param_name, param_value in model_to_merge.items():
351
+ # caclulate consine similarity
352
+ grad = gradients[param_name]
353
+ ckpt = model_to_merge[param_name]
354
+ param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
355
+ torch.norm(grad) * torch.norm(ckpt)
356
+ )
357
+ if (
358
+ param_name not in min_inner_product
359
+ or param_alignment < min_inner_product[param_name]
360
+ ) and model_name not in model_to_merge_names[param_name]:
361
+ min_inner_product[param_name] = param_alignment
362
+ min_model[param_name] = param_value
363
+ min_idx[param_name] = model_name
364
+ min_model_name[param_name] = model_name
365
+ min_inner_product = sum(min_inner_product.values())
366
+ log_dict = {model_name: 0 for model_name in checkpoints.keys()}
367
+ for k in min_idx.values():
368
+ log_dict[k] += 1
369
+
370
+ return min_model, min_model_name, min_inner_product, log_dict
371
+
372
+ def run(self, modelpool: HuggingFaceClipVisionPool):
373
+ log.info("Fusing models using FW merging.")
374
+ self.modelpool = modelpool
375
+ tasks = self.tasks if self.tasks else self.modelpool.model_names
376
+ self.log_hyperparams(self.config)
377
+ self.on_frank_wolfe_iteration_start()
378
+
379
+ assert modelpool.has_pretrained, "Pretrained model is required."
380
+ finetuned_models = {
381
+ name: modelpool.load_model(name)
382
+ for name in modelpool.model_names[: self.max_num_models]
383
+ }
384
+
385
+ if self.init_weight == "base" or self.init_weight == "":
386
+ merged_model = modelpool.load_model("_pretrained_")
387
+ else:
388
+ log.info("Initializing the merged model with the initial weight")
389
+ if isinstance(self.init_weight, str):
390
+ # self.config.weights is a path to a saved tensor
391
+ layer_wise_weight = load_tensor_from_file(self.init_weight)
392
+ else:
393
+ raise ValueError(f"Unsupported weights format: {self.init_weight}")
394
+
395
+ pretrained_model = modelpool.load_model("_pretrained_")
396
+ layerwise_merged_model = LayerWiseMergedModel(
397
+ layer_wise_weight=layer_wise_weight,
398
+ pretrained_model=pretrained_model,
399
+ finetuned_models=list(finetuned_models.values())[: self.max_num_models],
400
+ clamp_weights=False,
401
+ tie_weights=True,
402
+ strict=False,
403
+ ).cuda()
404
+ merged_model = layerwise_merged_model.merge_and_unload()
405
+
406
+ initial_model = modelpool.load_model("_pretrained_")
407
+ self.set_requires_grad(merged_model, initial_model)
408
+ # initial_model.load_state_dict(deepcopy(merged_model.state_dict()))
409
+ # finetuned_models['initial'] = initial_model
410
+ for step_idx in (
411
+ pbar := tqdm(
412
+ range(self.max_iters if not self.is_debug_mode else 1),
413
+ ("[DEBUG MODE] " if self.is_debug_mode else "") + "Frank-Wolfe Merging",
414
+ dynamic_ncols=True,
415
+ )
416
+ ):
417
+ # Find the task vector with the most alignment to the gradient
418
+ models_dict_to_merge = []
419
+ model_to_merge_names = (
420
+ []
421
+ if self.granularity == "task"
422
+ else {name: [] for name in merged_model.state_dict().keys()}
423
+ )
424
+ inner_products = []
425
+ for task in tasks:
426
+ torch.set_grad_enabled(True)
427
+ torch.cuda.empty_cache()
428
+ gradients = self.frank_wolfe_iteration(merged_model.cuda(), task)
429
+ torch.set_grad_enabled(False)
430
+ grad_norm = torch.norm(
431
+ torch.stack([torch.norm(g) for g in gradients.values()])
432
+ )
433
+
434
+ min_model, min_model_name, min_inner_product, log_dict = (
435
+ self.frank_wolfe_selection(
436
+ gradients,
437
+ finetuned_models,
438
+ model_to_merge_names,
439
+ type=self.granularity,
440
+ )
441
+ )
442
+ if self.granularity == "task":
443
+ model_to_merge_names.append(min_model_name)
444
+ else:
445
+ for k, v in min_model_name.items():
446
+ model_to_merge_names[k].append(v)
447
+ models_dict_to_merge.append(min_model)
448
+ inner_products.append(min_inner_product)
449
+
450
+ log.info(f"Task: {task}, Inner Products: {log_dict}")
451
+ if (
452
+ len(models_dict_to_merge) >= len(self.modelpool.model_names)
453
+ or len(models_dict_to_merge) >= self.max_num_models
454
+ ):
455
+ log.info(f"Breaking at {len(models_dict_to_merge)}")
456
+ break
457
+
458
+ # print iteration information
459
+ log.info(
460
+ f"Iteration {step_idx+1}, Task Vector: {model_to_merge_names}, Gradient Norm: {grad_norm:.6f}, Inner Products: {inner_products}"
461
+ )
462
+
463
+ if self.merge_fn == "adamerging":
464
+ models_to_merge = [
465
+ modelpool.load_model("_pretrained_")
466
+ for _ in range(len(models_dict_to_merge))
467
+ ]
468
+ layer_wise_weight = get_layer_wise_weights(
469
+ num_models=len(models_to_merge),
470
+ num_layers=len(
471
+ tuple(
472
+ filter(
473
+ lambda p: p.requires_grad,
474
+ models_to_merge[0].parameters(),
475
+ )
476
+ )
477
+ ),
478
+ init_values=self.ada_coeff if step_idx > 0 else 0.3,
479
+ )
480
+ for model_to_merge, model_to_merge_dict in zip(
481
+ models_to_merge, models_dict_to_merge
482
+ ):
483
+ model_to_merge.load_state_dict(model_to_merge_dict)
484
+ layerwise_merged_model = LayerWiseMergedModel(
485
+ layer_wise_weight=layer_wise_weight,
486
+ pretrained_model=merged_model.to("cpu"),
487
+ finetuned_models=models_to_merge,
488
+ clamp_weights=False,
489
+ tie_weights=True,
490
+ strict=False,
491
+ ).cuda()
492
+ torch.set_grad_enabled(True)
493
+ layerwise_merged_model = self.run_adamerging(layerwise_merged_model)
494
+ torch.set_grad_enabled(False)
495
+ with torch.no_grad():
496
+ merged_model = layerwise_merged_model.merge_and_unload()
497
+ self.set_requires_grad(merged_model, initial_model)
498
+ del (
499
+ models_to_merge,
500
+ layerwise_merged_model,
501
+ layer_wise_weight,
502
+ models_dict_to_merge,
503
+ )
504
+ else:
505
+ step = 2 / (step_idx + 2) * self.step_size if step_idx > 0 else 1
506
+ merged_model = task_arithmetic_merge(
507
+ merged_model.to("cpu"), models_dict_to_merge, 0.3 * step
508
+ )
509
+ del models_dict_to_merge
510
+
511
+ torch.set_grad_enabled(False)
512
+ merged_model = merged_model.cuda().eval()
513
+ return merged_model
514
+
515
+ def set_requires_grad(self, merged_model, initial_model):
516
+ for name, param in initial_model.named_parameters():
517
+ for n, p in merged_model.named_parameters():
518
+ if name == n:
519
+ p.requires_grad = param.requires_grad