fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +108 -3
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/scripts/cli.py +19 -8
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +18 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from fusion_bench import StateDictType, TorchModelType
|
|
8
|
+
from fusion_bench.models.modulator import ModulatedModel, TaskModulator
|
|
9
|
+
from fusion_bench.models.modulator.base import ModulatedModel, TaskModulator
|
|
10
|
+
from fusion_bench.models.parameter_dict import ParameterDictModel
|
|
11
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _sign(x: torch.Tensor) -> torch.Tensor:
|
|
15
|
+
"""
|
|
16
|
+
Return the sign of the tensor: 1 for positive, -1 for negative.
|
|
17
|
+
Zeros are treated as negative (i.e., sign -1).
|
|
18
|
+
"""
|
|
19
|
+
return (x > 0) * 2 - 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def emr_merge(task_vectors: list[StateDictType]):
|
|
23
|
+
"""
|
|
24
|
+
Modified from original EMR merging function to return unified vector, masks, and rescalers.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
task_vectors: List of task-specific vectors (state dicts).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
vector_unified: The unified task vector (state dict).
|
|
31
|
+
masks: Dict mapping parameter names to list of task-specific masks (tensors).
|
|
32
|
+
rescalers: Tensor of rescaling factors for each task.
|
|
33
|
+
"""
|
|
34
|
+
num_tasks = len(task_vectors)
|
|
35
|
+
|
|
36
|
+
# compute the sign flag
|
|
37
|
+
# \gamma_uni = sign( sum_i tau_i )
|
|
38
|
+
flag_dict = {k: _sign(v) for k, v in state_dict_sum(task_vectors).items()}
|
|
39
|
+
|
|
40
|
+
# \tau_uni
|
|
41
|
+
vector_unified = {}
|
|
42
|
+
scales = torch.zeros(num_tasks)
|
|
43
|
+
# mask indicate if the direction of the task vector aligns with the unified vector
|
|
44
|
+
# {<param_name>: [mask_task1, mask_task2, ...]}
|
|
45
|
+
masks: dict[str, list[torch.Tensor]] = {}
|
|
46
|
+
for n, flag in flag_dict.items():
|
|
47
|
+
masks[n] = []
|
|
48
|
+
param_max = torch.zeros_like(task_vectors[0][n])
|
|
49
|
+
for m in range(num_tasks):
|
|
50
|
+
param = task_vectors[m][n]
|
|
51
|
+
mask = (param * flag) > 0
|
|
52
|
+
masks[n].append(mask)
|
|
53
|
+
param_abs = torch.abs(mask * param)
|
|
54
|
+
param_max = torch.where(param_abs > param_max, param_abs, param_max)
|
|
55
|
+
scales[m] += torch.mean(torch.abs(param))
|
|
56
|
+
vector_unified[n] = param_max * flag
|
|
57
|
+
|
|
58
|
+
new_scales = torch.zeros(num_tasks)
|
|
59
|
+
for m in range(num_tasks):
|
|
60
|
+
for n in vector_unified:
|
|
61
|
+
p = vector_unified[n] * masks[n][m]
|
|
62
|
+
new_scales[m] += torch.mean(torch.abs(p))
|
|
63
|
+
rescalers = scales / new_scales
|
|
64
|
+
|
|
65
|
+
return vector_unified, masks, rescalers
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EMRModulatedModel(ModulatedModel[TorchModelType]):
|
|
69
|
+
"""
|
|
70
|
+
Modulated Model for EMR Merging.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
backbone: TorchModelType,
|
|
76
|
+
modulators: Dict[str, "EMRTaskModulator"],
|
|
77
|
+
unified_task_vector: StateDictType,
|
|
78
|
+
):
|
|
79
|
+
super().__init__(backbone, modulators)
|
|
80
|
+
|
|
81
|
+
unified_task_vector = unified_task_vector.copy()
|
|
82
|
+
for name, tensor in unified_task_vector.items():
|
|
83
|
+
if not isinstance(tensor, (nn.Parameter, nn.Buffer)):
|
|
84
|
+
unified_task_vector[name] = nn.Parameter(tensor, requires_grad=False)
|
|
85
|
+
self.unified_task_vector = ParameterDictModel(unified_task_vector)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class EMRTaskModulator(TaskModulator[TorchModelType]):
|
|
89
|
+
"""
|
|
90
|
+
Task Modulator for EMR (Elect, Mask & Rescale) Merging.
|
|
91
|
+
|
|
92
|
+
This modulator applies task-specific adaptations to a unified model by:
|
|
93
|
+
1. Masking: Aligning direction with task-specific model (mask sets inconsistent signs to zero)
|
|
94
|
+
2. Rescaling: Aligning magnitude with task-specific model
|
|
95
|
+
|
|
96
|
+
The application formula is:
|
|
97
|
+
θ_new = θ_old + τ_unified ⊙ mask_i * rescaler_i
|
|
98
|
+
|
|
99
|
+
where:
|
|
100
|
+
- τ_unified is the unified task vector (elected from all task vectors)
|
|
101
|
+
- mask_i is the task-specific binary mask
|
|
102
|
+
- rescaler_i is the task-specific rescaling factor
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
vector: The unified task vector (τ_unified) as a state dict
|
|
106
|
+
mask: Task-specific binary mask as a dict of tensors
|
|
107
|
+
rescaler: Task-specific rescaling factor (scalar)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
mask: Dict[str, torch.Tensor],
|
|
113
|
+
rescaler: float,
|
|
114
|
+
):
|
|
115
|
+
super().__init__()
|
|
116
|
+
|
|
117
|
+
# Store masks separately with a prefix to avoid conflicts
|
|
118
|
+
mask = mask.copy()
|
|
119
|
+
for name, tensor in mask.items():
|
|
120
|
+
if not isinstance(tensor, (nn.Parameter, nn.Buffer)):
|
|
121
|
+
mask[name] = nn.Parameter(tensor, requires_grad=False)
|
|
122
|
+
self.mask = ParameterDictModel(mask)
|
|
123
|
+
|
|
124
|
+
# Register rescaler as a parameter for proper device handling
|
|
125
|
+
self.rescaler = nn.Parameter(torch.tensor(rescaler), requires_grad=False)
|
|
126
|
+
|
|
127
|
+
@torch.no_grad()
|
|
128
|
+
def apply(self, modulated_model: "EMRModulatedModel[TorchModelType]"):
|
|
129
|
+
"""
|
|
130
|
+
Apply the EMR task vector to the model.
|
|
131
|
+
|
|
132
|
+
For each parameter in the state dict:
|
|
133
|
+
θ_new = θ_old + τ_unified ⊙ mask_i * rescaler_i
|
|
134
|
+
|
|
135
|
+
This applies the masked and rescaled unified task vector to align the backbone
|
|
136
|
+
with the task-specific model.
|
|
137
|
+
"""
|
|
138
|
+
unified_vector = modulated_model.unified_task_vector
|
|
139
|
+
|
|
140
|
+
for name in unified_vector:
|
|
141
|
+
delta = unified_vector[name] * self.mask[name] * self.rescaler
|
|
142
|
+
param = modulated_model.backbone.get_parameter(name)
|
|
143
|
+
param.add_(delta)
|
|
144
|
+
|
|
145
|
+
@torch.no_grad()
|
|
146
|
+
def remove(self, modulated_model: "EMRModulatedModel[TorchModelType]"):
|
|
147
|
+
"""
|
|
148
|
+
Remove the EMR task vector from the model.
|
|
149
|
+
|
|
150
|
+
For each parameter in the state dict:
|
|
151
|
+
θ_old = θ_new - τ_unified ⊙ mask_i * rescaler_i
|
|
152
|
+
|
|
153
|
+
This reverses the task-specific adaptation to restore the original backbone.
|
|
154
|
+
"""
|
|
155
|
+
unified_vector = modulated_model.unified_task_vector
|
|
156
|
+
|
|
157
|
+
for name in unified_vector:
|
|
158
|
+
delta = unified_vector[name] * self.mask[name] * self.rescaler
|
|
159
|
+
param = modulated_model.backbone.get_parameter(name)
|
|
160
|
+
param.sub_(delta)
|
|
161
|
+
|
|
162
|
+
modulated_model._current_task = None
|
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -16,22 +16,23 @@ from transformers import CLIPVisionModel
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
18
|
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.models.utils import is_leaf_module
|
|
19
20
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
21
|
from fusion_bench.utils import instantiate
|
|
21
22
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
22
23
|
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
23
24
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
24
25
|
|
|
25
|
-
from .utils import frobenius_inner_product, get_task_vector_norm,
|
|
26
|
+
from .utils import frobenius_inner_product, get_task_vector_norm, svd
|
|
26
27
|
|
|
27
28
|
if TYPE_CHECKING:
|
|
28
29
|
from torch.utils.tensorboard import SummaryWriter
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class OPCMForCLIP(
|
|
32
|
-
BaseAlgorithm,
|
|
33
33
|
LightningFabricMixin,
|
|
34
34
|
SimpleProfilerMixin,
|
|
35
|
+
BaseAlgorithm,
|
|
35
36
|
):
|
|
36
37
|
def __init__(
|
|
37
38
|
self,
|
|
@@ -219,6 +220,9 @@ class OPCMForCLIP(
|
|
|
219
220
|
return merged_model
|
|
220
221
|
|
|
221
222
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
223
|
+
if self.log_dir is None:
|
|
224
|
+
print("Log dir is None, skip saving merged model.")
|
|
225
|
+
return
|
|
222
226
|
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
223
227
|
merged_model.save_pretrained(
|
|
224
228
|
Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
|
|
16
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
17
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
18
|
+
from fusion_bench.models.utils import is_leaf_module, named_leaf_modules
|
|
19
|
+
from fusion_bench.utils import instantiate
|
|
20
|
+
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
21
|
+
from fusion_bench.utils.packages import is_ray_available
|
|
22
|
+
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
23
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
24
|
+
|
|
25
|
+
from .utils import frobenius_inner_product, get_task_vector_norm, svd
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@auto_register_config
|
|
32
|
+
class OPCM(
|
|
33
|
+
LightningFabricMixin,
|
|
34
|
+
SimpleProfilerMixin,
|
|
35
|
+
BaseAlgorithm,
|
|
36
|
+
):
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
alpha: float,
|
|
40
|
+
shuffle_order: bool = True,
|
|
41
|
+
seed: Optional[int] = None,
|
|
42
|
+
save_on_every_step: bool = True,
|
|
43
|
+
evaluate_on_every_step: bool = False,
|
|
44
|
+
num_ray_actors: int = 0,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Continual Model Merging via SVD Projection.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
alpha (float): the scaling factor for the SVD projection.
|
|
52
|
+
shuffle_order (bool): whether to shuffle the order of the models.
|
|
53
|
+
seed (Optional[int]): the seed to use.
|
|
54
|
+
save_on_every_step (bool): whether to save the merged model on every step.
|
|
55
|
+
evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
|
|
56
|
+
"""
|
|
57
|
+
self.alpha = alpha
|
|
58
|
+
self.shuffle_order = shuffle_order
|
|
59
|
+
self.seed = seed
|
|
60
|
+
self.save_on_every_step = save_on_every_step
|
|
61
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
62
|
+
super().__init__(**kwargs)
|
|
63
|
+
|
|
64
|
+
@torch.no_grad()
|
|
65
|
+
def run(self, modelpool: BaseModelPool):
|
|
66
|
+
if self.num_ray_actors > 0:
|
|
67
|
+
if is_ray_available():
|
|
68
|
+
import ray
|
|
69
|
+
from ray.util.actor_pool import ActorPool
|
|
70
|
+
|
|
71
|
+
if not ray.is_initialized():
|
|
72
|
+
ray.init()
|
|
73
|
+
|
|
74
|
+
# create actors
|
|
75
|
+
if self.fabric.device.type == "cuda":
|
|
76
|
+
actor_options = {"num_gpus": 1}
|
|
77
|
+
else:
|
|
78
|
+
actor_options = {}
|
|
79
|
+
self.ray_actor_pool = ActorPool(
|
|
80
|
+
[
|
|
81
|
+
OPCMActor.options(**actor_options).remote(**self.config)
|
|
82
|
+
for _ in range(self.num_ray_actors)
|
|
83
|
+
]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.seed is not None:
|
|
87
|
+
L.seed_everything(self.seed)
|
|
88
|
+
|
|
89
|
+
with self.profile("loading model"):
|
|
90
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
91
|
+
|
|
92
|
+
model_names = modelpool.model_names
|
|
93
|
+
if self.shuffle_order:
|
|
94
|
+
random.shuffle(model_names)
|
|
95
|
+
|
|
96
|
+
# log the model names
|
|
97
|
+
if self.log_dir is not None:
|
|
98
|
+
save_to_json(model_names, Path(self.log_dir) / "model_names.json")
|
|
99
|
+
tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
|
|
100
|
+
tensorboard_summarywriter.add_text(
|
|
101
|
+
"global/model_names", str(model_names), global_step=0
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# get the average model
|
|
105
|
+
with self.profile("loading model"):
|
|
106
|
+
print("Using the first model as the initial merged model.")
|
|
107
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
108
|
+
assert merged_model is not None, "Failed to load the first model"
|
|
109
|
+
|
|
110
|
+
self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
111
|
+
self.all_task_vector_norm = [self.avg_task_vector_norm]
|
|
112
|
+
self.fabric.log("model/task_vector_norm", self.avg_task_vector_norm, step=0)
|
|
113
|
+
self.fabric.log("model/avg_task_vector_norm", self.avg_task_vector_norm, step=0)
|
|
114
|
+
self.fabric.log(
|
|
115
|
+
"model/merged_task_vector_norm", self.avg_task_vector_norm, step=0
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.previous_lambda_t = 1
|
|
119
|
+
self.lambda_t = None
|
|
120
|
+
self.fabric.log("model/lambda_t", self.previous_lambda_t, step=0)
|
|
121
|
+
self.fabric.log("empirical/lambda_t", 1, step=0)
|
|
122
|
+
|
|
123
|
+
if self.save_on_every_step:
|
|
124
|
+
self.save_merged_model(merged_model, 0)
|
|
125
|
+
|
|
126
|
+
for model_idx, model_name in tqdm(
|
|
127
|
+
enumerate(model_names[1:]), desc="Processing models"
|
|
128
|
+
):
|
|
129
|
+
model_idx += 1
|
|
130
|
+
with self.profile("loading model"):
|
|
131
|
+
task_model = modelpool.load_model(model_name)
|
|
132
|
+
|
|
133
|
+
with self.profile("merging model"):
|
|
134
|
+
self.all_task_vector_norm.append(
|
|
135
|
+
get_task_vector_norm(task_model, pretrained_model)
|
|
136
|
+
)
|
|
137
|
+
self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
|
|
138
|
+
self.fabric.log(
|
|
139
|
+
"model/task_vector_norm",
|
|
140
|
+
self.all_task_vector_norm[-1],
|
|
141
|
+
step=model_idx,
|
|
142
|
+
)
|
|
143
|
+
self.fabric.log(
|
|
144
|
+
"model/avg_task_vector_norm",
|
|
145
|
+
self.avg_task_vector_norm,
|
|
146
|
+
step=model_idx,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
self.lambda_t = 1 # temporary value
|
|
150
|
+
|
|
151
|
+
self._layer_wise_merge(
|
|
152
|
+
merged_model=merged_model,
|
|
153
|
+
pretrained_model=pretrained_model,
|
|
154
|
+
task_model=task_model,
|
|
155
|
+
model_name=model_name,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
159
|
+
self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
|
|
160
|
+
for param_name, param in merged_model.named_parameters():
|
|
161
|
+
param.data = pretrained_model.get_parameter(param_name) + (
|
|
162
|
+
param - pretrained_model.get_parameter(param_name)
|
|
163
|
+
) * (self.avg_task_vector_norm / task_vector_norm)
|
|
164
|
+
self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
|
|
165
|
+
self.fabric.log(
|
|
166
|
+
"empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
|
|
167
|
+
)
|
|
168
|
+
self.previous_lambda_t = self.lambda_t
|
|
169
|
+
self.lambda_t = None
|
|
170
|
+
|
|
171
|
+
self.fabric.log(
|
|
172
|
+
"model/merged_task_vector_norm",
|
|
173
|
+
get_task_vector_norm(merged_model, pretrained_model),
|
|
174
|
+
step=model_idx,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if self.save_on_every_step:
|
|
178
|
+
with self.profile("saving model"):
|
|
179
|
+
self.save_merged_model(merged_model, model_idx)
|
|
180
|
+
|
|
181
|
+
self.print_profile_summary()
|
|
182
|
+
return merged_model
|
|
183
|
+
|
|
184
|
+
def _layer_wise_merge(self, merged_model, pretrained_model, task_model, model_name):
|
|
185
|
+
if self.num_ray_actors > 0:
|
|
186
|
+
self._update_attributes_across_ray()
|
|
187
|
+
|
|
188
|
+
for module_name, module in tqdm(
|
|
189
|
+
list(named_leaf_modules(merged_model, ignore_empty=True)),
|
|
190
|
+
desc=f"Processing {model_name}",
|
|
191
|
+
leave=False,
|
|
192
|
+
disable=self.num_ray_actors > 0,
|
|
193
|
+
):
|
|
194
|
+
if isinstance(module, nn.Linear):
|
|
195
|
+
# processing linear layers
|
|
196
|
+
merge_kwargs = {
|
|
197
|
+
"merged_W": module.weight,
|
|
198
|
+
"pretrained_W": pretrained_model.get_submodule(module_name).weight,
|
|
199
|
+
"task_W": task_model.get_submodule(module_name).weight,
|
|
200
|
+
"param_name": ".".join([module_name, "weight"]),
|
|
201
|
+
"alpha": self.alpha,
|
|
202
|
+
}
|
|
203
|
+
if not self.num_ray_actors > 0:
|
|
204
|
+
_, merged_weight = self.merge_linear_weights(**merge_kwargs)
|
|
205
|
+
module.weight.data = merged_weight
|
|
206
|
+
else:
|
|
207
|
+
if not self.ray_actor_pool.has_free():
|
|
208
|
+
returned_module_name, merged_weight = (
|
|
209
|
+
self.ray_actor_pool.get_next_unordered()
|
|
210
|
+
)
|
|
211
|
+
print(f"merged weight {returned_module_name} from ray actors.")
|
|
212
|
+
pretrained_model.get_submodule(
|
|
213
|
+
returned_module_name
|
|
214
|
+
).weight.data = merged_weight
|
|
215
|
+
self.ray_actor_pool.submit(
|
|
216
|
+
lambda actor, kwargs: actor.merge_linear_weights.remote(
|
|
217
|
+
**kwargs
|
|
218
|
+
),
|
|
219
|
+
merge_kwargs,
|
|
220
|
+
)
|
|
221
|
+
# processing bias if exists
|
|
222
|
+
if module.bias is not None:
|
|
223
|
+
module.bias.data = self.merge_other_parameters(
|
|
224
|
+
module.bias,
|
|
225
|
+
pretrained_model.get_submodule(module_name).bias,
|
|
226
|
+
task_model.get_submodule(module_name).bias,
|
|
227
|
+
param_name=".".join([module_name, "bias"]),
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
for param_name, param in module.named_parameters():
|
|
231
|
+
param.data = self.merge_other_parameters(
|
|
232
|
+
merged_W=param,
|
|
233
|
+
pretrained_W=pretrained_model.get_submodule(
|
|
234
|
+
module_name
|
|
235
|
+
).get_parameter(param_name),
|
|
236
|
+
task_W=task_model.get_submodule(module_name).get_parameter(
|
|
237
|
+
param_name
|
|
238
|
+
),
|
|
239
|
+
param_name=".".join([module_name, param_name]),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if self.num_ray_actors > 0:
|
|
243
|
+
while self.ray_actor_pool.has_next():
|
|
244
|
+
returned_module_name, merged_weight = (
|
|
245
|
+
self.ray_actor_pool.get_next_unordered()
|
|
246
|
+
)
|
|
247
|
+
print(f"merged weight {returned_module_name} from ray actors.")
|
|
248
|
+
merged_model.get_submodule(returned_module_name).weight.data = (
|
|
249
|
+
merged_weight
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def save_merged_model(self, merged_model, step: int):
|
|
253
|
+
if self.log_dir is None:
|
|
254
|
+
print("Log dir is None, skip saving merged model.")
|
|
255
|
+
return
|
|
256
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
257
|
+
merged_model.save_pretrained(
|
|
258
|
+
Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def _update_attributes_across_ray(self, attr_dict=None):
|
|
262
|
+
if attr_dict is None:
|
|
263
|
+
# called on master
|
|
264
|
+
attrs_to_sync = ["previous_lambda_t", "lambda_t"]
|
|
265
|
+
assert (
|
|
266
|
+
not self.ray_actor_pool.has_next()
|
|
267
|
+
), "All previous tasks must be merged before syncing attributes."
|
|
268
|
+
|
|
269
|
+
for actor in self.ray_actor_pool._idle_actors:
|
|
270
|
+
actor._update_attributes_across_ray.remote(
|
|
271
|
+
{attr: getattr(self, attr) for attr in attrs_to_sync}
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
# called on ray actors
|
|
275
|
+
for attr, value in attr_dict.items():
|
|
276
|
+
setattr(self, attr, value)
|
|
277
|
+
|
|
278
|
+
def merge_linear_weights(
|
|
279
|
+
self,
|
|
280
|
+
merged_W: Tensor,
|
|
281
|
+
pretrained_W: Tensor,
|
|
282
|
+
task_W: Tensor,
|
|
283
|
+
param_name: str,
|
|
284
|
+
alpha: float,
|
|
285
|
+
):
|
|
286
|
+
accelerator = self.fabric.device
|
|
287
|
+
|
|
288
|
+
original_device = merged_W.device
|
|
289
|
+
merged_W = merged_W.to(accelerator)
|
|
290
|
+
pretrained_W = pretrained_W.to(accelerator)
|
|
291
|
+
task_W = task_W.to(accelerator)
|
|
292
|
+
|
|
293
|
+
previous_merged_tv = merged_W - pretrained_W
|
|
294
|
+
task_tv = task_W - pretrained_W
|
|
295
|
+
|
|
296
|
+
u, s, v = svd(previous_merged_tv)
|
|
297
|
+
rank = s.size(0)
|
|
298
|
+
split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
|
|
299
|
+
|
|
300
|
+
projected_task_tv = u.T @ task_tv @ v
|
|
301
|
+
projected_task_tv.diagonal().fill_(0)
|
|
302
|
+
|
|
303
|
+
projected_task_tv[:split_rank, :split_rank] = 0
|
|
304
|
+
|
|
305
|
+
cleaned_task_tv = u @ projected_task_tv @ v.T
|
|
306
|
+
|
|
307
|
+
previous_lambda_t = self.previous_lambda_t
|
|
308
|
+
lambda_t = self.lambda_t
|
|
309
|
+
new_merged_W = (
|
|
310
|
+
pretrained_W
|
|
311
|
+
+ (previous_lambda_t * previous_merged_tv + cleaned_task_tv) / lambda_t
|
|
312
|
+
)
|
|
313
|
+
module_name = param_name[: param_name.rfind(".")]
|
|
314
|
+
return module_name, new_merged_W.to(original_device)
|
|
315
|
+
|
|
316
|
+
def merge_other_parameters(
|
|
317
|
+
self,
|
|
318
|
+
merged_W: Tensor,
|
|
319
|
+
pretrained_W: Tensor,
|
|
320
|
+
task_W: Tensor,
|
|
321
|
+
param_name: str,
|
|
322
|
+
):
|
|
323
|
+
accelerator = self.fabric.device
|
|
324
|
+
|
|
325
|
+
original_device = merged_W.device
|
|
326
|
+
merged_W = merged_W.to(accelerator)
|
|
327
|
+
pretrained_W = pretrained_W.to(accelerator)
|
|
328
|
+
task_W = task_W.to(accelerator)
|
|
329
|
+
|
|
330
|
+
previous_merged_tv = merged_W - pretrained_W
|
|
331
|
+
task_tv = task_W - pretrained_W
|
|
332
|
+
|
|
333
|
+
previous_lambda_t = self.previous_lambda_t
|
|
334
|
+
lambda_t = self.lambda_t
|
|
335
|
+
|
|
336
|
+
new_merged_W = (
|
|
337
|
+
pretrained_W + (previous_lambda_t * previous_merged_tv + task_tv) / lambda_t
|
|
338
|
+
)
|
|
339
|
+
return new_merged_W.to(original_device)
|
|
340
|
+
|
|
341
|
+
def compute_lambda_t(
|
|
342
|
+
self, previous_merged_tv: Tensor, task_tv: Tensor, previous_lambda_t: float
|
|
343
|
+
):
|
|
344
|
+
previous_merged_tv = torch.flatten(previous_merged_tv)
|
|
345
|
+
task_tv = torch.flatten(task_tv)
|
|
346
|
+
|
|
347
|
+
lambda_t = torch.linalg.vector_norm(
|
|
348
|
+
previous_lambda_t * previous_merged_tv + task_tv
|
|
349
|
+
) / torch.linalg.vector_norm(previous_merged_tv)
|
|
350
|
+
return lambda_t.item()
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
if is_ray_available():
|
|
354
|
+
import ray
|
|
355
|
+
|
|
356
|
+
OPCMActor = ray.remote(OPCM)
|
|
@@ -3,6 +3,7 @@ from typing import Tuple
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import Tensor, nn
|
|
5
5
|
|
|
6
|
+
from fusion_bench.models.utils import is_leaf_module
|
|
6
7
|
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
7
8
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
8
9
|
|
|
@@ -51,10 +52,6 @@ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
|
|
|
51
52
|
return torch.trace(w1.T @ w2)
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
def is_leaf_module(module: nn.Module) -> bool:
|
|
55
|
-
return len(list(module.children())) == 0
|
|
56
|
-
|
|
57
|
-
|
|
58
55
|
def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
|
|
59
56
|
"""
|
|
60
57
|
Get the vector norm of the task model.
|