fusion-bench 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
- fusion_bench/method/DOGE_TA/__init__.py +2 -0
- fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
- fusion_bench/method/__init__.py +22 -0
- fusion_bench/method/classification/continual_clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +30 -13
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the implementation of the Isotropic Merging in Common Subspace (ISO-C) algorithm and Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS) algorithm.
|
|
3
|
+
Modified from the original implementation: https://github.com/danielm1405/iso-merging
|
|
4
|
+
|
|
5
|
+
Reference:
|
|
6
|
+
- Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
|
|
7
|
+
https://arxiv.org/abs/2502.04959
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .iso import (
|
|
11
|
+
ISO_C_Merge,
|
|
12
|
+
ISO_CTS_Merge,
|
|
13
|
+
IsotropicMergingInCommonAndTaskSubspace,
|
|
14
|
+
IsotropicMergingInCommonSubspace,
|
|
15
|
+
)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
6
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
7
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
8
|
+
state_dict_add,
|
|
9
|
+
state_dict_mul,
|
|
10
|
+
state_dict_sub,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .iso_utils import check_parameterNamesMatch, iso_c, iso_cts
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class IsotropicMergingInCommonSubspace(BaseAlgorithm, LightningFabricMixin):
|
|
17
|
+
"""
|
|
18
|
+
Isotropic Merging in Common Subspace (Iso-C)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
scaling_factor: float,
|
|
24
|
+
exclude_keys: List[str] = None,
|
|
25
|
+
):
|
|
26
|
+
self.scaling_factor = scaling_factor
|
|
27
|
+
self.exclude_keys = exclude_keys
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
def run(self, modelpool: BaseModelPool):
|
|
31
|
+
# load the pretrained model and the task vectors of all the finetuned models
|
|
32
|
+
with torch.no_grad():
|
|
33
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
34
|
+
task_vectors = []
|
|
35
|
+
for model_name in modelpool.model_names:
|
|
36
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
37
|
+
task_vectors.append(
|
|
38
|
+
state_dict_sub(
|
|
39
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
del finetuned_model # free memory
|
|
43
|
+
check_parameterNamesMatch(task_vectors)
|
|
44
|
+
|
|
45
|
+
# compute the merged task vector
|
|
46
|
+
merged_tv = iso_c(
|
|
47
|
+
task_vectors,
|
|
48
|
+
accelerator=self.fabric.device,
|
|
49
|
+
exclude_keys=self.exclude_keys,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
|
|
53
|
+
pretrained_model.load_state_dict(
|
|
54
|
+
state_dict_add(
|
|
55
|
+
pretrained_model.state_dict(),
|
|
56
|
+
state_dict_mul(merged_tv, self.scaling_factor),
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return pretrained_model
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class IsotropicMergingInCommonAndTaskSubspace(BaseAlgorithm, LightningFabricMixin):
|
|
64
|
+
"""
|
|
65
|
+
Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS)
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
scaling_factor: float,
|
|
71
|
+
common_space_fraction: float,
|
|
72
|
+
exclude_keys: List[str] = None,
|
|
73
|
+
):
|
|
74
|
+
self.common_space_fraction = common_space_fraction
|
|
75
|
+
self.scaling_factor = scaling_factor
|
|
76
|
+
self.exclude_keys = exclude_keys
|
|
77
|
+
super().__init__()
|
|
78
|
+
|
|
79
|
+
def run(self, modelpool: BaseModelPool):
|
|
80
|
+
# load the pretrained model and the task vectors of all the finetuned models
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
83
|
+
task_vectors = []
|
|
84
|
+
for model_name in modelpool.model_names:
|
|
85
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
86
|
+
task_vectors.append(
|
|
87
|
+
state_dict_sub(
|
|
88
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
del finetuned_model # free memory
|
|
92
|
+
check_parameterNamesMatch(task_vectors)
|
|
93
|
+
|
|
94
|
+
# compute the merged task vector
|
|
95
|
+
merged_tv = iso_cts(
|
|
96
|
+
task_vectors,
|
|
97
|
+
common_space_fraction=self.common_space_fraction,
|
|
98
|
+
accelerator=self.fabric.device,
|
|
99
|
+
exclude_keys=self.exclude_keys,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
|
|
103
|
+
pretrained_model.load_state_dict(
|
|
104
|
+
state_dict_add(
|
|
105
|
+
pretrained_model.state_dict(),
|
|
106
|
+
state_dict_mul(merged_tv, self.scaling_factor),
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return pretrained_model
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
ISO_C_Merge = IsotropicMergingInCommonSubspace # alias
|
|
114
|
+
ISO_CTS_Merge = IsotropicMergingInCommonAndTaskSubspace # alias
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils import timeit_context
|
|
7
|
+
from fusion_bench.utils.type import StateDictType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def iso_c(
|
|
11
|
+
task_vectors: List[StateDictType],
|
|
12
|
+
accelerator="cuda",
|
|
13
|
+
exclude_keys: List[str] = None,
|
|
14
|
+
) -> StateDictType:
|
|
15
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
16
|
+
|
|
17
|
+
with torch.no_grad(), timeit_context("ISO-C Merging"):
|
|
18
|
+
new_vector = {}
|
|
19
|
+
for key in task_vectors[0]:
|
|
20
|
+
print(f"Merging {key}...")
|
|
21
|
+
original_device = task_vectors[0][key].device
|
|
22
|
+
tvs = [
|
|
23
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
24
|
+
for task_vector in task_vectors
|
|
25
|
+
]
|
|
26
|
+
num_tvs = len(tvs)
|
|
27
|
+
new_vector[key] = sum(tvs) / num_tvs
|
|
28
|
+
del tvs # free memory
|
|
29
|
+
|
|
30
|
+
if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
|
|
31
|
+
# if the key is a 2D matrix, we need to merge the task vectors in the common space
|
|
32
|
+
new_vector[key] *= num_tvs
|
|
33
|
+
U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False)
|
|
34
|
+
S_mean = torch.ones_like(S) * S.mean()
|
|
35
|
+
|
|
36
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
37
|
+
(
|
|
38
|
+
U,
|
|
39
|
+
torch.diag(S_mean),
|
|
40
|
+
V,
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
new_vector[key] = new_vector[key].to(
|
|
44
|
+
device=original_device, non_blocking=True
|
|
45
|
+
)
|
|
46
|
+
return new_vector
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@torch.no_grad()
|
|
50
|
+
def iso_cts(
|
|
51
|
+
task_vectors: List[StateDictType],
|
|
52
|
+
common_space_fraction: float,
|
|
53
|
+
accelerator: str = "cuda",
|
|
54
|
+
exclude_keys: List[str] = None,
|
|
55
|
+
):
|
|
56
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
57
|
+
new_vector = {}
|
|
58
|
+
|
|
59
|
+
print("ISO-CTS Merging")
|
|
60
|
+
for key in task_vectors[0]:
|
|
61
|
+
shape_ = task_vectors[0][key].shape
|
|
62
|
+
original_device = task_vectors[0][key].device
|
|
63
|
+
is_2d_matrix = (len(shape_) == 2) and (key not in exclude_keys)
|
|
64
|
+
if not is_2d_matrix:
|
|
65
|
+
print(f"Combining by avg {key}...")
|
|
66
|
+
for i, task_vector in enumerate(task_vectors):
|
|
67
|
+
vec = task_vector[key].to(device=accelerator, non_blocking=True)
|
|
68
|
+
if i == 0:
|
|
69
|
+
new_vector[key] = vec.clone()
|
|
70
|
+
else:
|
|
71
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
72
|
+
|
|
73
|
+
# move the new vector to the original device
|
|
74
|
+
new_vector[key] = new_vector[key].to(
|
|
75
|
+
device=original_device, non_blocking=True
|
|
76
|
+
)
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
print(f"Computing common space using sum for {key}...")
|
|
80
|
+
combined_w = sum(
|
|
81
|
+
[
|
|
82
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
83
|
+
for task_vector in task_vectors
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
### Calculate the common space size (making sure that task specific space is equally divisible) ###
|
|
88
|
+
common_space_index_s = int(min(shape_) * common_space_fraction)
|
|
89
|
+
_task_specific_total_space_index_s = round(
|
|
90
|
+
(min(shape_) - common_space_index_s) / len(task_vectors)
|
|
91
|
+
) * len(task_vectors)
|
|
92
|
+
common_space_index_s = min(shape_) - _task_specific_total_space_index_s
|
|
93
|
+
|
|
94
|
+
u, s, v = torch.linalg.svd(combined_w, full_matrices=False)
|
|
95
|
+
common_space_u = u[:, :common_space_index_s]
|
|
96
|
+
common_space_s = s[:common_space_index_s]
|
|
97
|
+
common_space_v = v[:common_space_index_s, :]
|
|
98
|
+
###################################################################
|
|
99
|
+
|
|
100
|
+
### Calculate task specific space ###
|
|
101
|
+
n_dims_per_task = int((min(shape_) - common_space_index_s) / len(task_vectors))
|
|
102
|
+
for i, task_vector in enumerate(task_vectors):
|
|
103
|
+
w = task_vector[key].to(device=accelerator)
|
|
104
|
+
|
|
105
|
+
# calculate the projection onto task specific space to remove the common space
|
|
106
|
+
w_ts = w - common_space_u @ common_space_u.T @ w
|
|
107
|
+
u_ts, s_ts, v_ts = torch.linalg.svd(w_ts, full_matrices=False)
|
|
108
|
+
|
|
109
|
+
if i == 0:
|
|
110
|
+
combined_space_u = torch.zeros_like(u_ts, device=accelerator)
|
|
111
|
+
combined_space_s = torch.zeros_like(s_ts, device=accelerator)
|
|
112
|
+
combined_space_v = torch.zeros_like(v_ts, device=accelerator)
|
|
113
|
+
|
|
114
|
+
combined_space_u[:, i * n_dims_per_task : (i + 1) * n_dims_per_task] = u_ts[
|
|
115
|
+
:, :n_dims_per_task
|
|
116
|
+
]
|
|
117
|
+
combined_space_s[i * n_dims_per_task : (i + 1) * n_dims_per_task] = s_ts[
|
|
118
|
+
:n_dims_per_task
|
|
119
|
+
]
|
|
120
|
+
combined_space_v[i * n_dims_per_task : (i + 1) * n_dims_per_task, :] = v_ts[
|
|
121
|
+
:n_dims_per_task, :
|
|
122
|
+
]
|
|
123
|
+
###################################################################
|
|
124
|
+
|
|
125
|
+
combined_space_u[
|
|
126
|
+
:,
|
|
127
|
+
len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
|
|
128
|
+
+ common_space_index_s,
|
|
129
|
+
] = common_space_u
|
|
130
|
+
combined_space_s[
|
|
131
|
+
len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
|
|
132
|
+
+ common_space_index_s
|
|
133
|
+
] = common_space_s
|
|
134
|
+
combined_space_v[
|
|
135
|
+
len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
|
|
136
|
+
+ common_space_index_s,
|
|
137
|
+
:,
|
|
138
|
+
] = common_space_v
|
|
139
|
+
|
|
140
|
+
### Orthogonalize combined_space_u and combined_space_v ###
|
|
141
|
+
u_combined_space_u, s_combined_space_u, v_combined_space_u = torch.linalg.svd(
|
|
142
|
+
combined_space_u, full_matrices=False
|
|
143
|
+
)
|
|
144
|
+
u_combined_space_v, s_combined_space_v, v_combined_space_v = torch.linalg.svd(
|
|
145
|
+
combined_space_v, full_matrices=False
|
|
146
|
+
)
|
|
147
|
+
combined_space_u = u_combined_space_u @ v_combined_space_u
|
|
148
|
+
combined_space_v = u_combined_space_v @ v_combined_space_v
|
|
149
|
+
###################################################################
|
|
150
|
+
|
|
151
|
+
combined_space_s = torch.ones_like(combined_space_s) * combined_space_s.mean()
|
|
152
|
+
|
|
153
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
154
|
+
(
|
|
155
|
+
combined_space_u,
|
|
156
|
+
torch.diag(combined_space_s),
|
|
157
|
+
combined_space_v,
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
new_vector[key] = new_vector[key].to(device=original_device, non_blocking=True)
|
|
161
|
+
|
|
162
|
+
return new_vector
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def check_parameterNamesMatch(checkpoints):
|
|
166
|
+
parameter_names = set(checkpoints[0].keys())
|
|
167
|
+
|
|
168
|
+
if len(checkpoints) >= 2:
|
|
169
|
+
# raise ValueError("Number of models is less than 2.")
|
|
170
|
+
for checkpoint in checkpoints[1:]:
|
|
171
|
+
current_parameterNames = set(checkpoint.keys())
|
|
172
|
+
if current_parameterNames != parameter_names:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"Differing parameter names in models. "
|
|
175
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
176
|
+
)
|
|
@@ -9,15 +9,20 @@ fusion_bench \
|
|
|
9
9
|
```
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from typing import List, Optional
|
|
12
|
+
from typing import Iterable, List, Optional, Union
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
|
+
from omegaconf import ListConfig
|
|
15
16
|
from torch import Tensor, nn
|
|
16
17
|
|
|
17
18
|
from fusion_bench import BaseAlgorithm
|
|
18
19
|
from fusion_bench.mixins import LightningFabricMixin
|
|
19
20
|
from fusion_bench.utils import timeit_context
|
|
20
|
-
from fusion_bench.utils.state_dict_arithmetic import
|
|
21
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
22
|
+
state_dict_add,
|
|
23
|
+
state_dict_mul,
|
|
24
|
+
state_dict_sub,
|
|
25
|
+
)
|
|
21
26
|
from fusion_bench.utils.type import StateDictType
|
|
22
27
|
|
|
23
28
|
from .utils import (
|
|
@@ -33,9 +38,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
33
38
|
|
|
34
39
|
def __init__(
|
|
35
40
|
self,
|
|
41
|
+
alpha: Union[float, Iterable[float]] = None,
|
|
36
42
|
remove_keys: Optional[List[str]] = None,
|
|
37
43
|
**kwargs,
|
|
38
44
|
):
|
|
45
|
+
self.alpha = alpha
|
|
39
46
|
self.remove_keys = remove_keys if remove_keys is not None else []
|
|
40
47
|
super().__init__(**kwargs)
|
|
41
48
|
|
|
@@ -50,6 +57,14 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
50
57
|
|
|
51
58
|
with timeit_context("Flattening out Checkpoints"):
|
|
52
59
|
task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
|
|
60
|
+
if isinstance(self.alpha, Iterable):
|
|
61
|
+
assert len(self.alpha) == len(
|
|
62
|
+
task_vectors
|
|
63
|
+
), "Alpha and task vectors must have the same length"
|
|
64
|
+
task_vectors = [
|
|
65
|
+
state_dict_mul(state_dict=tv, scalar=alpha)
|
|
66
|
+
for alpha, tv in zip(self.alpha, task_vectors)
|
|
67
|
+
]
|
|
53
68
|
|
|
54
69
|
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
55
70
|
task_vectors,
|
|
@@ -57,6 +72,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
57
72
|
accelerator=self.fabric.device,
|
|
58
73
|
)
|
|
59
74
|
|
|
75
|
+
# If alpha is a float, we need to scale the new merged task vector by alpha
|
|
76
|
+
if self.alpha is not None and isinstance(self.alpha, float):
|
|
77
|
+
print(f"Scaling new merged task vector by alpha: {self.alpha}")
|
|
78
|
+
new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
|
|
79
|
+
|
|
60
80
|
pretrained_model.load_state_dict(
|
|
61
81
|
state_dict_add(new_merged_tv, pretrained_model.state_dict())
|
|
62
82
|
)
|