fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +110 -11
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -0
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +25 -23
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +20 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity. NeurIPS, 2025.
|
|
3
|
+
(Architecture agnostic implementation)
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import time
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, List, Literal, Optional, Tuple, cast
|
|
13
|
+
|
|
14
|
+
import lightning as L
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
from omegaconf import DictConfig
|
|
18
|
+
from torch import Tensor, nn
|
|
19
|
+
from torch.autograd import Variable
|
|
20
|
+
from tqdm.auto import tqdm
|
|
21
|
+
|
|
22
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
23
|
+
from fusion_bench.method.simple_average import simple_average
|
|
24
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
25
|
+
from fusion_bench.models.utils import named_leaf_modules
|
|
26
|
+
from fusion_bench.utils import seed_everything_by_time
|
|
27
|
+
from fusion_bench.utils.dtype import dtype_support_svd
|
|
28
|
+
from fusion_bench.utils.json import save_to_json
|
|
29
|
+
from fusion_bench.utils.packages import is_ray_available
|
|
30
|
+
|
|
31
|
+
from .min_norm_solvers import MinNormSolver, gradient_normalizers
|
|
32
|
+
from .utils import is_leaf_module, print_params, svd
|
|
33
|
+
|
|
34
|
+
log = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@auto_register_config
|
|
38
|
+
class DOPMerging(LightningFabricMixin, SimpleProfilerMixin, BaseAlgorithm):
|
|
39
|
+
"""
|
|
40
|
+
Dual Projections for Balancing Stability and Plasticity (DOP) merging algorithm.
|
|
41
|
+
|
|
42
|
+
This method implements continual model merging without data by using dual projections
|
|
43
|
+
in the SVD space to balance stability (preserving previously merged model's knowledge)
|
|
44
|
+
and plasticity (incorporating new model's knowledge).
|
|
45
|
+
|
|
46
|
+
The algorithm merges models sequentially, optimizing each merge using gradient descent
|
|
47
|
+
with optional multi-gradient descent algorithm (MGDA) for better trade-offs.
|
|
48
|
+
|
|
49
|
+
Reference:
|
|
50
|
+
Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity.
|
|
51
|
+
NeurIPS, 2025.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
seed: Optional[int] = None,
|
|
57
|
+
shuffle_order: bool = False,
|
|
58
|
+
save_on_every_step: bool = True,
|
|
59
|
+
evaluate_on_every_step: bool = False,
|
|
60
|
+
lr: float = 1e-4,
|
|
61
|
+
num_steps: int = 200,
|
|
62
|
+
mgda: bool = True,
|
|
63
|
+
ema: bool = True,
|
|
64
|
+
ema_beta: float = 0.99,
|
|
65
|
+
alpha: float = None,
|
|
66
|
+
svd_epsilon: float = 1.0,
|
|
67
|
+
svd_proj_space: str = "uv",
|
|
68
|
+
exclude_keys: List[str] | None = None,
|
|
69
|
+
num_ray_actors: int = 0,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Initialize the DOP merging algorithm.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
seed: Random seed for reproducibility. If None, uses time-based seeding.
|
|
77
|
+
shuffle_order: Whether to shuffle the order of models before merging.
|
|
78
|
+
save_on_every_step: Whether to save the model after each merge step.
|
|
79
|
+
evaluate_on_every_step: Whether to evaluate the model after each merge step.
|
|
80
|
+
lr: Learning rate for the optimization process.
|
|
81
|
+
num_steps: Number of optimization steps per layer.
|
|
82
|
+
mgda: Whether to use Multi-Gradient Descent Algorithm for balancing losses.
|
|
83
|
+
ema: Whether to use exponential moving average for MGDA weights.
|
|
84
|
+
ema_beta: EMA decay rate for MGDA weights (only used if ema=True).
|
|
85
|
+
alpha: Weight for balancing between stability and plasticity (0-1).
|
|
86
|
+
When mgda=False, used as a fixed weight. When mgda=True with ema=True,
|
|
87
|
+
used as initial weight.
|
|
88
|
+
svd_epsilon: Threshold for SVD rank selection (0-1). Determines how much
|
|
89
|
+
variance to preserve in the projection space.
|
|
90
|
+
svd_proj_space: SVD projection space to use: 'u', 'v', or 'uv' (both).
|
|
91
|
+
exclude_keys: List of module names to exclude from optimization.
|
|
92
|
+
num_ray_actors: Number of Ray actors to use for parallel processing. If 0, ray is not used.
|
|
93
|
+
**kwargs: Additional arguments passed to BaseAlgorithm.
|
|
94
|
+
"""
|
|
95
|
+
self.lr = lr
|
|
96
|
+
self.num_steps = num_steps
|
|
97
|
+
self.mgda = mgda
|
|
98
|
+
self.ema = ema
|
|
99
|
+
self.ema_beta = ema_beta
|
|
100
|
+
self.alpha = alpha
|
|
101
|
+
self.svd_epsilon = svd_epsilon
|
|
102
|
+
self.svd_proj_space = svd_proj_space
|
|
103
|
+
self.seed = seed
|
|
104
|
+
self.shuffle_order = shuffle_order
|
|
105
|
+
self.save_on_every_step = save_on_every_step
|
|
106
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
107
|
+
|
|
108
|
+
if exclude_keys is None:
|
|
109
|
+
exclude_keys = []
|
|
110
|
+
self.exclude_keys = exclude_keys
|
|
111
|
+
|
|
112
|
+
assert (
|
|
113
|
+
self.svd_epsilon >= 0 and self.svd_epsilon <= 1
|
|
114
|
+
), "The svd_epsilon should be in the range of [0, 1]"
|
|
115
|
+
assert (
|
|
116
|
+
self.alpha >= 0 and self.alpha <= 1
|
|
117
|
+
), "The alpha should be in the range of [0, 1]"
|
|
118
|
+
super().__init__(**kwargs)
|
|
119
|
+
|
|
120
|
+
def run(self, modelpool: BaseModelPool):
|
|
121
|
+
"""
|
|
122
|
+
Execute the DOP merging algorithm on a pool of models.
|
|
123
|
+
|
|
124
|
+
Merges models sequentially, where each new model is merged with the
|
|
125
|
+
previously merged result. The first model is used as-is, and subsequent
|
|
126
|
+
models are merged using layer-wise optimization.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
modelpool: The model pool containing models to merge and the pretrained model.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The final merged model after sequentially merging all models in the pool.
|
|
133
|
+
"""
|
|
134
|
+
if self.num_ray_actors > 0:
|
|
135
|
+
if is_ray_available():
|
|
136
|
+
import ray
|
|
137
|
+
from ray.util.actor_pool import ActorPool
|
|
138
|
+
|
|
139
|
+
if not ray.is_initialized():
|
|
140
|
+
ray.init()
|
|
141
|
+
|
|
142
|
+
# create actors
|
|
143
|
+
if self.fabric.device.type == "cuda":
|
|
144
|
+
actor_options = {"num_gpus": 1}
|
|
145
|
+
else:
|
|
146
|
+
actor_options = {}
|
|
147
|
+
self.ray_actor_pool = ActorPool(
|
|
148
|
+
[
|
|
149
|
+
DOPMergingActor.options(**actor_options).remote(**self.config)
|
|
150
|
+
for _ in range(self.num_ray_actors)
|
|
151
|
+
]
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
raise ImportError(
|
|
155
|
+
"Ray is not installed. Please install ray to use this feature. Install with `pip install 'ray[default]'`."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
model_names = modelpool.model_names
|
|
159
|
+
if self.shuffle_order:
|
|
160
|
+
random.shuffle(model_names)
|
|
161
|
+
|
|
162
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
163
|
+
|
|
164
|
+
merged_model = None
|
|
165
|
+
for model_idx, model_name in enumerate(model_names):
|
|
166
|
+
print(
|
|
167
|
+
f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
|
|
168
|
+
)
|
|
169
|
+
if model_idx == 0:
|
|
170
|
+
print("Using the first model as the initial merged model.")
|
|
171
|
+
with self.profile("loading models"):
|
|
172
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
173
|
+
else:
|
|
174
|
+
with self.profile("loading models"):
|
|
175
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
176
|
+
with self.profile("DOP merging"):
|
|
177
|
+
merged_model = self._layer_wise_optimize(
|
|
178
|
+
model_names=["merged", model_name],
|
|
179
|
+
pretrained_model=deepcopy(pretrained_model),
|
|
180
|
+
finetuned_models={
|
|
181
|
+
"merged": merged_model,
|
|
182
|
+
model_name: finetuned_model,
|
|
183
|
+
},
|
|
184
|
+
model_idx=model_idx,
|
|
185
|
+
)
|
|
186
|
+
del finetuned_model
|
|
187
|
+
|
|
188
|
+
self.print_profile_summary()
|
|
189
|
+
return merged_model
|
|
190
|
+
|
|
191
|
+
def _optimize_linear_layer(
|
|
192
|
+
self,
|
|
193
|
+
module_name: str,
|
|
194
|
+
module: nn.Linear,
|
|
195
|
+
finetuned_weights: Dict[str, nn.Linear],
|
|
196
|
+
model_idx: int,
|
|
197
|
+
):
|
|
198
|
+
if module.weight.requires_grad and module_name not in self.exclude_keys:
|
|
199
|
+
original_dtype = module.weight.dtype
|
|
200
|
+
merged_weight = self._optimize_weight(
|
|
201
|
+
module.weight,
|
|
202
|
+
finetuned_weights,
|
|
203
|
+
module_name,
|
|
204
|
+
model_idx,
|
|
205
|
+
)
|
|
206
|
+
merged_weight = merged_weight.to(dtype=original_dtype)
|
|
207
|
+
else:
|
|
208
|
+
merged_weight = simple_average(list(finetuned_weights.values()))
|
|
209
|
+
return module_name, merged_weight
|
|
210
|
+
|
|
211
|
+
def _layer_wise_optimize(
|
|
212
|
+
self,
|
|
213
|
+
model_names: List[str],
|
|
214
|
+
pretrained_model: nn.Module,
|
|
215
|
+
finetuned_models: Dict[str, nn.Module],
|
|
216
|
+
model_idx: int,
|
|
217
|
+
):
|
|
218
|
+
"""
|
|
219
|
+
Optimize model parameters layer by layer.
|
|
220
|
+
|
|
221
|
+
Iterates through all leaf modules in the pretrained model and merges their weights
|
|
222
|
+
with the corresponding modules in the finetuned models. Linear layers with trainable
|
|
223
|
+
weights (not in exclude_keys) are optimized using gradient descent, while other
|
|
224
|
+
parameters are simply averaged.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
model_names: List of model names to merge (e.g., ['merged', 'new_model']).
|
|
228
|
+
pretrained_model: The base pretrained model (structure modified in-place).
|
|
229
|
+
finetuned_models: Dictionary mapping model names to their finetuned versions.
|
|
230
|
+
model_idx: Index of the current model being merged (for tracking/logging).
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
The pretrained_model with optimized/merged weights from finetuned models.
|
|
234
|
+
"""
|
|
235
|
+
for module_name, module in named_leaf_modules(pretrained_model):
|
|
236
|
+
finetuned_modules = {
|
|
237
|
+
model_name: finetuned_models[model_name].get_submodule(module_name)
|
|
238
|
+
for model_name in model_names
|
|
239
|
+
}
|
|
240
|
+
if isinstance(module, nn.Linear):
|
|
241
|
+
# process weight
|
|
242
|
+
finetuned_weights = {
|
|
243
|
+
model_name: finetuned_modules[model_name].weight
|
|
244
|
+
for model_name in model_names
|
|
245
|
+
}
|
|
246
|
+
if self.num_ray_actors == 0:
|
|
247
|
+
_, merged_weight = self._optimize_linear_layer(
|
|
248
|
+
module_name,
|
|
249
|
+
module=module,
|
|
250
|
+
finetuned_weights=finetuned_weights,
|
|
251
|
+
model_idx=model_idx,
|
|
252
|
+
)
|
|
253
|
+
module.weight.data = merged_weight.data
|
|
254
|
+
else:
|
|
255
|
+
if not self.ray_actor_pool.has_free():
|
|
256
|
+
returned_module_name, merged_weight = (
|
|
257
|
+
self.ray_actor_pool.get_next_unordered()
|
|
258
|
+
)
|
|
259
|
+
print(f"merged weight {returned_module_name} from ray actors.")
|
|
260
|
+
pretrained_model.get_submodule(
|
|
261
|
+
returned_module_name
|
|
262
|
+
).weight.data = merged_weight
|
|
263
|
+
self.ray_actor_pool.submit(
|
|
264
|
+
lambda actor, kwargs: actor._optimize_linear_layer.remote(
|
|
265
|
+
**kwargs
|
|
266
|
+
),
|
|
267
|
+
{
|
|
268
|
+
"module_name": module_name,
|
|
269
|
+
"module": module,
|
|
270
|
+
"finetuned_weights": finetuned_weights,
|
|
271
|
+
"model_idx": model_idx,
|
|
272
|
+
},
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# process bias if exists
|
|
276
|
+
if module.bias is not None:
|
|
277
|
+
module.bias.data = simple_average(
|
|
278
|
+
[m.bias for m in finetuned_modules.values()]
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
simple_average(list(finetuned_modules.values()), base_module=module)
|
|
282
|
+
|
|
283
|
+
if self.num_ray_actors > 0:
|
|
284
|
+
while self.ray_actor_pool.has_next():
|
|
285
|
+
module_name, merged_weight = self.ray_actor_pool.get_next_unordered()
|
|
286
|
+
print(f"merged weight {module_name} from ray actors.")
|
|
287
|
+
pretrained_model.get_submodule(module_name).weight.data = merged_weight
|
|
288
|
+
|
|
289
|
+
return pretrained_model
|
|
290
|
+
|
|
291
|
+
def _optimize_weight(
|
|
292
|
+
self,
|
|
293
|
+
pretrained_weight: Tensor,
|
|
294
|
+
finetuned_weights: Dict[str, Tensor],
|
|
295
|
+
module_name: str,
|
|
296
|
+
model_idx: int,
|
|
297
|
+
):
|
|
298
|
+
"""
|
|
299
|
+
Optimize a single weight matrix by balancing projections in SVD space.
|
|
300
|
+
|
|
301
|
+
Performs gradient-based optimization to find merged weights that minimize
|
|
302
|
+
the projection loss in the SVD space of task vectors. Uses either MGDA
|
|
303
|
+
for automatic weight balancing or fixed alpha weighting.
|
|
304
|
+
|
|
305
|
+
The algorithm:
|
|
306
|
+
1. Computes SVD of each task vector (finetuned - pretrained)
|
|
307
|
+
2. Projects the difference between merged and finetuned weights onto SVD subspaces
|
|
308
|
+
3. Optimizes merged weights to minimize projection losses
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
pretrained_weight: The original pretrained weight matrix.
|
|
312
|
+
finetuned_weights: Dictionary mapping model names to their finetuned weight matrices.
|
|
313
|
+
module_name: Name of the module being optimized (for logging).
|
|
314
|
+
model_idx: Index of the current model being merged (for tracking).
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Optimized merged weight matrix on CPU.
|
|
318
|
+
"""
|
|
319
|
+
assert (
|
|
320
|
+
self.fabric.world_size == 1
|
|
321
|
+
), "This algorithm is not currently supported in distributed training"
|
|
322
|
+
|
|
323
|
+
with torch.no_grad():
|
|
324
|
+
# Convert weights to float if original dtype does not support SVD
|
|
325
|
+
original_dtype = pretrained_weight.dtype
|
|
326
|
+
if not dtype_support_svd(original_dtype):
|
|
327
|
+
pretrained_weight = pretrained_weight.float()
|
|
328
|
+
finetuned_weights = {
|
|
329
|
+
model_name: finetuned_weight.float()
|
|
330
|
+
for model_name, finetuned_weight in finetuned_weights.items()
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
# Move weights to the appropriate device
|
|
334
|
+
pretrained_weight = self.fabric.to_device(pretrained_weight.detach())
|
|
335
|
+
finetuned_weights = {
|
|
336
|
+
model_name: self.fabric.to_device(finetuned_weight.detach())
|
|
337
|
+
for model_name, finetuned_weight in finetuned_weights.items()
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
# Initialize merged weight as simple average of finetuned weights
|
|
341
|
+
merged_weight = self.fabric.to_device(
|
|
342
|
+
nn.Parameter(
|
|
343
|
+
simple_average(list(finetuned_weights.values())), requires_grad=True
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Compute SVD of the difference between the finetuned and pretrained weights
|
|
348
|
+
proj_u_dict = {}
|
|
349
|
+
proj_v_dict = {}
|
|
350
|
+
proj_s_dict = {}
|
|
351
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
352
|
+
finetuned_tv = finetuned_weight - pretrained_weight
|
|
353
|
+
u, s, v = svd(finetuned_tv, full_matrices=True)
|
|
354
|
+
epsilon = 1.0 if self.svd_epsilon > 1.0 else self.svd_epsilon
|
|
355
|
+
cumsum_ratio = s.cumsum(dim=0) / s.sum()
|
|
356
|
+
split_rank = torch.searchsorted(cumsum_ratio, epsilon).item()
|
|
357
|
+
u_main = u[:, :split_rank]
|
|
358
|
+
v_main = v[:, :split_rank]
|
|
359
|
+
s_main = s[:split_rank]
|
|
360
|
+
proj_u_dict[i] = u_main
|
|
361
|
+
proj_v_dict[i] = v_main
|
|
362
|
+
proj_s_dict[i] = s_main
|
|
363
|
+
|
|
364
|
+
if self.mgda:
|
|
365
|
+
if self.ema:
|
|
366
|
+
ema_sol = [self.alpha, 1 - self.alpha]
|
|
367
|
+
# This is multiple-gradient descent algorithm (MGDA) optimization
|
|
368
|
+
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
|
|
369
|
+
all_losses = [[], []]
|
|
370
|
+
all_alphas = [[], []]
|
|
371
|
+
for step_idx in tqdm(
|
|
372
|
+
range(self.num_steps),
|
|
373
|
+
desc=f"Optimizing {module_name} weight",
|
|
374
|
+
disable=self.num_ray_actors > 0,
|
|
375
|
+
):
|
|
376
|
+
# Scaling the loss functions based on the algorithm choice
|
|
377
|
+
loss_data = {}
|
|
378
|
+
grads = {}
|
|
379
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
380
|
+
proj_u = proj_u_dict[i]
|
|
381
|
+
proj_v = proj_v_dict[i]
|
|
382
|
+
proj_s = proj_s_dict[i]
|
|
383
|
+
delta_tv = merged_weight - finetuned_weight
|
|
384
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
385
|
+
loss_data[i] = float(loss_i.data)
|
|
386
|
+
|
|
387
|
+
all_losses[i].append(float(loss_i.data))
|
|
388
|
+
|
|
389
|
+
optimizer.zero_grad()
|
|
390
|
+
loss_i.backward()
|
|
391
|
+
grads[i] = Variable(
|
|
392
|
+
merged_weight.grad.data.clone(), requires_grad=False
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Normalize all gradients
|
|
396
|
+
gn = gradient_normalizers(
|
|
397
|
+
grads=grads, losses=loss_data, normalization_type="loss"
|
|
398
|
+
)
|
|
399
|
+
for i, _ in enumerate(finetuned_weights.values()):
|
|
400
|
+
grads[i] = grads[i] / float(gn[i])
|
|
401
|
+
|
|
402
|
+
# Frank-Wolfe iteration to compute scales.
|
|
403
|
+
sol, min_norm = MinNormSolver.find_min_norm_element(
|
|
404
|
+
[[grads[i]] for i in range(len(finetuned_weights.values()))]
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
if self.ema:
|
|
408
|
+
ema_sol = [
|
|
409
|
+
self.ema_beta * ema_sol[i] + (1 - self.ema_beta) * float(sol[i])
|
|
410
|
+
for i in range(len(sol))
|
|
411
|
+
]
|
|
412
|
+
sol = ema_sol
|
|
413
|
+
all_alphas[0].append(ema_sol[0])
|
|
414
|
+
all_alphas[1].append(ema_sol[1])
|
|
415
|
+
|
|
416
|
+
# Scaled back-propagation
|
|
417
|
+
loss = 0
|
|
418
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
419
|
+
# Comptue gradients of each loss function wrt parameters
|
|
420
|
+
proj_u = proj_u_dict[i]
|
|
421
|
+
proj_v = proj_v_dict[i]
|
|
422
|
+
proj_s = proj_s_dict[i]
|
|
423
|
+
delta_tv = merged_weight - finetuned_weight
|
|
424
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
425
|
+
loss += float(sol[i]) * loss_i
|
|
426
|
+
|
|
427
|
+
optimizer.zero_grad()
|
|
428
|
+
loss.backward()
|
|
429
|
+
optimizer.step()
|
|
430
|
+
|
|
431
|
+
else:
|
|
432
|
+
# This is a naive weighted optimization
|
|
433
|
+
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
|
|
434
|
+
for step_idx in tqdm(
|
|
435
|
+
range(self.num_steps),
|
|
436
|
+
desc=f"Optimizing {module_name} weight",
|
|
437
|
+
disable=self.num_ray_actors > 0,
|
|
438
|
+
):
|
|
439
|
+
loss = 0
|
|
440
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
441
|
+
proj_u = proj_u_dict[i]
|
|
442
|
+
proj_v = proj_v_dict[i]
|
|
443
|
+
proj_s = proj_s_dict[i]
|
|
444
|
+
delta_tv = merged_weight - finetuned_weight
|
|
445
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
446
|
+
loss += self.alpha * loss_i if i == 0 else (1 - self.alpha) * loss_i
|
|
447
|
+
|
|
448
|
+
optimizer.zero_grad()
|
|
449
|
+
loss.backward()
|
|
450
|
+
optimizer.step()
|
|
451
|
+
|
|
452
|
+
return merged_weight.detach().to(dtype=original_dtype, device="cpu")
|
|
453
|
+
|
|
454
|
+
def cal_loss_i(self, delta_tv, proj_s, proj_u, proj_v):
|
|
455
|
+
"""
|
|
456
|
+
Calculate the projection loss for a single task.
|
|
457
|
+
|
|
458
|
+
Computes the Frobenius norm of the projection of the weight difference
|
|
459
|
+
onto the SVD subspace(s) defined by U and/or V matrices.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
delta_tv: Difference between merged weight and finetuned weight (task vector difference).
|
|
463
|
+
proj_s: Singular values from SVD of the task vector.
|
|
464
|
+
proj_u: Left singular vectors (U) from SVD.
|
|
465
|
+
proj_v: Right singular vectors (V) from SVD.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Scalar loss value representing the projection distance.
|
|
469
|
+
"""
|
|
470
|
+
proj_delta_1 = torch.diag(proj_s) @ proj_u.T @ delta_tv
|
|
471
|
+
proj_delta_2 = delta_tv @ proj_v @ torch.diag(proj_s)
|
|
472
|
+
loss_i_u = torch.linalg.matrix_norm(proj_delta_1, ord="fro") ** 2
|
|
473
|
+
loss_i_v = torch.linalg.matrix_norm(proj_delta_2, ord="fro") ** 2
|
|
474
|
+
if self.svd_proj_space == "uv":
|
|
475
|
+
loss_i = loss_i_u + loss_i_v
|
|
476
|
+
elif self.svd_proj_space == "u":
|
|
477
|
+
loss_i = loss_i_u
|
|
478
|
+
elif self.svd_proj_space == "v":
|
|
479
|
+
loss_i = loss_i_v
|
|
480
|
+
else:
|
|
481
|
+
raise ValueError("Invalid svd_proj_space")
|
|
482
|
+
|
|
483
|
+
return loss_i
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
if is_ray_available():
|
|
487
|
+
import ray
|
|
488
|
+
|
|
489
|
+
DOPMergingActor = ray.remote(DOPMerging)
|
fusion_bench/method/dop/utils.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Tuple
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import Tensor, nn
|
|
5
5
|
|
|
6
|
+
from fusion_bench.models.utils import is_leaf_module
|
|
6
7
|
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
7
8
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
8
9
|
|
|
@@ -51,10 +52,6 @@ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
|
|
|
51
52
|
return torch.trace(w1.T @ w2)
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
def is_leaf_module(module: nn.Module) -> bool:
|
|
55
|
-
return len(list(module.children())) == 0
|
|
56
|
-
|
|
57
|
-
|
|
58
55
|
def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
|
|
59
56
|
"""
|
|
60
57
|
Get the vector norm of the task model.
|
|
@@ -71,3 +68,26 @@ def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tenso
|
|
|
71
68
|
state_dict_sub(model.state_dict(), pretrained_model.state_dict())
|
|
72
69
|
)
|
|
73
70
|
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def print_params(model):
|
|
74
|
+
total_params = 0
|
|
75
|
+
linear_params = 0
|
|
76
|
+
linear_weight_params = 0
|
|
77
|
+
for module_name, module in model.named_modules():
|
|
78
|
+
if not is_leaf_module(module):
|
|
79
|
+
continue
|
|
80
|
+
if isinstance(module, nn.Linear):
|
|
81
|
+
linear_params += sum(p.numel() for n, p in module.named_parameters())
|
|
82
|
+
linear_weight_params += sum(
|
|
83
|
+
p.numel() for n, p in module.named_parameters() if "weight" in n
|
|
84
|
+
)
|
|
85
|
+
total_params += sum(p.numel() for p in module.parameters())
|
|
86
|
+
|
|
87
|
+
linear_ratio = linear_params / total_params * 100
|
|
88
|
+
linear_weight_ratio = linear_weight_params / total_params * 100
|
|
89
|
+
print(f"Total Parameters: {total_params}")
|
|
90
|
+
print(f"Linear Parameters: {linear_params}")
|
|
91
|
+
print(f"Linear Weight Parameters: {linear_weight_params}")
|
|
92
|
+
print(f"Linear Ratio: {linear_ratio:.2f}%")
|
|
93
|
+
print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .emr_merging import EMRMerging
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
2
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
3
|
+
|
|
4
|
+
from .utils import EMRModulatedModel, EMRTaskModulator, emr_merge
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@auto_register_config
|
|
8
|
+
class EMRMerging(BaseAlgorithm):
|
|
9
|
+
"""
|
|
10
|
+
EMR Merging Algorithm.
|
|
11
|
+
|
|
12
|
+
This algorithm merges multiple task-specific models into a unified model using
|
|
13
|
+
the Elect, Mask & Rescale (EMR) strategy. It constructs a modulated model that
|
|
14
|
+
can adapt to different tasks via task-specific modulators.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def load_pretrained_model_and_task_vectors(self, modelpool: BaseModelPool):
|
|
18
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
19
|
+
|
|
20
|
+
task_vectors = []
|
|
21
|
+
for model_name in modelpool.model_names:
|
|
22
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
23
|
+
task_vector = state_dict_sub(
|
|
24
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
25
|
+
)
|
|
26
|
+
task_vectors.append(task_vector)
|
|
27
|
+
|
|
28
|
+
return pretrained_model, task_vectors
|
|
29
|
+
|
|
30
|
+
def run(self, modelpool: BaseModelPool) -> EMRModulatedModel:
|
|
31
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
32
|
+
modelpool = BaseModelPool(modelpool)
|
|
33
|
+
|
|
34
|
+
pretrained_model, task_vectors = (
|
|
35
|
+
modelpool.load_pretrained_model_and_task_vectors()
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
unified_vector, masks, rescalers = emr_merge(task_vectors)
|
|
39
|
+
|
|
40
|
+
emr_model = EMRModulatedModel(
|
|
41
|
+
backbone=pretrained_model, modulators={}, unified_task_vector=unified_vector
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
for model_idx, model_name in enumerate(modelpool.model_names):
|
|
45
|
+
emr_model.add_modulator(
|
|
46
|
+
task_name=model_name,
|
|
47
|
+
modulator=EMRTaskModulator(
|
|
48
|
+
mask={n: m[model_idx] for n, m in masks.items()},
|
|
49
|
+
rescaler=rescalers[model_idx],
|
|
50
|
+
),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return emr_model
|