fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/dataset/__init__.py +2 -0
  4. fusion_bench/dataset/clip_dataset.py +4 -72
  5. fusion_bench/dataset/image_dataset.py +44 -18
  6. fusion_bench/method/base_algorithm.py +4 -0
  7. fusion_bench/method/dop/dop.py +0 -22
  8. fusion_bench/method/dop/dop_general.py +489 -0
  9. fusion_bench/method/dop/utils.py +24 -4
  10. fusion_bench/method/emr_merging/__init__.py +1 -0
  11. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  12. fusion_bench/method/emr_merging/utils.py +162 -0
  13. fusion_bench/method/opcm/opcm.py +6 -2
  14. fusion_bench/method/opcm/opcm_general.py +356 -0
  15. fusion_bench/method/opcm/utils.py +1 -4
  16. fusion_bench/method/simple_average.py +52 -18
  17. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  18. fusion_bench/mixins/lightning_fabric.py +108 -3
  19. fusion_bench/mixins/serialization.py +1 -1
  20. fusion_bench/modelpool/base_pool.py +37 -1
  21. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  22. fusion_bench/models/hf_clip.py +20 -0
  23. fusion_bench/models/modulator/__init__.py +1 -0
  24. fusion_bench/models/modulator/base.py +123 -0
  25. fusion_bench/models/parameter_dict.py +119 -29
  26. fusion_bench/models/utils.py +190 -2
  27. fusion_bench/models/wrappers/switch.py +90 -0
  28. fusion_bench/programs/base_program.py +6 -0
  29. fusion_bench/programs/fabric_fusion_program.py +4 -0
  30. fusion_bench/scripts/cli.py +19 -8
  31. fusion_bench/taskpool/image_classification.py +270 -0
  32. fusion_bench/utils/__init__.py +18 -1
  33. fusion_bench/utils/data.py +1 -1
  34. fusion_bench/utils/dict.py +19 -0
  35. fusion_bench/utils/dtype.py +19 -0
  36. fusion_bench/utils/misc.py +1 -0
  37. fusion_bench/utils/packages.py +4 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  39. fusion_bench/utils/tensorboard.py +21 -3
  40. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  41. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
  42. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  43. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  44. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  45. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  46. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  47. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  48. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  49. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  50. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  51. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,489 @@
