fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/dataset/__init__.py +2 -0
  4. fusion_bench/dataset/clip_dataset.py +4 -72
  5. fusion_bench/dataset/image_dataset.py +44 -18
  6. fusion_bench/method/base_algorithm.py +4 -0
  7. fusion_bench/method/dop/dop.py +0 -22
  8. fusion_bench/method/dop/dop_general.py +489 -0
  9. fusion_bench/method/dop/utils.py +24 -4
  10. fusion_bench/method/emr_merging/__init__.py +1 -0
  11. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  12. fusion_bench/method/emr_merging/utils.py +162 -0
  13. fusion_bench/method/opcm/opcm.py +6 -2
  14. fusion_bench/method/opcm/opcm_general.py +356 -0
  15. fusion_bench/method/opcm/utils.py +1 -4
  16. fusion_bench/method/simple_average.py +52 -18
  17. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  18. fusion_bench/mixins/lightning_fabric.py +108 -3
  19. fusion_bench/mixins/serialization.py +1 -1
  20. fusion_bench/modelpool/base_pool.py +37 -1
  21. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  22. fusion_bench/models/hf_clip.py +20 -0
  23. fusion_bench/models/modulator/__init__.py +1 -0
  24. fusion_bench/models/modulator/base.py +123 -0
  25. fusion_bench/models/parameter_dict.py +119 -29
  26. fusion_bench/models/utils.py +190 -2
  27. fusion_bench/models/wrappers/switch.py +90 -0
  28. fusion_bench/programs/base_program.py +6 -0
  29. fusion_bench/programs/fabric_fusion_program.py +4 -0
  30. fusion_bench/scripts/cli.py +19 -8
  31. fusion_bench/taskpool/image_classification.py +270 -0
  32. fusion_bench/utils/__init__.py +18 -1
  33. fusion_bench/utils/data.py +1 -1
  34. fusion_bench/utils/dict.py +19 -0
  35. fusion_bench/utils/dtype.py +19 -0
  36. fusion_bench/utils/misc.py +1 -0
  37. fusion_bench/utils/packages.py +4 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  39. fusion_bench/utils/tensorboard.py +21 -3
  40. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  41. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
  42. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  43. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  44. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  45. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  46. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  47. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  48. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  49. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  50. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  51. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+
