fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  30. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  32. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  33. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  34. fusion_bench/programs/fabric_fusion_program.py +5 -0
  35. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  36. fusion_bench/utils/__init__.py +1 -0
  37. fusion_bench/utils/data.py +1 -1
  38. fusion_bench/utils/lazy_state_dict.py +268 -0
  39. fusion_bench/utils/parameters.py +33 -0
  40. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  41. fusion_bench/utils/type.py +1 -0
  42. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
  43. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
  44. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  45. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  53. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  54. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  55. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  56. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  57. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  58. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  60. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  61. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  63. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  64. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  68. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  72. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  73. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  74. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  75. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -67,6 +67,7 @@ _import_structure = {
67
67
  "CLIPTaskWiseGossipAlgorithm",
68
68
  "FlanT5LayerWiseGossipAlgorithm",
69
69
  ],
70
+ "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
70
71
  # plug-and-play model merging methods
71
72
  "concrete_subspace": [
72
73
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -103,6 +104,7 @@ _import_structure = {
103
104
  "RandomPruningForLlama",
104
105
  "MagnitudePruningForLlama",
105
106
  "WandaPruningForLlama",
107
+ "SparseGPTPruningForLlama",
106
108
  ],
107
109
  "sparselo": [
108
110
  "IterativeSparseLoForLlama",
@@ -141,6 +143,7 @@ if TYPE_CHECKING:
141
143
  WeightedEnsembleAlgorithm,
142
144
  )
143
145
  from .fisher_merging import FisherMergingForCLIPVisionModel
146
+ from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
144
147
  from .gossip import (
145
148
  CLIPLayerWiseGossipAlgorithm,
146
149
  CLIPTaskWiseGossipAlgorithm,
@@ -172,6 +175,7 @@ if TYPE_CHECKING:
172
175
  MagnitudeDiffPruningAlgorithm,
173
176
  MagnitudePruningForLlama,
174
177
  RandomPruningForLlama,
178
+ SparseGPTPruningForLlama,
175
179
  WandaPruningForLlama,
176
180
  )
177
181
  from .pwe_moe import (
@@ -0,0 +1,2 @@
1
+ from .fw_hard import FrankWolfeHardAlgorithm
2
+ from .fw_soft import FrankWolfeSoftAlgorithm
@@ -0,0 +1,448 @@
1
+ """
2
+ This script contains the general implementation of the Task Arithmetic method.
3
+
4
+ http://arxiv.org/abs/2212.04089
5
+ """
6
+
7
+ import functools
8
+ import logging
9
+ import os
10
+ from abc import abstractmethod
11
+ from collections import defaultdict
12
+ from copy import deepcopy
13
+ from functools import partial
14
+ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, TypeVar, Union
15
+
16
+ import torch
17
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
18
+ from omegaconf import DictConfig
19
+ from torch import Tensor, nn
20
+ from torch.utils.data import DataLoader
21
+ from tqdm.autonotebook import tqdm
22
+
23
+ from fusion_bench.compat.method import ModelFusionAlgorithm
24
+ from fusion_bench.compat.modelpool import HuggingFaceClipVisionPool, ModelPool
25
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
26
+ from fusion_bench.mixins import CLIPClassificationMixin
27
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
28
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
29
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
30
+ LayerWiseMergedModel,
31
+ get_layer_wise_weights,
32
+ )
33
+ from fusion_bench.utils.data import load_tensor_from_file
34
+ from fusion_bench.utils.type import TorchModelType
35
+
36
+ from .utils import *
37
+
38
+ if TYPE_CHECKING:
39
+ from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
40
+
41
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
42
+ from fusion_bench.modelpool import BaseModelPool
43
+ from fusion_bench.utils import instantiate
44
+ from fusion_bench.utils.data import InfiniteDataLoader
45
+ from fusion_bench.utils.state_dict_arithmetic import (
46
+ state_dict_add,
47
+ state_dict_mul,
48
+ state_dict_sub,
49
+ )
50
+ from fusion_bench.utils.type import StateDictType
51
+
52
+ log = logging.getLogger(__name__)
53
+
54
+
55
+ @torch.no_grad()
56
+ def task_arithmetic_merge(
57
+ pretrained_model: nn.Module,
58
+ finetuned_models: List[Dict[str, Tensor]],
59
+ scaling_factor: float,
60
+ inplace: bool = True,
61
+ ) -> nn.Module:
62
+ """
63
+ Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
64
+
65
+ Args:
66
+ pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added.
67
+ finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
68
+ scaling_factor (float): A factor by which the task vectors will be scaled before merging.
69
+ inplace (bool, optional): If True, the pre-trained model will be modified in place.
70
+ If False, a copy of the pre-trained model will be modified. Defaults to True.
71
+
72
+ Returns:
73
+ nn.Module: The pre-trained model with the merged task vectors.
74
+ """
75
+ if not inplace:
76
+ pretrained_model = deepcopy(pretrained_model)
77
+ if isinstance(finetuned_models[0], nn.Module):
78
+ finetuned_models = [
79
+ deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models
80
+ ]
81
+ task_vector: StateDictType = None
82
+ # Calculate the total task vector
83
+ for model in finetuned_models:
84
+ if task_vector is None:
85
+ task_vector = state_dict_sub(
86
+ model,
87
+ pretrained_model.state_dict(keep_vars=True),
88
+ )
89
+ else:
90
+ task_vector = state_dict_add(
91
+ task_vector,
92
+ state_dict_sub(
93
+ model,
94
+ pretrained_model.state_dict(keep_vars=True),
95
+ ),
96
+ )
97
+ # scale the task vector
98
+ task_vector = state_dict_mul(task_vector, scaling_factor)
99
+ # add the task vector to the pretrained model
100
+ state_dict = state_dict_add(
101
+ pretrained_model.state_dict(keep_vars=True), task_vector
102
+ )
103
+ pretrained_model.load_state_dict(state_dict)
104
+ return pretrained_model
105
+
106
+
107
+ @torch.no_grad()
108
+ def ties_merge(
109
+ pretrained_model: nn.Module,
110
+ finetuned_models: List[Dict[str, Tensor]],
111
+ scaling_factor: float,
112
+ threshold: float,
113
+ ) -> nn.Module:
114
+ remove_keys = []
115
+ merge_func = "sum"
116
+ if isinstance(finetuned_models[0], nn.Module):
117
+ finetuned_models = [
118
+ deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models
119
+ ]
120
+
121
+ ptm_check = pretrained_model.state_dict(keep_vars=True)
122
+
123
+ # Compute the task vectors
124
+ flat_ft = torch.vstack(
125
+ [state_dict_to_vector(check, remove_keys) for check in finetuned_models]
126
+ )
127
+ flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
128
+ tv_flat_checks = flat_ft - flat_ptm
129
+
130
+ # Perform TIES Merging
131
+ merged_tv = ties_merging(
132
+ tv_flat_checks,
133
+ reset_thresh=threshold,
134
+ merge_func=merge_func,
135
+ )
136
+ merged_check = flat_ptm + scaling_factor * merged_tv
137
+ merged_state_dict = vector_to_state_dict(
138
+ merged_check, ptm_check, remove_keys=remove_keys
139
+ )
140
+
141
+ # Load the merged state dict into the pretrained model
142
+ pretrained_model.load_state_dict(merged_state_dict)
143
+ return pretrained_model
144
+
145
+
146
+ def entropy_loss(logits: Tensor, pred=None, eps: float = 1e-8) -> Tensor:
147
+ """
148
+ Compute the entropy loss of a set of logits.
149
+
150
+ Args:
151
+ logits (Tensor): The logits to compute the entropy loss of.
152
+ eps (float): A small value to avoid log(0). Default is 1e-8.
153
+
154
+ Returns:
155
+ Tensor: The entropy loss of the logits.
156
+ """
157
+ # Ensure the logits tensor has 2 dimensions
158
+ assert (
159
+ logits.dim() == 2
160
+ ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
161
+
162
+ # Compute the softmax probabilities
163
+ probs = torch.softmax(logits, dim=-1)
164
+
165
+ # Compute the entropy loss
166
+ return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
167
+
168
+
169
+ class FrankWolfeHardAlgorithm(
170
+ CLIPClassificationMixin,
171
+ ModelFusionAlgorithm,
172
+ SimpleProfilerMixin,
173
+ ):
174
+
175
+ def __init__(
176
+ self,
177
+ merge_fn: str,
178
+ step_size: float,
179
+ max_iters: int,
180
+ dataset_size: int,
181
+ tasks: List[str] = [],
182
+ granularity: str = "task",
183
+ max_num_models: int = 100,
184
+ loss_fn: str = "cross_entropy",
185
+ init_weight: str = "",
186
+ scaling_factor: float = 1.0,
187
+ threshold: int = 20,
188
+ **kwargs,
189
+ ):
190
+ """
191
+ Initializes the TaskArithmeticAlgorithm with the given scaling factor.
192
+
193
+ Args:
194
+ scaling_factor (int): The factor by which the task vectors will be scaled before merging.
195
+ """
196
+ self.merger = merge_fn
197
+ if merge_fn == "task_arithmetic":
198
+ self.merge_fn = task_arithmetic_merge
199
+ elif merge_fn == "ties":
200
+ self.merge_fn = partial(ties_merge, threshold=threshold)
201
+ # elif merge_fn == "concrete_ta":
202
+ # self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP(
203
+ # instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml"))
204
+ # )
205
+ else:
206
+ raise ValueError(f"Unsupported merge_fn: {merge_fn}")
207
+ self.scaling_factor = scaling_factor
208
+
209
+ self.init_weight = init_weight
210
+ self.step_size = step_size
211
+ self.max_iters = max_iters
212
+ self.granularity = granularity
213
+ self.loss_fn = loss_fn
214
+ self.tasks = tasks
215
+ self.dataset_size = dataset_size
216
+ self.max_num_models = max_num_models
217
+ super().__init__(**kwargs)
218
+
219
+ def on_frank_wolfe_iteration_start(self):
220
+ self.setup_zero_shot_classification_head()
221
+
222
+ @functools.cache
223
+ def get_shuffled_loader_iter(self, task: str):
224
+ if self.loss_fn == "cross_entropy":
225
+ # get dataloader kwargs
226
+ dataloader_kwargs = self._dataloader_kwargs.copy()
227
+ dataloader_kwargs["shuffle"] = True
228
+ dataloader_kwargs["batch_size"] = 1
229
+
230
+ # get the test dataset
231
+ clip_dataset = CLIPDataset(
232
+ self.modelpool.load_train_dataset(task), self.clip_processor
233
+ )
234
+ # create the dataloader
235
+ loader = DataLoader(clip_dataset, **dataloader_kwargs)
236
+ loader = self.fabric.setup_dataloaders(loader)
237
+ return iter(InfiniteDataLoader(loader))
238
+ elif self.loss_fn == "entropy":
239
+ return super().get_shuffled_test_loader_iter(
240
+ task,
241
+ batch_size=1,
242
+ )
243
+ else:
244
+ raise ValueError(f"Unsupported loss function: {self.loss_fn}")
245
+
246
+ def frank_wolfe_iteration(self, merged_model):
247
+
248
+ merged_model.train()
249
+ # zero the gradients
250
+ for name, param in merged_model.named_parameters():
251
+ param.requires_grad = True
252
+ param.grad = None
253
+
254
+ if self.loss_fn == "cross_entropy":
255
+ loss_fn = nn.CrossEntropyLoss()
256
+ elif self.loss_fn == "entropy":
257
+ loss_fn = entropy_loss
258
+ avg_loss = defaultdict(list)
259
+ tasks = self.tasks if self.tasks else self.modelpool.model_names
260
+ for task in tasks:
261
+ log.info(f"Processing task {task}")
262
+ for _ in range(self.dataset_size):
263
+ with self.profile("data loading"):
264
+ batch = next(self.get_shuffled_loader_iter(task))
265
+ with self.profile("forward pass"):
266
+ logits = self.compute_logits(merged_model, batch[0], task)
267
+ loss = loss_fn(logits, batch[1]) / (
268
+ self.dataset_size * len(self.modelpool.model_names)
269
+ )
270
+ with self.profile("backward pass"):
271
+ # self.fabric.backward(loss, retain_graph=True)
272
+ loss.backward()
273
+ avg_loss[task].append(loss.item())
274
+
275
+ # calculate the loss
276
+ avg_loss = {
277
+ task: sum(losses) / len(losses) for task, losses in avg_loss.items()
278
+ }
279
+ log.info(
280
+ f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}"
281
+ )
282
+
283
+ gradients = {
284
+ name: param.grad.clone().to("cpu")
285
+ for name, param in merged_model.named_parameters()
286
+ if param.requires_grad
287
+ }
288
+ for name, param in merged_model.named_parameters():
289
+ param.grad = None
290
+ merged_model.eval()
291
+
292
+ return gradients
293
+
294
+ def frank_wolfe_selection(
295
+ self, gradients, checkpoints, model_to_merge_names={}, type="task"
296
+ ):
297
+ assert type in [
298
+ "task",
299
+ "layer",
300
+ ], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']"
301
+ min_inner_product = float("inf")
302
+ min_model = None
303
+ min_model_name = None
304
+ log_dict = {}
305
+ if type == "task":
306
+ for model_name, model_to_merge in checkpoints.items():
307
+ model_to_merge = model_to_merge.to("cpu").state_dict()
308
+ inner_product_sum = 0
309
+ for param_name, param_value in model_to_merge.items():
310
+ # caclulate consine similarity
311
+ grad = gradients[param_name]
312
+ ckpt = model_to_merge[param_name]
313
+ param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
314
+ torch.norm(grad) * torch.norm(ckpt)
315
+ )
316
+ inner_product_sum += param_alignment
317
+ log_dict[model_name] = inner_product_sum.item()
318
+ if (
319
+ inner_product_sum < min_inner_product
320
+ and model_name not in model_to_merge_names
321
+ ):
322
+ min_inner_product = inner_product_sum
323
+ min_model = deepcopy(model_to_merge)
324
+ min_model_name = model_name
325
+ else:
326
+ min_model = {}
327
+ min_inner_product = {}
328
+ min_idx = {}
329
+ min_model_name = {}
330
+ for model_name, model_to_merge in checkpoints.items():
331
+ model_to_merge = model_to_merge.to("cpu").state_dict()
332
+ for param_name, param_value in model_to_merge.items():
333
+ # caclulate consine similarity
334
+ grad = gradients[param_name]
335
+ ckpt = model_to_merge[param_name]
336
+ param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
337
+ torch.norm(grad) * torch.norm(ckpt)
338
+ )
339
+ if (
340
+ param_name not in min_inner_product
341
+ or param_alignment < min_inner_product[param_name]
342
+ ) and model_name not in model_to_merge_names[param_name]:
343
+ min_inner_product[param_name] = param_alignment
344
+ # if min_inner_product[param_name] < 0:
345
+ min_model[param_name] = param_value
346
+ min_idx[param_name] = model_name
347
+ min_model_name[param_name] = model_name
348
+ # else:
349
+ # min_model[param_name] = torch.zeros_like(param_value)
350
+ min_inner_product = sum(min_inner_product.values())
351
+ log_dict = {model_name: 0 for model_name in checkpoints.keys()}
352
+ for k in min_idx.values():
353
+ log_dict[k] += 1
354
+
355
+ return min_model, min_model_name, min_inner_product, log_dict
356
+
357
+ def run(self, modelpool: HuggingFaceClipVisionPool):
358
+ log.info("Fusing models using FW merging.")
359
+ self.modelpool = modelpool
360
+ self.log_hyperparams(self.config)
361
+ self.on_frank_wolfe_iteration_start()
362
+
363
+ assert modelpool.has_pretrained, "Pretrained model is required."
364
+ finetuned_models = {
365
+ name: modelpool.load_model(name)
366
+ for name in modelpool.model_names[: self.max_num_models]
367
+ }
368
+ pretrained_model = modelpool.load_model("_pretrained_")
369
+
370
+ if self.init_weight:
371
+ if self.init_weight == "base":
372
+ log.info("Initializing the merged model with the base model")
373
+ merged_model = pretrained_model
374
+ else:
375
+ log.info("Initializing the merged model with the initial weight")
376
+ if isinstance(self.init_weight, str):
377
+ # self.config.weights is a path to a saved tensor
378
+ layer_wise_weight = load_tensor_from_file(self.init_weight)
379
+ else:
380
+ raise ValueError(f"Unsupported weights format: {self.init_weight}")
381
+
382
+ merged_model = LayerWiseMergedModel(
383
+ layer_wise_weight=layer_wise_weight,
384
+ pretrained_model=modelpool.load_model("_pretrained_"),
385
+ finetuned_models=list(finetuned_models.values()),
386
+ clamp_weights=False,
387
+ tie_weights=True,
388
+ strict=False,
389
+ ).cuda()
390
+ merged_model = merged_model.merge_and_unload()
391
+ else:
392
+ log.info("Initializing the merged model with merge function")
393
+ merged_model = self.merge_fn(
394
+ pretrained_model=modelpool.load_model("_pretrained_"),
395
+ finetuned_models=list(finetuned_models.values()),
396
+ scaling_factor=self.scaling_factor,
397
+ ).cuda()
398
+ # merged_model = self.fabric.setup(merged_model)
399
+
400
+ initial_model = modelpool.load_model("_pretrained_")
401
+ initial_model.load_state_dict(deepcopy(merged_model.state_dict()))
402
+ finetuned_models["initial"] = initial_model
403
+ for step_idx in (
404
+ pbar := tqdm(
405
+ range(self.max_iters if not self.is_debug_mode else 1),
406
+ ("[DEBUG MODE] " if self.is_debug_mode else "") + "Frank-Wolfe Merging",
407
+ dynamic_ncols=True,
408
+ )
409
+ ):
410
+ torch.cuda.empty_cache()
411
+ torch.set_grad_enabled(True)
412
+ gradients = self.frank_wolfe_iteration(merged_model.cuda())
413
+ torch.set_grad_enabled(False)
414
+ grad_norm = torch.norm(
415
+ torch.stack([torch.norm(g) for g in gradients.values()])
416
+ )
417
+
418
+ model_to_merge_names = (
419
+ []
420
+ if self.granularity == "task"
421
+ else {name: [] for name in merged_model.state_dict().keys()}
422
+ )
423
+ min_model, min_model_name, min_alignment, chosen_model = (
424
+ self.frank_wolfe_selection(
425
+ gradients,
426
+ finetuned_models,
427
+ model_to_merge_names=model_to_merge_names,
428
+ type=self.granularity,
429
+ )
430
+ )
431
+
432
+ # Determine step size
433
+ step = 2 / (step_idx + 2) * self.step_size
434
+
435
+ # print iteration information
436
+ log.info(
437
+ f"Iteration {step_idx+1}, Task Vector: {min_model_name}, Gradient Norm: {grad_norm:.6f}, Inner Products: {min_alignment:.6f}, Chosen Model: {chosen_model}"
438
+ )
439
+
440
+ merged_model = self.merge_fn(
441
+ pretrained_model=merged_model.to("cpu"),
442
+ finetuned_models=[min_model],
443
+ scaling_factor=step * self.scaling_factor,
444
+ )
445
+
446
+ torch.set_grad_enabled(False)
447
+ merged_model = merged_model.cuda().eval()
448
+ return merged_model