1
+ """
2
+ Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity. NeurIPS, 2025.
3
+ (Architecture agnostic implementation)
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import time
10
+ from copy import deepcopy
11
+ from pathlib import Path
12
+ from typing import Dict, List, Literal, Optional, Tuple, cast
13
+
14
+ import lightning as L
15
+ import numpy as np
16
+ import torch
17
+ from omegaconf import DictConfig
18
+ from torch import Tensor, nn
19
+ from torch.autograd import Variable
20
+ from tqdm.auto import tqdm
21
+
22
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
23
+ from fusion_bench.method.simple_average import simple_average
24
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
25
+ from fusion_bench.models.utils import named_leaf_modules
26
+ from fusion_bench.utils import seed_everything_by_time
27
+ from fusion_bench.utils.dtype import dtype_support_svd
28
+ from fusion_bench.utils.json import save_to_json
29
+ from fusion_bench.utils.packages import is_ray_available
30
+
31
+ from .min_norm_solvers import MinNormSolver, gradient_normalizers
32
+ from .utils import is_leaf_module, print_params, svd
33
+
34
+ log = logging.getLogger(__name__)
35
+
36
+
37
+ @auto_register_config
38
+ class DOPMerging(LightningFabricMixin, SimpleProfilerMixin, BaseAlgorithm):
39
+ """
40
+ Dual Projections for Balancing Stability and Plasticity (DOP) merging algorithm.
41
+
42
+ This method implements continual model merging without data by using dual projections
43
+ in the SVD space to balance stability (preserving previously merged model's knowledge)
44
+ and plasticity (incorporating new model's knowledge).
45
+
46
+ The algorithm merges models sequentially, optimizing each merge using gradient descent
47
+ with optional multi-gradient descent algorithm (MGDA) for better trade-offs.
48
+
49
+ Reference:
50
+ Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity.
51
+ NeurIPS, 2025.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ seed: Optional[int] = None,
57
+ shuffle_order: bool = False,
58
+ save_on_every_step: bool = True,
59
+ evaluate_on_every_step: bool = False,
60
+ lr: float = 1e-4,
61
+ num_steps: int = 200,
62
+ mgda: bool = True,
63
+ ema: bool = True,
64
+ ema_beta: float = 0.99,
65
+ alpha: float = None,
66
+ svd_epsilon: float = 1.0,
67
+ svd_proj_space: str = "uv",
68
+ exclude_keys: List[str] | None = None,
69
+ num_ray_actors: int = 0,
70
+ **kwargs,
71
+ ):
72
+ """
73
+ Initialize the DOP merging algorithm.
74
+
75
+ Args:
76
+ seed: Random seed for reproducibility. If None, uses time-based seeding.
77
+ shuffle_order: Whether to shuffle the order of models before merging.
78
+ save_on_every_step: Whether to save the model after each merge step.
79
+ evaluate_on_every_step: Whether to evaluate the model after each merge step.
80
+ lr: Learning rate for the optimization process.
81
+ num_steps: Number of optimization steps per layer.
82
+ mgda: Whether to use Multi-Gradient Descent Algorithm for balancing losses.
83
+ ema: Whether to use exponential moving average for MGDA weights.
84
+ ema_beta: EMA decay rate for MGDA weights (only used if ema=True).
85
+ alpha: Weight for balancing between stability and plasticity (0-1).
86
+ When mgda=False, used as a fixed weight. When mgda=True with ema=True,
87
+ used as initial weight.
88
+ svd_epsilon: Threshold for SVD rank selection (0-1). Determines how much
89
+ variance to preserve in the projection space.
90
+ svd_proj_space: SVD projection space to use: 'u', 'v', or 'uv' (both).
91
+ exclude_keys: List of module names to exclude from optimization.
92
+ num_ray_actors: Number of Ray actors to use for parallel processing. If 0, ray is not used.
93
+ **kwargs: Additional arguments passed to BaseAlgorithm.
94
+ """
95
+ self.lr = lr
96
+ self.num_steps = num_steps
97
+ self.mgda = mgda
98
+ self.ema = ema
99
+ self.ema_beta = ema_beta
100
+ self.alpha = alpha
101
+ self.svd_epsilon = svd_epsilon
102
+ self.svd_proj_space = svd_proj_space
103
+ self.seed = seed
104
+ self.shuffle_order = shuffle_order
105
+ self.save_on_every_step = save_on_every_step
106
+ self.evaluate_on_every_step = evaluate_on_every_step
107
+
108
+ if exclude_keys is None:
109
+ exclude_keys = []
110
+ self.exclude_keys = exclude_keys
111
+
112
+ assert (
113
+ self.svd_epsilon >= 0 and self.svd_epsilon <= 1
114
+ ), "The svd_epsilon should be in the range of [0, 1]"
115
+ assert (
116
+ self.alpha >= 0 and self.alpha <= 1
117
+ ), "The alpha should be in the range of [0, 1]"
118
+ super().__init__(**kwargs)
119
+
120
+ def run(self, modelpool: BaseModelPool):
121
+ """
122
+ Execute the DOP merging algorithm on a pool of models.
123
+
124
+ Merges models sequentially, where each new model is merged with the
125
+ previously merged result. The first model is used as-is, and subsequent
126
+ models are merged using layer-wise optimization.
127
+
128
+ Args:
129
+ modelpool: The model pool containing models to merge and the pretrained model.
130
+
131
+ Returns:
132
+ The final merged model after sequentially merging all models in the pool.
133
+ """
134
+ if self.num_ray_actors > 0:
135
+ if is_ray_available():
136
+ import ray
137
+ from ray.util.actor_pool import ActorPool
138
+
139
+ if not ray.is_initialized():
140
+ ray.init()
141
+
142
+ # create actors
143
+ if self.fabric.device.type == "cuda":
144
+ actor_options = {"num_gpus": 1}
145
+ else:
146
+ actor_options = {}
147
+ self.ray_actor_pool = ActorPool(
148
+ [
149
+ DOPMergingActor.options(**actor_options).remote(**self.config)
150
+ for _ in range(self.num_ray_actors)
151
+ ]
152
+ )
153
+ else:
154
+ raise ImportError(
155
+ "Ray is not installed. Please install ray to use this feature. Install with `pip install 'ray[default]'`."
156
+ )
157
+
158
+ model_names = modelpool.model_names
159
+ if self.shuffle_order:
160
+ random.shuffle(model_names)
161
+
162
+ pretrained_model = modelpool.load_pretrained_model()
163
+
164
+ merged_model = None
165
+ for model_idx, model_name in enumerate(model_names):
166
+ print(
167
+ f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
168
+ )
169
+ if model_idx == 0:
170
+ print("Using the first model as the initial merged model.")
171
+ with self.profile("loading models"):
172
+ merged_model = modelpool.load_model(model_names[0])
173
+ else:
174
+ with self.profile("loading models"):
175
+ finetuned_model = modelpool.load_model(model_name)
176
+ with self.profile("DOP merging"):
177
+ merged_model = self._layer_wise_optimize(
178
+ model_names=["merged", model_name],
179
+ pretrained_model=deepcopy(pretrained_model),
180
+ finetuned_models={
181
+ "merged": merged_model,
182
+ model_name: finetuned_model,
183
+ },
184
+ model_idx=model_idx,
185
+ )
186
+ del finetuned_model
187
+
188
+ self.print_profile_summary()
189
+ return merged_model
190
+
191
+ def _optimize_linear_layer(
192
+ self,
193
+ module_name: str,
194
+ module: nn.Linear,
195
+ finetuned_weights: Dict[str, nn.Linear],
196
+ model_idx: int,
197
+ ):
198
+ if module.weight.requires_grad and module_name not in self.exclude_keys:
199
+ original_dtype = module.weight.dtype
200
+ merged_weight = self._optimize_weight(
201
+ module.weight,
202
+ finetuned_weights,
203
+ module_name,
204
+ model_idx,
205
+ )
206
+ merged_weight = merged_weight.to(dtype=original_dtype)
207
+ else:
208
+ merged_weight = simple_average(list(finetuned_weights.values()))
209
+ return module_name, merged_weight
210
+
211
+ def _layer_wise_optimize(
212
+ self,
213
+ model_names: List[str],
214
+ pretrained_model: nn.Module,
215
+ finetuned_models: Dict[str, nn.Module],
216
+ model_idx: int,
217
+ ):
218
+ """
219
+ Optimize model parameters layer by layer.
220
+
221
+ Iterates through all leaf modules in the pretrained model and merges their weights
222
+ with the corresponding modules in the finetuned models. Linear layers with trainable
223
+ weights (not in exclude_keys) are optimized using gradient descent, while other
224
+ parameters are simply averaged.
225
+
226
+ Args:
227
+ model_names: List of model names to merge (e.g., ['merged', 'new_model']).
228
+ pretrained_model: The base pretrained model (structure modified in-place).
229
+ finetuned_models: Dictionary mapping model names to their finetuned versions.
230
+ model_idx: Index of the current model being merged (for tracking/logging).
231
+
232
+ Returns:
233
+ The pretrained_model with optimized/merged weights from finetuned models.
234
+ """
235
+ for module_name, module in named_leaf_modules(pretrained_model):
236
+ finetuned_modules = {
237
+ model_name: finetuned_models[model_name].get_submodule(module_name)
238
+ for model_name in model_names
239
+ }
240
+ if isinstance(module, nn.Linear):
241
+ # process weight
242
+ finetuned_weights = {
243
+ model_name: finetuned_modules[model_name].weight
244
+ for model_name in model_names
245
+ }
246
+ if self.num_ray_actors == 0:
247
+ _, merged_weight = self._optimize_linear_layer(
248
+ module_name,
249
+ module=module,
250
+ finetuned_weights=finetuned_weights,
251
+ model_idx=model_idx,
252
+ )
253
+ module.weight.data = merged_weight.data
254
+ else:
255
+ if not self.ray_actor_pool.has_free():
256
+ returned_module_name, merged_weight = (
257
+ self.ray_actor_pool.get_next_unordered()
258
+ )
259
+ print(f"merged weight {returned_module_name} from ray actors.")
260
+ pretrained_model.get_submodule(
261
+ returned_module_name
262
+ ).weight.data = merged_weight
263
+ self.ray_actor_pool.submit(
264
+ lambda actor, kwargs: actor._optimize_linear_layer.remote(
265
+ **kwargs
266
+ ),
267
+ {
268
+ "module_name": module_name,
269
+ "module": module,
270
+ "finetuned_weights": finetuned_weights,
271
+ "model_idx": model_idx,
272
+ },
273
+ )
274
+
275
+ # process bias if exists
276
+ if module.bias is not None:
277
+ module.bias.data = simple_average(
278
+ [m.bias for m in finetuned_modules.values()]
279
+ )
280
+ else:
281
+ simple_average(list(finetuned_modules.values()), base_module=module)
282
+
283
+ if self.num_ray_actors > 0:
284
+ while self.ray_actor_pool.has_next():
285
+ module_name, merged_weight = self.ray_actor_pool.get_next_unordered()
286
+ print(f"merged weight {module_name} from ray actors.")
287
+ pretrained_model.get_submodule(module_name).weight.data = merged_weight
288
+
289
+ return pretrained_model
290
+
291
+ def _optimize_weight(
292
+ self,
293
+ pretrained_weight: Tensor,
294
+ finetuned_weights: Dict[str, Tensor],
295
+ module_name: str,
296
+ model_idx: int,
297
+ ):
298
+ """
299
+ Optimize a single weight matrix by balancing projections in SVD space.
300
+
301
+ Performs gradient-based optimization to find merged weights that minimize
302
+ the projection loss in the SVD space of task vectors. Uses either MGDA
303
+ for automatic weight balancing or fixed alpha weighting.
304
+
305
+ The algorithm:
306
+ 1. Computes SVD of each task vector (finetuned - pretrained)
307
+ 2. Projects the difference between merged and finetuned weights onto SVD subspaces
308
+ 3. Optimizes merged weights to minimize projection losses
309
+
310
+ Args:
311
+ pretrained_weight: The original pretrained weight matrix.
312
+ finetuned_weights: Dictionary mapping model names to their finetuned weight matrices.
313
+ module_name: Name of the module being optimized (for logging).
314
+ model_idx: Index of the current model being merged (for tracking).
315
+
316
+ Returns:
317
+ Optimized merged weight matrix on CPU.
318
+ """
319
+ assert (
320
+ self.fabric.world_size == 1
321
+ ), "This algorithm is not currently supported in distributed training"
322
+
323
+ with torch.no_grad():
324
+ # Convert weights to float if original dtype does not support SVD
325
+ original_dtype = pretrained_weight.dtype
326
+ if not dtype_support_svd(original_dtype):
327
+ pretrained_weight = pretrained_weight.float()
328
+ finetuned_weights = {
329
+ model_name: finetuned_weight.float()
330
+ for model_name, finetuned_weight in finetuned_weights.items()
331
+ }
332
+
333
+ # Move weights to the appropriate device
334
+ pretrained_weight = self.fabric.to_device(pretrained_weight.detach())
335
+ finetuned_weights = {
336
+ model_name: self.fabric.to_device(finetuned_weight.detach())
337
+ for model_name, finetuned_weight in finetuned_weights.items()
338
+ }
339
+
340
+ # Initialize merged weight as simple average of finetuned weights
341
+ merged_weight = self.fabric.to_device(
342
+ nn.Parameter(
343
+ simple_average(list(finetuned_weights.values())), requires_grad=True
344
+ )
345
+ )
346
+
347
+ # Compute SVD of the difference between the finetuned and pretrained weights
348
+ proj_u_dict = {}
349
+ proj_v_dict = {}
350
+ proj_s_dict = {}
351
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
352
+ finetuned_tv = finetuned_weight - pretrained_weight
353
+ u, s, v = svd(finetuned_tv, full_matrices=True)
354
+ epsilon = 1.0 if self.svd_epsilon > 1.0 else self.svd_epsilon
355
+ cumsum_ratio = s.cumsum(dim=0) / s.sum()
356
+ split_rank = torch.searchsorted(cumsum_ratio, epsilon).item()
357
+ u_main = u[:, :split_rank]
358
+ v_main = v[:, :split_rank]
359
+ s_main = s[:split_rank]
360
+ proj_u_dict[i] = u_main
361
+ proj_v_dict[i] = v_main
362
+ proj_s_dict[i] = s_main
363
+
364
+ if self.mgda:
365
+ if self.ema:
366
+ ema_sol = [self.alpha, 1 - self.alpha]
367
+ # This is multiple-gradient descent algorithm (MGDA) optimization
368
+ optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
369
+ all_losses = [[], []]
370
+ all_alphas = [[], []]
371
+ for step_idx in tqdm(
372
+ range(self.num_steps),
373
+ desc=f"Optimizing {module_name} weight",
374
+ disable=self.num_ray_actors > 0,
375
+ ):
376
+ # Scaling the loss functions based on the algorithm choice
377
+ loss_data = {}
378
+ grads = {}
379
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
380
+ proj_u = proj_u_dict[i]
381
+ proj_v = proj_v_dict[i]
382
+ proj_s = proj_s_dict[i]
383
+ delta_tv = merged_weight - finetuned_weight
384
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
385
+ loss_data[i] = float(loss_i.data)
386
+
387
+ all_losses[i].append(float(loss_i.data))
388
+
389
+ optimizer.zero_grad()
390
+ loss_i.backward()
391
+ grads[i] = Variable(
392
+ merged_weight.grad.data.clone(), requires_grad=False
393
+ )
394
+
395
+ # Normalize all gradients
396
+ gn = gradient_normalizers(
397
+ grads=grads, losses=loss_data, normalization_type="loss"
398
+ )
399
+ for i, _ in enumerate(finetuned_weights.values()):
400
+ grads[i] = grads[i] / float(gn[i])
401
+
402
+ # Frank-Wolfe iteration to compute scales.
403
+ sol, min_norm = MinNormSolver.find_min_norm_element(
404
+ [[grads[i]] for i in range(len(finetuned_weights.values()))]
405
+ )
406
+
407
+ if self.ema:
408
+ ema_sol = [
409
+ self.ema_beta * ema_sol[i] + (1 - self.ema_beta) * float(sol[i])
410
+ for i in range(len(sol))
411
+ ]
412
+ sol = ema_sol
413
+ all_alphas[0].append(ema_sol[0])
414
+ all_alphas[1].append(ema_sol[1])
415
+
416
+ # Scaled back-propagation
417
+ loss = 0
418
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
419
+ # Comptue gradients of each loss function wrt parameters
420
+ proj_u = proj_u_dict[i]
421
+ proj_v = proj_v_dict[i]
422
+ proj_s = proj_s_dict[i]
423
+ delta_tv = merged_weight - finetuned_weight
424
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
425
+ loss += float(sol[i]) * loss_i
426
+
427
+ optimizer.zero_grad()
428
+ loss.backward()
429
+ optimizer.step()
430
+
431
+ else:
432
+ # This is a naive weighted optimization
433
+ optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
434
+ for step_idx in tqdm(
435
+ range(self.num_steps),
436
+ desc=f"Optimizing {module_name} weight",
437
+ disable=self.num_ray_actors > 0,
438
+ ):
439
+ loss = 0
440
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
441
+ proj_u = proj_u_dict[i]
442
+ proj_v = proj_v_dict[i]
443
+ proj_s = proj_s_dict[i]
444
+ delta_tv = merged_weight - finetuned_weight
445
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
446
+ loss += self.alpha * loss_i if i == 0 else (1 - self.alpha) * loss_i
447
+
448
+ optimizer.zero_grad()
449
+ loss.backward()
450
+ optimizer.step()
451
+
452
+ return merged_weight.detach().to(dtype=original_dtype, device="cpu")
453
+
454
+ def cal_loss_i(self, delta_tv, proj_s, proj_u, proj_v):
455
+ """
456
+ Calculate the projection loss for a single task.
457
+
458
+ Computes the Frobenius norm of the projection of the weight difference
459
+ onto the SVD subspace(s) defined by U and/or V matrices.
460
+
461
+ Args:
462
+ delta_tv: Difference between merged weight and finetuned weight (task vector difference).
463
+ proj_s: Singular values from SVD of the task vector.
464
+ proj_u: Left singular vectors (U) from SVD.
465
+ proj_v: Right singular vectors (V) from SVD.
466
+
467
+ Returns:
468
+ Scalar loss value representing the projection distance.
469
+ """
470
+ proj_delta_1 = torch.diag(proj_s) @ proj_u.T @ delta_tv
471
+ proj_delta_2 = delta_tv @ proj_v @ torch.diag(proj_s)
472
+ loss_i_u = torch.linalg.matrix_norm(proj_delta_1, ord="fro") ** 2
473
+ loss_i_v = torch.linalg.matrix_norm(proj_delta_2, ord="fro") ** 2
474
+ if self.svd_proj_space == "uv":
475
+ loss_i = loss_i_u + loss_i_v
476
+ elif self.svd_proj_space == "u":
477
+ loss_i = loss_i_u
478
+ elif self.svd_proj_space == "v":
479
+ loss_i = loss_i_v
480
+ else:
481
+ raise ValueError("Invalid svd_proj_space")
482
+
483
+ return loss_i
484
+
485
+
486
+ if is_ray_available():
487
+ import ray
488
+
489
+ DOPMergingActor = ray.remote(DOPMerging)
@@ -3,6 +3,7 @@ from typing import Tuple
3
3
  import torch
4
4
  from torch import Tensor, nn
5
5
 
6
+ from fusion_bench.models.utils import is_leaf_module
6
7
  from fusion_bench.utils.parameters import state_dict_to_vector
7
8
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
8
9
 
@@ -51,10 +52,6 @@ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
51
52
  return torch.trace(w1.T @ w2)
52
53
 
53
54
 
54
- def is_leaf_module(module: nn.Module) -> bool:
55
- return len(list(module.children())) == 0
56
-
57
-
58
55
  def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
59
56
  """
60
57
  Get the vector norm of the task model.
@@ -71,3 +68,26 @@ def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tenso
71
68
  state_dict_sub(model.state_dict(), pretrained_model.state_dict())
72
69
  )
73
70
  )
71
+
72
+
73
+ def print_params(model):
74
+ total_params = 0
75
+ linear_params = 0
76
+ linear_weight_params = 0
77
+ for module_name, module in model.named_modules():
78
+ if not is_leaf_module(module):
79
+ continue
80
+ if isinstance(module, nn.Linear):
81
+ linear_params += sum(p.numel() for n, p in module.named_parameters())
82
+ linear_weight_params += sum(
83
+ p.numel() for n, p in module.named_parameters() if "weight" in n
84
+ )
85
+ total_params += sum(p.numel() for p in module.parameters())
86
+
87
+ linear_ratio = linear_params / total_params * 100
88
+ linear_weight_ratio = linear_weight_params / total_params * 100
89
+ print(f"Total Parameters: {total_params}")
90
+ print(f"Linear Parameters: {linear_params}")
91
+ print(f"Linear Weight Parameters: {linear_weight_params}")
92
+ print(f"Linear Ratio: {linear_ratio:.2f}%")
93
+ print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
@@ -0,0 +1 @@
1
+ from .emr_merging import EMRMerging
@@ -0,0 +1,53 @@
1
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
2
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
3
+
4
+ from .utils import EMRModulatedModel, EMRTaskModulator, emr_merge
5
+
6
+
7
+ @auto_register_config
8
+ class EMRMerging(BaseAlgorithm):
9
+ """
10
+ EMR Merging Algorithm.
11
+
12
+ This algorithm merges multiple task-specific models into a unified model using
13
+ the Elect, Mask & Rescale (EMR) strategy. It constructs a modulated model that
14
+ can adapt to different tasks via task-specific modulators.
15
+ """
16
+
17
+ def load_pretrained_model_and_task_vectors(self, modelpool: BaseModelPool):
18
+ pretrained_model = modelpool.load_pretrained_model()
19
+
20
+ task_vectors = []
21
+ for model_name in modelpool.model_names:
22
+ finetuned_model = modelpool.load_model(model_name)
23
+ task_vector = state_dict_sub(
24
+ finetuned_model.state_dict(), pretrained_model.state_dict()
25
+ )
26
+ task_vectors.append(task_vector)
27
+
28
+ return pretrained_model, task_vectors
29
+
30
+ def run(self, modelpool: BaseModelPool) -> EMRModulatedModel:
31
+ if not isinstance(modelpool, BaseModelPool):
32
+ modelpool = BaseModelPool(modelpool)
33
+
34
+ pretrained_model, task_vectors = (
35
+ modelpool.load_pretrained_model_and_task_vectors()
36
+ )
37
+
38
+ unified_vector, masks, rescalers = emr_merge(task_vectors)
39
+
40
+ emr_model = EMRModulatedModel(
41
+ backbone=pretrained_model, modulators={}, unified_task_vector=unified_vector
42
+ )
43
+
44
+ for model_idx, model_name in enumerate(modelpool.model_names):
45
+ emr_model.add_modulator(
46
+ task_name=model_name,
47
+ modulator=EMRTaskModulator(
48
+ mask={n: m[model_idx] for n, m in masks.items()},
49
+ rescaler=rescalers[model_idx],
50
+ ),
51
+ )
52
+
53
+ return emr_model