fusion-bench 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. fusion_bench/compat/method/__init__.py +5 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
  4. fusion_bench/method/DOGE_TA/__init__.py +2 -0
  5. fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
  6. fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
  7. fusion_bench/method/__init__.py +22 -0
  8. fusion_bench/method/classification/continual_clip_finetune.py +1 -1
  9. fusion_bench/method/concrete_subspace/__init__.py +8 -0
  10. fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
  11. fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
  12. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  13. fusion_bench/method/isotropic_merging/iso.py +114 -0
  14. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  15. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  16. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
  17. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
  18. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +30 -13
  19. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
  20. fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
  21. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
  22. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
  23. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
  24. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
  25. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  26. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  27. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  28. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
  29. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
  30. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,15 @@
1
+ """
2
+ This module contains the implementation of the Isotropic Merging in Common Subspace (ISO-C) algorithm and Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS) algorithm.
3
+ Modified from the original implementation: https://github.com/danielm1405/iso-merging
4
+
5
+ Reference:
6
+ - Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
7
+ https://arxiv.org/abs/2502.04959
8
+ """
9
+
10
+ from .iso import (
11
+ ISO_C_Merge,
12
+ ISO_CTS_Merge,
13
+ IsotropicMergingInCommonAndTaskSubspace,
14
+ IsotropicMergingInCommonSubspace,
15
+ )
@@ -0,0 +1,114 @@
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ from fusion_bench import BaseAlgorithm, BaseModelPool
6
+ from fusion_bench.mixins import LightningFabricMixin
7
+ from fusion_bench.utils.state_dict_arithmetic import (
8
+ state_dict_add,
9
+ state_dict_mul,
10
+ state_dict_sub,
11
+ )
12
+
13
+ from .iso_utils import check_parameterNamesMatch, iso_c, iso_cts
14
+
15
+
16
+ class IsotropicMergingInCommonSubspace(BaseAlgorithm, LightningFabricMixin):
17
+ """
18
+ Isotropic Merging in Common Subspace (Iso-C)
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ scaling_factor: float,
24
+ exclude_keys: List[str] = None,
25
+ ):
26
+ self.scaling_factor = scaling_factor
27
+ self.exclude_keys = exclude_keys
28
+ super().__init__()
29
+
30
+ def run(self, modelpool: BaseModelPool):
31
+ # load the pretrained model and the task vectors of all the finetuned models
32
+ with torch.no_grad():
33
+ pretrained_model = modelpool.load_pretrained_model()
34
+ task_vectors = []
35
+ for model_name in modelpool.model_names:
36
+ finetuned_model = modelpool.load_model(model_name)
37
+ task_vectors.append(
38
+ state_dict_sub(
39
+ finetuned_model.state_dict(), pretrained_model.state_dict()
40
+ )
41
+ )
42
+ del finetuned_model # free memory
43
+ check_parameterNamesMatch(task_vectors)
44
+
45
+ # compute the merged task vector
46
+ merged_tv = iso_c(
47
+ task_vectors,
48
+ accelerator=self.fabric.device,
49
+ exclude_keys=self.exclude_keys,
50
+ )
51
+
52
+ # merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
53
+ pretrained_model.load_state_dict(
54
+ state_dict_add(
55
+ pretrained_model.state_dict(),
56
+ state_dict_mul(merged_tv, self.scaling_factor),
57
+ )
58
+ )
59
+
60
+ return pretrained_model
61
+
62
+
63
+ class IsotropicMergingInCommonAndTaskSubspace(BaseAlgorithm, LightningFabricMixin):
64
+ """
65
+ Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS)
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ scaling_factor: float,
71
+ common_space_fraction: float,
72
+ exclude_keys: List[str] = None,
73
+ ):
74
+ self.common_space_fraction = common_space_fraction
75
+ self.scaling_factor = scaling_factor
76
+ self.exclude_keys = exclude_keys
77
+ super().__init__()
78
+
79
+ def run(self, modelpool: BaseModelPool):
80
+ # load the pretrained model and the task vectors of all the finetuned models
81
+ with torch.no_grad():
82
+ pretrained_model = modelpool.load_pretrained_model()
83
+ task_vectors = []
84
+ for model_name in modelpool.model_names:
85
+ finetuned_model = modelpool.load_model(model_name)
86
+ task_vectors.append(
87
+ state_dict_sub(
88
+ finetuned_model.state_dict(), pretrained_model.state_dict()
89
+ )
90
+ )
91
+ del finetuned_model # free memory
92
+ check_parameterNamesMatch(task_vectors)
93
+
94
+ # compute the merged task vector
95
+ merged_tv = iso_cts(
96
+ task_vectors,
97
+ common_space_fraction=self.common_space_fraction,
98
+ accelerator=self.fabric.device,
99
+ exclude_keys=self.exclude_keys,
100
+ )
101
+
102
+ # merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
103
+ pretrained_model.load_state_dict(
104
+ state_dict_add(
105
+ pretrained_model.state_dict(),
106
+ state_dict_mul(merged_tv, self.scaling_factor),
107
+ )
108
+ )
109
+
110
+ return pretrained_model
111
+
112
+
113
+ ISO_C_Merge = IsotropicMergingInCommonSubspace # alias
114
+ ISO_CTS_Merge = IsotropicMergingInCommonAndTaskSubspace # alias
@@ -0,0 +1,176 @@
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+ from fusion_bench.utils import timeit_context
7
+ from fusion_bench.utils.type import StateDictType
8
+
9
+
10
+ def iso_c(
11
+ task_vectors: List[StateDictType],
12
+ accelerator="cuda",
13
+ exclude_keys: List[str] = None,
14
+ ) -> StateDictType:
15
+ exclude_keys = [] if exclude_keys is None else exclude_keys
16
+
17
+ with torch.no_grad(), timeit_context("ISO-C Merging"):
18
+ new_vector = {}
19
+ for key in task_vectors[0]:
20
+ print(f"Merging {key}...")
21
+ original_device = task_vectors[0][key].device
22
+ tvs = [
23
+ task_vector[key].to(device=accelerator, non_blocking=True)
24
+ for task_vector in task_vectors
25
+ ]
26
+ num_tvs = len(tvs)
27
+ new_vector[key] = sum(tvs) / num_tvs
28
+ del tvs # free memory
29
+
30
+ if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
31
+ # if the key is a 2D matrix, we need to merge the task vectors in the common space
32
+ new_vector[key] *= num_tvs
33
+ U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False)
34
+ S_mean = torch.ones_like(S) * S.mean()
35
+
36
+ new_vector[key] = torch.linalg.multi_dot(
37
+ (
38
+ U,
39
+ torch.diag(S_mean),
40
+ V,
41
+ )
42
+ )
43
+ new_vector[key] = new_vector[key].to(
44
+ device=original_device, non_blocking=True
45
+ )
46
+ return new_vector
47
+
48
+
49
+ @torch.no_grad()
50
+ def iso_cts(
51
+ task_vectors: List[StateDictType],
52
+ common_space_fraction: float,
53
+ accelerator: str = "cuda",
54
+ exclude_keys: List[str] = None,
55
+ ):
56
+ exclude_keys = [] if exclude_keys is None else exclude_keys
57
+ new_vector = {}
58
+
59
+ print("ISO-CTS Merging")
60
+ for key in task_vectors[0]:
61
+ shape_ = task_vectors[0][key].shape
62
+ original_device = task_vectors[0][key].device
63
+ is_2d_matrix = (len(shape_) == 2) and (key not in exclude_keys)
64
+ if not is_2d_matrix:
65
+ print(f"Combining by avg {key}...")
66
+ for i, task_vector in enumerate(task_vectors):
67
+ vec = task_vector[key].to(device=accelerator, non_blocking=True)
68
+ if i == 0:
69
+ new_vector[key] = vec.clone()
70
+ else:
71
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
72
+
73
+ # move the new vector to the original device
74
+ new_vector[key] = new_vector[key].to(
75
+ device=original_device, non_blocking=True
76
+ )
77
+ continue
78
+
79
+ print(f"Computing common space using sum for {key}...")
80
+ combined_w = sum(
81
+ [
82
+ task_vector[key].to(device=accelerator, non_blocking=True)
83
+ for task_vector in task_vectors
84
+ ]
85
+ )
86
+
87
+ ### Calculate the common space size (making sure that task specific space is equally divisible) ###
88
+ common_space_index_s = int(min(shape_) * common_space_fraction)
89
+ _task_specific_total_space_index_s = round(
90
+ (min(shape_) - common_space_index_s) / len(task_vectors)
91
+ ) * len(task_vectors)
92
+ common_space_index_s = min(shape_) - _task_specific_total_space_index_s
93
+
94
+ u, s, v = torch.linalg.svd(combined_w, full_matrices=False)
95
+ common_space_u = u[:, :common_space_index_s]
96
+ common_space_s = s[:common_space_index_s]
97
+ common_space_v = v[:common_space_index_s, :]
98
+ ###################################################################
99
+
100
+ ### Calculate task specific space ###
101
+ n_dims_per_task = int((min(shape_) - common_space_index_s) / len(task_vectors))
102
+ for i, task_vector in enumerate(task_vectors):
103
+ w = task_vector[key].to(device=accelerator)
104
+
105
+ # calculate the projection onto task specific space to remove the common space
106
+ w_ts = w - common_space_u @ common_space_u.T @ w
107
+ u_ts, s_ts, v_ts = torch.linalg.svd(w_ts, full_matrices=False)
108
+
109
+ if i == 0:
110
+ combined_space_u = torch.zeros_like(u_ts, device=accelerator)
111
+ combined_space_s = torch.zeros_like(s_ts, device=accelerator)
112
+ combined_space_v = torch.zeros_like(v_ts, device=accelerator)
113
+
114
+ combined_space_u[:, i * n_dims_per_task : (i + 1) * n_dims_per_task] = u_ts[
115
+ :, :n_dims_per_task
116
+ ]
117
+ combined_space_s[i * n_dims_per_task : (i + 1) * n_dims_per_task] = s_ts[
118
+ :n_dims_per_task
119
+ ]
120
+ combined_space_v[i * n_dims_per_task : (i + 1) * n_dims_per_task, :] = v_ts[
121
+ :n_dims_per_task, :
122
+ ]
123
+ ###################################################################
124
+
125
+ combined_space_u[
126
+ :,
127
+ len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
128
+ + common_space_index_s,
129
+ ] = common_space_u
130
+ combined_space_s[
131
+ len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
132
+ + common_space_index_s
133
+ ] = common_space_s
134
+ combined_space_v[
135
+ len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
136
+ + common_space_index_s,
137
+ :,
138
+ ] = common_space_v
139
+
140
+ ### Orthogonalize combined_space_u and combined_space_v ###
141
+ u_combined_space_u, s_combined_space_u, v_combined_space_u = torch.linalg.svd(
142
+ combined_space_u, full_matrices=False
143
+ )
144
+ u_combined_space_v, s_combined_space_v, v_combined_space_v = torch.linalg.svd(
145
+ combined_space_v, full_matrices=False
146
+ )
147
+ combined_space_u = u_combined_space_u @ v_combined_space_u
148
+ combined_space_v = u_combined_space_v @ v_combined_space_v
149
+ ###################################################################
150
+
151
+ combined_space_s = torch.ones_like(combined_space_s) * combined_space_s.mean()
152
+
153
+ new_vector[key] = torch.linalg.multi_dot(
154
+ (
155
+ combined_space_u,
156
+ torch.diag(combined_space_s),
157
+ combined_space_v,
158
+ )
159
+ )
160
+ new_vector[key] = new_vector[key].to(device=original_device, non_blocking=True)
161
+
162
+ return new_vector
163
+
164
+
165
+ def check_parameterNamesMatch(checkpoints):
166
+ parameter_names = set(checkpoints[0].keys())
167
+
168
+ if len(checkpoints) >= 2:
169
+ # raise ValueError("Number of models is less than 2.")
170
+ for checkpoint in checkpoints[1:]:
171
+ current_parameterNames = set(checkpoint.keys())
172
+ if current_parameterNames != parameter_names:
173
+ raise ValueError(
174
+ "Differing parameter names in models. "
175
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
176
+ )
@@ -9,15 +9,20 @@ fusion_bench \
9
9
  ```
10
10
  """
11
11
 
12
- from typing import List, Optional
12
+ from typing import Iterable, List, Optional, Union
13
13
 
14
14
  import torch
15
+ from omegaconf import ListConfig
15
16
  from torch import Tensor, nn
16
17
 
17
18
  from fusion_bench import BaseAlgorithm
18
19
  from fusion_bench.mixins import LightningFabricMixin
19
20
  from fusion_bench.utils import timeit_context
20
- from fusion_bench.utils.state_dict_arithmetic import state_dict_sub, state_dict_add
21
+ from fusion_bench.utils.state_dict_arithmetic import (
22
+ state_dict_add,
23
+ state_dict_mul,
24
+ state_dict_sub,
25
+ )
21
26
  from fusion_bench.utils.type import StateDictType
22
27
 
23
28
  from .utils import (
@@ -33,9 +38,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
33
38
 
34
39
  def __init__(
35
40
  self,
41
+ alpha: Union[float, Iterable[float]] = None,
36
42
  remove_keys: Optional[List[str]] = None,
37
43
  **kwargs,
38
44
  ):
45
+ self.alpha = alpha
39
46
  self.remove_keys = remove_keys if remove_keys is not None else []
40
47
  super().__init__(**kwargs)
41
48
 
@@ -50,6 +57,14 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
50
57
 
51
58
  with timeit_context("Flattening out Checkpoints"):
52
59
  task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
60
+ if isinstance(self.alpha, Iterable):
61
+ assert len(self.alpha) == len(
62
+ task_vectors
63
+ ), "Alpha and task vectors must have the same length"
64
+ task_vectors = [
65
+ state_dict_mul(state_dict=tv, scalar=alpha)
66
+ for alpha, tv in zip(self.alpha, task_vectors)
67
+ ]
53
68
 
54
69
  new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
55
70
  task_vectors,
@@ -57,6 +72,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
57
72
  accelerator=self.fabric.device,
58
73
  )
59
74
 
75
+ # If alpha is a float, we need to scale the new merged task vector by alpha
76
+ if self.alpha is not None and isinstance(self.alpha, float):
77
+ print(f"Scaling new merged task vector by alpha: {self.alpha}")
78
+ new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
79
+
60
80
  pretrained_model.load_state_dict(
61
81
  state_dict_add(new_merged_tv, pretrained_model.state_dict())
62
82
  )