7
+ from fusion_bench import StateDictType, TorchModelType
8
+ from fusion_bench.models.modulator import ModulatedModel, TaskModulator
9
+ from fusion_bench.models.modulator.base import ModulatedModel, TaskModulator
10
+ from fusion_bench.models.parameter_dict import ParameterDictModel
11
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
12
+
13
+
14
+ def _sign(x: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Return the sign of the tensor: 1 for positive, -1 for negative.
17
+ Zeros are treated as negative (i.e., sign -1).
18
+ """
19
+ return (x > 0) * 2 - 1
20
+
21
+
22
+ def emr_merge(task_vectors: list[StateDictType]):
23
+ """
24
+ Modified from original EMR merging function to return unified vector, masks, and rescalers.
25
+
26
+ Args:
27
+ task_vectors: List of task-specific vectors (state dicts).
28
+
29
+ Returns:
30
+ vector_unified: The unified task vector (state dict).
31
+ masks: Dict mapping parameter names to list of task-specific masks (tensors).
32
+ rescalers: Tensor of rescaling factors for each task.
33
+ """
34
+ num_tasks = len(task_vectors)
35
+
36
+ # compute the sign flag
37
+ # \gamma_uni = sign( sum_i tau_i )
38
+ flag_dict = {k: _sign(v) for k, v in state_dict_sum(task_vectors).items()}
39
+
40
+ # \tau_uni
41
+ vector_unified = {}
42
+ scales = torch.zeros(num_tasks)
43
+ # mask indicate if the direction of the task vector aligns with the unified vector
44
+ # {<param_name>: [mask_task1, mask_task2, ...]}
45
+ masks: dict[str, list[torch.Tensor]] = {}
46
+ for n, flag in flag_dict.items():
47
+ masks[n] = []
48
+ param_max = torch.zeros_like(task_vectors[0][n])
49
+ for m in range(num_tasks):
50
+ param = task_vectors[m][n]
51
+ mask = (param * flag) > 0
52
+ masks[n].append(mask)
53
+ param_abs = torch.abs(mask * param)
54
+ param_max = torch.where(param_abs > param_max, param_abs, param_max)
55
+ scales[m] += torch.mean(torch.abs(param))
56
+ vector_unified[n] = param_max * flag
57
+
58
+ new_scales = torch.zeros(num_tasks)
59
+ for m in range(num_tasks):
60
+ for n in vector_unified:
61
+ p = vector_unified[n] * masks[n][m]
62
+ new_scales[m] += torch.mean(torch.abs(p))
63
+ rescalers = scales / new_scales
64
+
65
+ return vector_unified, masks, rescalers
66
+
67
+
68
+ class EMRModulatedModel(ModulatedModel[TorchModelType]):
69
+ """
70
+ Modulated Model for EMR Merging.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ backbone: TorchModelType,
76
+ modulators: Dict[str, "EMRTaskModulator"],
77
+ unified_task_vector: StateDictType,
78
+ ):
79
+ super().__init__(backbone, modulators)
80
+
81
+ unified_task_vector = unified_task_vector.copy()
82
+ for name, tensor in unified_task_vector.items():
83
+ if not isinstance(tensor, (nn.Parameter, nn.Buffer)):
84
+ unified_task_vector[name] = nn.Parameter(tensor, requires_grad=False)
85
+ self.unified_task_vector = ParameterDictModel(unified_task_vector)
86
+
87
+
88
+ class EMRTaskModulator(TaskModulator[TorchModelType]):
89
+ """
90
+ Task Modulator for EMR (Elect, Mask & Rescale) Merging.
91
+
92
+ This modulator applies task-specific adaptations to a unified model by:
93
+ 1. Masking: Aligning direction with task-specific model (mask sets inconsistent signs to zero)
94
+ 2. Rescaling: Aligning magnitude with task-specific model
95
+
96
+ The application formula is:
97
+ θ_new = θ_old + τ_unified ⊙ mask_i * rescaler_i
98
+
99
+ where:
100
+ - τ_unified is the unified task vector (elected from all task vectors)
101
+ - mask_i is the task-specific binary mask
102
+ - rescaler_i is the task-specific rescaling factor
103
+
104
+ Args:
105
+ vector: The unified task vector (τ_unified) as a state dict
106
+ mask: Task-specific binary mask as a dict of tensors
107
+ rescaler: Task-specific rescaling factor (scalar)
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ mask: Dict[str, torch.Tensor],
113
+ rescaler: float,
114
+ ):
115
+ super().__init__()
116
+
117
+ # Store masks separately with a prefix to avoid conflicts
118
+ mask = mask.copy()
119
+ for name, tensor in mask.items():
120
+ if not isinstance(tensor, (nn.Parameter, nn.Buffer)):
121
+ mask[name] = nn.Parameter(tensor, requires_grad=False)
122
+ self.mask = ParameterDictModel(mask)
123
+
124
+ # Register rescaler as a parameter for proper device handling
125
+ self.rescaler = nn.Parameter(torch.tensor(rescaler), requires_grad=False)
126
+
127
+ @torch.no_grad()
128
+ def apply(self, modulated_model: "EMRModulatedModel[TorchModelType]"):
129
+ """
130
+ Apply the EMR task vector to the model.
131
+
132
+ For each parameter in the state dict:
133
+ θ_new = θ_old + τ_unified ⊙ mask_i * rescaler_i
134
+
135
+ This applies the masked and rescaled unified task vector to align the backbone
136
+ with the task-specific model.
137
+ """
138
+ unified_vector = modulated_model.unified_task_vector
139
+
140
+ for name in unified_vector:
141
+ delta = unified_vector[name] * self.mask[name] * self.rescaler
142
+ param = modulated_model.backbone.get_parameter(name)
143
+ param.add_(delta)
144
+
145
+ @torch.no_grad()
146
+ def remove(self, modulated_model: "EMRModulatedModel[TorchModelType]"):
147
+ """
148
+ Remove the EMR task vector from the model.
149
+
150
+ For each parameter in the state dict:
151
+ θ_old = θ_new - τ_unified ⊙ mask_i * rescaler_i
152
+
153
+ This reverses the task-specific adaptation to restore the original backbone.
154
+ """
155
+ unified_vector = modulated_model.unified_task_vector
156
+
157
+ for name in unified_vector:
158
+ delta = unified_vector[name] * self.mask[name] * self.rescaler
159
+ param = modulated_model.backbone.get_parameter(name)
160
+ param.sub_(delta)
161
+
162
+ modulated_model._current_task = None
@@ -16,22 +16,23 @@ from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
18
  from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
+ from fusion_bench.models.utils import is_leaf_module
19
20
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
21
  from fusion_bench.utils import instantiate
21
22
  from fusion_bench.utils.json import load_from_json, save_to_json
22
23
  from fusion_bench.utils.parameters import state_dict_to_vector
23
24
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
24
25
 
25
- from .utils import frobenius_inner_product, get_task_vector_norm, is_leaf_module, svd
26
+ from .utils import frobenius_inner_product, get_task_vector_norm, svd
26
27
 
27
28
  if TYPE_CHECKING:
28
29
  from torch.utils.tensorboard import SummaryWriter
29
30
 
30
31
 
31
32
  class OPCMForCLIP(
32
- BaseAlgorithm,
33
33
  LightningFabricMixin,
34
34
  SimpleProfilerMixin,
35
+ BaseAlgorithm,
35
36
  ):
36
37
  def __init__(
37
38
  self,
@@ -219,6 +220,9 @@ class OPCMForCLIP(
219
220
  return merged_model
220
221
 
221
222
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
223
+ if self.log_dir is None:
224
+ print("Log dir is None, skip saving merged model.")
225
+ return
222
226
  os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
223
227
  merged_model.save_pretrained(
224
228
  Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
@@ -0,0 +1,356 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
8
+
9
+ import lightning as L
10
+ import numpy as np
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from torch import Tensor, nn
14
+ from tqdm.auto import tqdm
15
+
16
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
17
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
18
+ from fusion_bench.models.utils import is_leaf_module, named_leaf_modules
19
+ from fusion_bench.utils import instantiate
20
+ from fusion_bench.utils.json import load_from_json, save_to_json
21
+ from fusion_bench.utils.packages import is_ray_available
22
+ from fusion_bench.utils.parameters import state_dict_to_vector
23
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
24
+
25
+ from .utils import frobenius_inner_product, get_task_vector_norm, svd
26
+
27
+ if TYPE_CHECKING:
28
+ from torch.utils.tensorboard import SummaryWriter
29
+
30
+
31
+ @auto_register_config
32
+ class OPCM(
33
+ LightningFabricMixin,
34
+ SimpleProfilerMixin,
35
+ BaseAlgorithm,
36
+ ):
37
+ def __init__(
38
+ self,
39
+ alpha: float,
40
+ shuffle_order: bool = True,
41
+ seed: Optional[int] = None,
42
+ save_on_every_step: bool = True,
43
+ evaluate_on_every_step: bool = False,
44
+ num_ray_actors: int = 0,
45
+ **kwargs,
46
+ ):
47
+ """
48
+ Continual Model Merging via SVD Projection.
49
+
50
+ Args:
51
+ alpha (float): the scaling factor for the SVD projection.
52
+ shuffle_order (bool): whether to shuffle the order of the models.
53
+ seed (Optional[int]): the seed to use.
54
+ save_on_every_step (bool): whether to save the merged model on every step.
55
+ evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
56
+ """
57
+ self.alpha = alpha
58
+ self.shuffle_order = shuffle_order
59
+ self.seed = seed
60
+ self.save_on_every_step = save_on_every_step
61
+ self.evaluate_on_every_step = evaluate_on_every_step
62
+ super().__init__(**kwargs)
63
+
64
+ @torch.no_grad()
65
+ def run(self, modelpool: BaseModelPool):
66
+ if self.num_ray_actors > 0:
67
+ if is_ray_available():
68
+ import ray
69
+ from ray.util.actor_pool import ActorPool
70
+
71
+ if not ray.is_initialized():
72
+ ray.init()
73
+
74
+ # create actors
75
+ if self.fabric.device.type == "cuda":
76
+ actor_options = {"num_gpus": 1}
77
+ else:
78
+ actor_options = {}
79
+ self.ray_actor_pool = ActorPool(
80
+ [
81
+ OPCMActor.options(**actor_options).remote(**self.config)
82
+ for _ in range(self.num_ray_actors)
83
+ ]
84
+ )
85
+
86
+ if self.seed is not None:
87
+ L.seed_everything(self.seed)
88
+
89
+ with self.profile("loading model"):
90
+ pretrained_model = modelpool.load_pretrained_model()
91
+
92
+ model_names = modelpool.model_names
93
+ if self.shuffle_order:
94
+ random.shuffle(model_names)
95
+
96
+ # log the model names
97
+ if self.log_dir is not None:
98
+ save_to_json(model_names, Path(self.log_dir) / "model_names.json")
99
+ tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
100
+ tensorboard_summarywriter.add_text(
101
+ "global/model_names", str(model_names), global_step=0
102
+ )
103
+
104
+ # get the average model
105
+ with self.profile("loading model"):
106
+ print("Using the first model as the initial merged model.")
107
+ merged_model = modelpool.load_model(model_names[0])
108
+ assert merged_model is not None, "Failed to load the first model"
109
+
110
+ self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
111
+ self.all_task_vector_norm = [self.avg_task_vector_norm]
112
+ self.fabric.log("model/task_vector_norm", self.avg_task_vector_norm, step=0)
113
+ self.fabric.log("model/avg_task_vector_norm", self.avg_task_vector_norm, step=0)
114
+ self.fabric.log(
115
+ "model/merged_task_vector_norm", self.avg_task_vector_norm, step=0
116
+ )
117
+
118
+ self.previous_lambda_t = 1
119
+ self.lambda_t = None
120
+ self.fabric.log("model/lambda_t", self.previous_lambda_t, step=0)
121
+ self.fabric.log("empirical/lambda_t", 1, step=0)
122
+
123
+ if self.save_on_every_step:
124
+ self.save_merged_model(merged_model, 0)
125
+
126
+ for model_idx, model_name in tqdm(
127
+ enumerate(model_names[1:]), desc="Processing models"
128
+ ):
129
+ model_idx += 1
130
+ with self.profile("loading model"):
131
+ task_model = modelpool.load_model(model_name)
132
+
133
+ with self.profile("merging model"):
134
+ self.all_task_vector_norm.append(
135
+ get_task_vector_norm(task_model, pretrained_model)
136
+ )
137
+ self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
138
+ self.fabric.log(
139
+ "model/task_vector_norm",
140
+ self.all_task_vector_norm[-1],
141
+ step=model_idx,
142
+ )
143
+ self.fabric.log(
144
+ "model/avg_task_vector_norm",
145
+ self.avg_task_vector_norm,
146
+ step=model_idx,
147
+ )
148
+
149
+ self.lambda_t = 1 # temporary value
150
+
151
+ self._layer_wise_merge(
152
+ merged_model=merged_model,
153
+ pretrained_model=pretrained_model,
154
+ task_model=task_model,
155
+ model_name=model_name,
156
+ )
157
+
158
+ task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
159
+ self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
160
+ for param_name, param in merged_model.named_parameters():
161
+ param.data = pretrained_model.get_parameter(param_name) + (
162
+ param - pretrained_model.get_parameter(param_name)
163
+ ) * (self.avg_task_vector_norm / task_vector_norm)
164
+ self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
165
+ self.fabric.log(
166
+ "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
167
+ )
168
+ self.previous_lambda_t = self.lambda_t
169
+ self.lambda_t = None
170
+
171
+ self.fabric.log(
172
+ "model/merged_task_vector_norm",
173
+ get_task_vector_norm(merged_model, pretrained_model),
174
+ step=model_idx,
175
+ )
176
+
177
+ if self.save_on_every_step:
178
+ with self.profile("saving model"):
179
+ self.save_merged_model(merged_model, model_idx)
180
+
181
+ self.print_profile_summary()
182
+ return merged_model
183
+
184
+ def _layer_wise_merge(self, merged_model, pretrained_model, task_model, model_name):
185
+ if self.num_ray_actors > 0:
186
+ self._update_attributes_across_ray()
187
+
188
+ for module_name, module in tqdm(
189
+ list(named_leaf_modules(merged_model, ignore_empty=True)),
190
+ desc=f"Processing {model_name}",
191
+ leave=False,
192
+ disable=self.num_ray_actors > 0,
193
+ ):
194
+ if isinstance(module, nn.Linear):
195
+ # processing linear layers
196
+ merge_kwargs = {
197
+ "merged_W": module.weight,
198
+ "pretrained_W": pretrained_model.get_submodule(module_name).weight,
199
+ "task_W": task_model.get_submodule(module_name).weight,
200
+ "param_name": ".".join([module_name, "weight"]),
201
+ "alpha": self.alpha,
202
+ }
203
+ if not self.num_ray_actors > 0:
204
+ _, merged_weight = self.merge_linear_weights(**merge_kwargs)
205
+ module.weight.data = merged_weight
206
+ else:
207
+ if not self.ray_actor_pool.has_free():
208
+ returned_module_name, merged_weight = (
209
+ self.ray_actor_pool.get_next_unordered()
210
+ )
211
+ print(f"merged weight {returned_module_name} from ray actors.")
212
+ pretrained_model.get_submodule(
213
+ returned_module_name
214
+ ).weight.data = merged_weight
215
+ self.ray_actor_pool.submit(
216
+ lambda actor, kwargs: actor.merge_linear_weights.remote(
217
+ **kwargs
218
+ ),
219
+ merge_kwargs,
220
+ )
221
+ # processing bias if exists
222
+ if module.bias is not None:
223
+ module.bias.data = self.merge_other_parameters(
224
+ module.bias,
225
+ pretrained_model.get_submodule(module_name).bias,
226
+ task_model.get_submodule(module_name).bias,
227
+ param_name=".".join([module_name, "bias"]),
228
+ )
229
+ else:
230
+ for param_name, param in module.named_parameters():
231
+ param.data = self.merge_other_parameters(
232
+ merged_W=param,
233
+ pretrained_W=pretrained_model.get_submodule(
234
+ module_name
235
+ ).get_parameter(param_name),
236
+ task_W=task_model.get_submodule(module_name).get_parameter(
237
+ param_name
238
+ ),
239
+ param_name=".".join([module_name, param_name]),
240
+ )
241
+
242
+ if self.num_ray_actors > 0:
243
+ while self.ray_actor_pool.has_next():
244
+ returned_module_name, merged_weight = (
245
+ self.ray_actor_pool.get_next_unordered()
246
+ )
247
+ print(f"merged weight {returned_module_name} from ray actors.")
248
+ merged_model.get_submodule(returned_module_name).weight.data = (
249
+ merged_weight
250
+ )
251
+
252
+ def save_merged_model(self, merged_model, step: int):
253
+ if self.log_dir is None:
254
+ print("Log dir is None, skip saving merged model.")
255
+ return
256
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
257
+ merged_model.save_pretrained(
258
+ Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
259
+ )
260
+
261
+ def _update_attributes_across_ray(self, attr_dict=None):
262
+ if attr_dict is None:
263
+ # called on master
264
+ attrs_to_sync = ["previous_lambda_t", "lambda_t"]
265
+ assert (
266
+ not self.ray_actor_pool.has_next()
267
+ ), "All previous tasks must be merged before syncing attributes."
268
+
269
+ for actor in self.ray_actor_pool._idle_actors:
270
+ actor._update_attributes_across_ray.remote(
271
+ {attr: getattr(self, attr) for attr in attrs_to_sync}
272
+ )
273
+ else:
274
+ # called on ray actors
275
+ for attr, value in attr_dict.items():
276
+ setattr(self, attr, value)
277
+
278
+ def merge_linear_weights(
279
+ self,
280
+ merged_W: Tensor,
281
+ pretrained_W: Tensor,
282
+ task_W: Tensor,
283
+ param_name: str,
284
+ alpha: float,
285
+ ):
286
+ accelerator = self.fabric.device
287
+
288
+ original_device = merged_W.device
289
+ merged_W = merged_W.to(accelerator)
290
+ pretrained_W = pretrained_W.to(accelerator)
291
+ task_W = task_W.to(accelerator)
292
+
293
+ previous_merged_tv = merged_W - pretrained_W
294
+ task_tv = task_W - pretrained_W
295
+
296
+ u, s, v = svd(previous_merged_tv)
297
+ rank = s.size(0)
298
+ split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
299
+
300
+ projected_task_tv = u.T @ task_tv @ v
301
+ projected_task_tv.diagonal().fill_(0)
302
+
303
+ projected_task_tv[:split_rank, :split_rank] = 0
304
+
305
+ cleaned_task_tv = u @ projected_task_tv @ v.T
306
+
307
+ previous_lambda_t = self.previous_lambda_t
308
+ lambda_t = self.lambda_t
309
+ new_merged_W = (
310
+ pretrained_W
311
+ + (previous_lambda_t * previous_merged_tv + cleaned_task_tv) / lambda_t
312
+ )
313
+ module_name = param_name[: param_name.rfind(".")]
314
+ return module_name, new_merged_W.to(original_device)
315
+
316
+ def merge_other_parameters(
317
+ self,
318
+ merged_W: Tensor,
319
+ pretrained_W: Tensor,
320
+ task_W: Tensor,
321
+ param_name: str,
322
+ ):
323
+ accelerator = self.fabric.device
324
+
325
+ original_device = merged_W.device
326
+ merged_W = merged_W.to(accelerator)
327
+ pretrained_W = pretrained_W.to(accelerator)
328
+ task_W = task_W.to(accelerator)
329
+
330
+ previous_merged_tv = merged_W - pretrained_W
331
+ task_tv = task_W - pretrained_W
332
+
333
+ previous_lambda_t = self.previous_lambda_t
334
+ lambda_t = self.lambda_t
335
+
336
+ new_merged_W = (
337
+ pretrained_W + (previous_lambda_t * previous_merged_tv + task_tv) / lambda_t
338
+ )
339
+ return new_merged_W.to(original_device)
340
+
341
+ def compute_lambda_t(
342
+ self, previous_merged_tv: Tensor, task_tv: Tensor, previous_lambda_t: float
343
+ ):
344
+ previous_merged_tv = torch.flatten(previous_merged_tv)
345
+ task_tv = torch.flatten(task_tv)
346
+
347
+ lambda_t = torch.linalg.vector_norm(
348
+ previous_lambda_t * previous_merged_tv + task_tv
349
+ ) / torch.linalg.vector_norm(previous_merged_tv)
350
+ return lambda_t.item()
351
+
352
+
353
+ if is_ray_available():
354
+ import ray
355
+
356
+ OPCMActor = ray.remote(OPCM)
@@ -3,6 +3,7 @@ from typing import Tuple
3
3
  import torch
4
4
  from torch import Tensor, nn
5
5
 
6
+ from fusion_bench.models.utils import is_leaf_module
6
7
  from fusion_bench.utils.parameters import state_dict_to_vector
7
8
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
8
9
 
@@ -51,10 +52,6 @@ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
51
52
  return torch.trace(w1.T @ w2)
52
53
 
53
54
 
54
- def is_leaf_module(module: nn.Module) -> bool:
55
- return len(list(module.children())) == 0
56
-
57
-
58
55
  def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
59
56
  """
60
57
  Get the vector norm of the task model.