fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/programs/fabric_fusion_program.py +5 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1013 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import random
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Dict, List, Literal, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from scipy.stats import ortho_group
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
|
|
12
|
+
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
13
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
14
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
15
|
+
from fusion_bench.utils.parameters import get_parameter_summary, human_readable
|
|
16
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_avg
|
|
17
|
+
from fusion_bench.utils.type import StateDictType
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def cosine_similarity(tensor1: Tensor, tensor2: Tensor) -> float:
|
|
23
|
+
if tensor1.shape != tensor2.shape:
|
|
24
|
+
raise ValueError("Matrices must have the same shape")
|
|
25
|
+
vec1 = tensor1.flatten()
|
|
26
|
+
vec2 = tensor2.flatten()
|
|
27
|
+
dot_product = torch.sum(vec1 * vec2)
|
|
28
|
+
norm1 = torch.sqrt(torch.sum(vec1**2))
|
|
29
|
+
norm2 = torch.sqrt(torch.sum(vec2**2))
|
|
30
|
+
if norm1 == 0 or norm2 == 0:
|
|
31
|
+
return 0.0
|
|
32
|
+
return dot_product / (norm1 * norm2)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def svd_and_partition(
|
|
36
|
+
A: torch.Tensor, num_chunks: int = 3
|
|
37
|
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
|
|
38
|
+
U, S, V = torch.svd(A)
|
|
39
|
+
singular_values = len(S)
|
|
40
|
+
chunk_size = singular_values // num_chunks
|
|
41
|
+
U_chunks, S_chunks, V_chunks = [], [], []
|
|
42
|
+
|
|
43
|
+
for i in range(num_chunks):
|
|
44
|
+
start_idx = i * chunk_size
|
|
45
|
+
end_idx = singular_values if i == num_chunks - 1 else start_idx + chunk_size
|
|
46
|
+
|
|
47
|
+
U_chunks.append(U[:, start_idx:end_idx])
|
|
48
|
+
S_chunks.append(S[start_idx:end_idx])
|
|
49
|
+
V_chunks.append(V[:, start_idx:end_idx])
|
|
50
|
+
|
|
51
|
+
return U_chunks, S_chunks, V_chunks
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def compute_svd_subspace_similarity(
|
|
55
|
+
ref: torch.Tensor, retrieval: torch.Tensor, num_chunks: int = 3
|
|
56
|
+
) -> List[dict]:
|
|
57
|
+
if torch.cuda.is_available():
|
|
58
|
+
ref = ref.cuda()
|
|
59
|
+
retrieval = retrieval.cuda()
|
|
60
|
+
U_chunks, S_chunks, V_chunks = svd_and_partition(ref, num_chunks)
|
|
61
|
+
similarities = []
|
|
62
|
+
for i in range(num_chunks):
|
|
63
|
+
retrieval_approx = (
|
|
64
|
+
U_chunks[i] @ U_chunks[i].T @ retrieval @ V_chunks[i] @ V_chunks[i].T
|
|
65
|
+
)
|
|
66
|
+
frob_sim = torch.norm(ref - retrieval_approx, p="fro").item() / ref.numel()
|
|
67
|
+
cos_sim = cosine_similarity(ref, retrieval_approx)
|
|
68
|
+
if isinstance(cos_sim, torch.Tensor):
|
|
69
|
+
cos_sim = cos_sim.item()
|
|
70
|
+
similarities.append(
|
|
71
|
+
{
|
|
72
|
+
"subspace": i + 1,
|
|
73
|
+
"frobenius_similarity": frob_sim,
|
|
74
|
+
"cosine_similarity": cos_sim,
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
return similarities
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def pairwise_cosine_similarity_matrix(tensors: List[torch.Tensor]) -> torch.Tensor:
|
|
81
|
+
if torch.cuda.is_available():
|
|
82
|
+
tensors = [tensor.cuda() for tensor in tensors]
|
|
83
|
+
n = len(tensors)
|
|
84
|
+
similarity_matrix = torch.zeros((n, n))
|
|
85
|
+
for i in range(n):
|
|
86
|
+
for j in range(n):
|
|
87
|
+
similarity = cosine_similarity(tensors[i], tensors[j])
|
|
88
|
+
similarity_matrix[i, j] = similarity.item()
|
|
89
|
+
return similarity_matrix
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def compare_models(
|
|
93
|
+
state_dict1: StateDictType, state_dict2: StateDictType, target_layers=None
|
|
94
|
+
):
|
|
95
|
+
results = {
|
|
96
|
+
"layerwise_l2": {},
|
|
97
|
+
"layerwise_cosine_similarity": {},
|
|
98
|
+
"total_l2": None,
|
|
99
|
+
"average_l2": None,
|
|
100
|
+
"total_cosine_similarity": None,
|
|
101
|
+
"average_cosine_similarity": None,
|
|
102
|
+
}
|
|
103
|
+
# Initialize lists to store flattened parameters
|
|
104
|
+
params1_list = []
|
|
105
|
+
params2_list = []
|
|
106
|
+
|
|
107
|
+
keys1 = set(state_dict1.keys())
|
|
108
|
+
keys2 = set(state_dict2.keys())
|
|
109
|
+
# filter out layers that are not in target_layers
|
|
110
|
+
if target_layers is not None:
|
|
111
|
+
keys1 = keys1.intersection(target_layers)
|
|
112
|
+
keys2 = keys2.intersection(target_layers)
|
|
113
|
+
|
|
114
|
+
common_keys = keys1 & keys2
|
|
115
|
+
if keys1 != keys2:
|
|
116
|
+
print(
|
|
117
|
+
"Warning: State dicts have different keys. Comparison will be made on common keys only."
|
|
118
|
+
)
|
|
119
|
+
num_layers = len(common_keys)
|
|
120
|
+
|
|
121
|
+
for key in common_keys:
|
|
122
|
+
tensor1 = state_dict1[key].float()
|
|
123
|
+
tensor2 = state_dict2[key].float()
|
|
124
|
+
|
|
125
|
+
# Compute L2 norm difference
|
|
126
|
+
l2_diff = torch.norm(tensor1 - tensor2, p=2) / tensor1.numel()
|
|
127
|
+
results["layerwise_l2"][key] = l2_diff.item()
|
|
128
|
+
|
|
129
|
+
# Compute cosine similarity
|
|
130
|
+
tensor1_flat = tensor1.reshape(-1)
|
|
131
|
+
tensor2_flat = tensor2.reshape(-1)
|
|
132
|
+
cos_sim = cosine_similarity(tensor1_flat, tensor2_flat).item()
|
|
133
|
+
results["layerwise_cosine_similarity"][key] = cos_sim
|
|
134
|
+
|
|
135
|
+
# Collect parameters for total metrics
|
|
136
|
+
params1_list.append(tensor1_flat)
|
|
137
|
+
params2_list.append(tensor2_flat)
|
|
138
|
+
|
|
139
|
+
# Compute total metrics over all parameters
|
|
140
|
+
if params1_list and params2_list:
|
|
141
|
+
params1 = torch.cat(params1_list)
|
|
142
|
+
params2 = torch.cat(params2_list)
|
|
143
|
+
# Compute total L2 norm difference
|
|
144
|
+
total_l2_difference = (
|
|
145
|
+
torch.norm(params1 - params2, p=2).item() / params1.numel()
|
|
146
|
+
)
|
|
147
|
+
results["total_l2"] = total_l2_difference
|
|
148
|
+
# Compute total cosine similarity
|
|
149
|
+
total_cosine_similarity = cosine_similarity(params1, params2).item()
|
|
150
|
+
results["total_cosine_similarity"] = total_cosine_similarity
|
|
151
|
+
else:
|
|
152
|
+
results["total_l2"] = None
|
|
153
|
+
results["total_cosine_similarity"] = None
|
|
154
|
+
|
|
155
|
+
# Compute average metrics
|
|
156
|
+
if num_layers > 0:
|
|
157
|
+
average_l2 = sum(results["layerwise_l2"].values()) / num_layers
|
|
158
|
+
average_cosine_similarity = (
|
|
159
|
+
sum(results["layerwise_cosine_similarity"].values()) / num_layers
|
|
160
|
+
)
|
|
161
|
+
results["average_l2"] = average_l2
|
|
162
|
+
results["average_cosine_similarity"] = average_cosine_similarity
|
|
163
|
+
else:
|
|
164
|
+
results["average_l2"] = None
|
|
165
|
+
results["average_cosine_similarity"] = None
|
|
166
|
+
|
|
167
|
+
return results
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class SuperposedAlgorithmBase(
|
|
171
|
+
BaseAlgorithm,
|
|
172
|
+
SimpleProfilerMixin,
|
|
173
|
+
):
|
|
174
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
175
|
+
"mode": "mode",
|
|
176
|
+
"target_layer": "target_layer",
|
|
177
|
+
"random_seed": "random_seed",
|
|
178
|
+
"different_across_layers": "different_across_layers",
|
|
179
|
+
"joint_matrix_mode": "joint_matrix_mode",
|
|
180
|
+
"rank": "rank",
|
|
181
|
+
"random_components": "random_components",
|
|
182
|
+
"shift_layers": "shift_layers",
|
|
183
|
+
"absorber": "absorber",
|
|
184
|
+
"debug": "debug",
|
|
185
|
+
"ms_mode": "ms_mode",
|
|
186
|
+
"verbose": "verbose",
|
|
187
|
+
"dropout_rate": "dropout_rate",
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
mode: str,
|
|
193
|
+
target_layer: str,
|
|
194
|
+
random_seed: int,
|
|
195
|
+
different_across_layers: bool,
|
|
196
|
+
joint_matrix_mode: str,
|
|
197
|
+
rank: int,
|
|
198
|
+
random_components: bool,
|
|
199
|
+
shift_layers: int,
|
|
200
|
+
absorber: Literal["average", "pretrained", "None"],
|
|
201
|
+
debug: int,
|
|
202
|
+
ms_mode: str,
|
|
203
|
+
verbose: int,
|
|
204
|
+
dropout_rate: int,
|
|
205
|
+
**kwargs,
|
|
206
|
+
):
|
|
207
|
+
super().__init__(**kwargs)
|
|
208
|
+
self.mode = mode
|
|
209
|
+
self.target_layer = target_layer
|
|
210
|
+
self.random_seed = random_seed
|
|
211
|
+
self.different_across_layers = different_across_layers
|
|
212
|
+
self.joint_matrix_mode = joint_matrix_mode
|
|
213
|
+
self.rank = rank
|
|
214
|
+
self.random_components = random_components
|
|
215
|
+
self.shift_layers = shift_layers
|
|
216
|
+
self.absorber = absorber
|
|
217
|
+
self.debug = debug
|
|
218
|
+
self.ms_mode = ms_mode
|
|
219
|
+
self.verbose = verbose
|
|
220
|
+
self.dropout_rate = dropout_rate
|
|
221
|
+
|
|
222
|
+
def _compute_svd_subspace_similarities(
|
|
223
|
+
self,
|
|
224
|
+
original_state_dict: StateDictType,
|
|
225
|
+
retrieved_state_dict: StateDictType,
|
|
226
|
+
target_layers: Optional[List[str]] = None,
|
|
227
|
+
) -> dict:
|
|
228
|
+
svd_similarities = {}
|
|
229
|
+
for layer_name, original_param in original_state_dict.items():
|
|
230
|
+
if target_layers is not None and layer_name not in target_layers:
|
|
231
|
+
continue
|
|
232
|
+
if (
|
|
233
|
+
original_param.dim() == 2
|
|
234
|
+
): # Only compute for 2D tensors (weight matrices)
|
|
235
|
+
retrieved_param = retrieved_state_dict[layer_name]
|
|
236
|
+
svd_similarities[layer_name] = compute_svd_subspace_similarity(
|
|
237
|
+
original_param.float(), retrieved_param.float()
|
|
238
|
+
)
|
|
239
|
+
return svd_similarities
|
|
240
|
+
|
|
241
|
+
def _load_state_dicts(self, modelpool: BaseModelPool) -> Dict[str, StateDictType]:
|
|
242
|
+
"""
|
|
243
|
+
Load the state dicts of the models in the modelpool.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
modelpool (BaseModelPool): The modelpool to load the state dicts from.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Dict[str, StateDictType]: A dictionary of state dicts, keyed by model name.
|
|
250
|
+
"""
|
|
251
|
+
state_dicts = {}
|
|
252
|
+
for model_name in modelpool.model_names:
|
|
253
|
+
with self.profile("load model"):
|
|
254
|
+
model = modelpool.load_model(model_name)
|
|
255
|
+
state_dicts[model_name] = model.state_dict(keep_vars=True)
|
|
256
|
+
return state_dicts
|
|
257
|
+
|
|
258
|
+
def _compute_absorber(
|
|
259
|
+
self,
|
|
260
|
+
state_dicts: Dict[str, StateDictType],
|
|
261
|
+
pretrained_model: Optional[nn.Module] = None,
|
|
262
|
+
) -> Optional[StateDictType]:
|
|
263
|
+
"""
|
|
264
|
+
Compute the absorber state dict.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
state_dicts (Dict[str, StateDictType]): The state dicts of the models, keyed by model name, i.e. `{model_name: state_dict}`.
|
|
268
|
+
pretrained_model (Optional[nn.Module]): The pretrained model.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Optional[StateDictType]: The absorber state dict.
|
|
272
|
+
"""
|
|
273
|
+
if self.absorber == "average":
|
|
274
|
+
return state_dict_avg(list(state_dicts.values()))
|
|
275
|
+
elif self.absorber == "pretrained":
|
|
276
|
+
return pretrained_model.state_dict(keep_vars=True)
|
|
277
|
+
elif self.absorber == "None":
|
|
278
|
+
return None
|
|
279
|
+
else:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"Unsupported absorber type: {self.absorber}. Must be one of 'average', 'pretrained', or 'None'."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def svd_decomposition(A, r):
|
|
286
|
+
if torch.cuda.is_available():
|
|
287
|
+
A = A.cuda()
|
|
288
|
+
U, S, V = torch.svd(A)
|
|
289
|
+
return (U[:, :r] @ torch.diag(S[:r])).cpu(), V.t()[:r, :].cpu()
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def svd_decomposition_bm(A, r_most, r_mid):
|
|
293
|
+
if torch.cuda.is_available():
|
|
294
|
+
A = A.cuda()
|
|
295
|
+
|
|
296
|
+
# Perform SVD
|
|
297
|
+
U, S, V = torch.svd(A)
|
|
298
|
+
|
|
299
|
+
# Get the most significant 'r_most' dimensions
|
|
300
|
+
U_most = U[:, :r_most]
|
|
301
|
+
S_most = S[:r_most]
|
|
302
|
+
V_most = V[:, :r_most]
|
|
303
|
+
|
|
304
|
+
# Get the middle 'r_mid' dimensions
|
|
305
|
+
start_mid = len(S) // 2 - r_mid // 2
|
|
306
|
+
end_mid = start_mid + r_mid
|
|
307
|
+
U_mid = U[:, start_mid:end_mid]
|
|
308
|
+
S_mid = S[start_mid:end_mid]
|
|
309
|
+
V_mid = V[:, start_mid:end_mid]
|
|
310
|
+
|
|
311
|
+
# Combine the results into two sets
|
|
312
|
+
U_combined = torch.cat([U_most, U_mid], dim=1)
|
|
313
|
+
S_combined = torch.cat([S_most, S_mid])
|
|
314
|
+
V_combined = torch.cat([V_most, V_mid], dim=1)
|
|
315
|
+
|
|
316
|
+
return (U_combined @ torch.diag(S_combined)).cpu(), V_combined.t().cpu()
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def svd_decomposition(A, r=None, r_most=None, r_mid=None, random_components=False):
|
|
320
|
+
"""
|
|
321
|
+
Perform SVD decomposition with options for:
|
|
322
|
+
1. Truncated SVD with 'r' components (if r is provided and random_components=False).
|
|
323
|
+
2. Most significant 'r_most' and middle 'r_mid' components (if r_most and r_mid are provided).
|
|
324
|
+
3. Randomly selected 'r' components (if r is provided and random_components=True).
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
A (torch.Tensor): The input matrix to decompose.
|
|
328
|
+
r (int, optional): Number of components for standard or random SVD.
|
|
329
|
+
r_most (int, optional): Number of most significant components.
|
|
330
|
+
r_mid (int, optional): Number of middle components.
|
|
331
|
+
random_components (bool, optional): Whether to sample 'r' random components.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
(torch.Tensor, torch.Tensor): Two matrices resulting from the SVD decomposition.
|
|
335
|
+
"""
|
|
336
|
+
if torch.cuda.is_available():
|
|
337
|
+
A = A.cuda()
|
|
338
|
+
|
|
339
|
+
# Perform SVD
|
|
340
|
+
U, S, V = torch.svd(A)
|
|
341
|
+
|
|
342
|
+
if r is not None and not random_components:
|
|
343
|
+
# Standard SVD decomposition with 'r' components
|
|
344
|
+
return (U[:, :r] @ torch.diag(S[:r])).cpu(), V.t()[:r, :].cpu()
|
|
345
|
+
|
|
346
|
+
elif r_most is not None and r_mid is not None:
|
|
347
|
+
# SVD decomposition with 'r_most' most significant and 'r_mid' middle components
|
|
348
|
+
# Most significant components
|
|
349
|
+
U_most = U[:, :r_most]
|
|
350
|
+
S_most = S[:r_most]
|
|
351
|
+
V_most = V[:, :r_most]
|
|
352
|
+
|
|
353
|
+
# Middle components
|
|
354
|
+
start_mid = len(S) // 2 - r_mid // 2
|
|
355
|
+
end_mid = start_mid + r_mid
|
|
356
|
+
U_mid = U[:, start_mid:end_mid]
|
|
357
|
+
S_mid = S[start_mid:end_mid]
|
|
358
|
+
V_mid = V[:, start_mid:end_mid]
|
|
359
|
+
|
|
360
|
+
# Combine the most and middle components
|
|
361
|
+
U_combined = torch.cat([U_most, U_mid], dim=1)
|
|
362
|
+
S_combined = torch.cat([S_most, S_mid])
|
|
363
|
+
V_combined = torch.cat([V_most, V_mid], dim=1)
|
|
364
|
+
|
|
365
|
+
return (U_combined @ torch.diag(S_combined)).cpu(), V_combined.t().cpu()
|
|
366
|
+
|
|
367
|
+
elif r is not None and random_components:
|
|
368
|
+
# SVD decomposition with random 'r' components
|
|
369
|
+
indices = torch.randperm(len(S))[:r]
|
|
370
|
+
U_rand = U[:, indices]
|
|
371
|
+
S_rand = S[indices]
|
|
372
|
+
V_rand = V[:, indices]
|
|
373
|
+
|
|
374
|
+
return (U_rand @ torch.diag(S_rand)).cpu(), V_rand.t().cpu()
|
|
375
|
+
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
"Invalid combination of arguments. Provide correct parameters."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
@staticmethod
|
|
382
|
+
def _get_rank(A, rank):
|
|
383
|
+
if isinstance(rank, str):
|
|
384
|
+
r1, r2 = rank.split("-")
|
|
385
|
+
r1 = int(float(r1) * min(A.shape)) if "." in r1 else int(r1)
|
|
386
|
+
r2 = int(float(r2) * min(A.shape)) if "." in r2 else int(r2)
|
|
387
|
+
return r1, r2
|
|
388
|
+
if isinstance(rank, int):
|
|
389
|
+
return rank
|
|
390
|
+
elif isinstance(rank, float):
|
|
391
|
+
return int(rank * min(A.shape))
|
|
392
|
+
|
|
393
|
+
def _target_layer_flag(self, layer: str):
|
|
394
|
+
"""
|
|
395
|
+
The method takes a layer name as input and returns a boolean indicating whether this layer should be targeted.
|
|
396
|
+
|
|
397
|
+
Current implementation assume Transformer architecture and layer number is the first number in the layer name.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
layer (str): The name of the layer.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
bool: True if the layer should be targeted, False otherwise.
|
|
404
|
+
"""
|
|
405
|
+
target_layers = self.target_layer # e.g. ["mlp_w", "attn_w"]
|
|
406
|
+
# TODO: figure out what wo is in flan-t5
|
|
407
|
+
mlp_flag = "mlp" in layer or "Dense" in layer
|
|
408
|
+
attn_flag = "attn" in layer or "Attention" in layer
|
|
409
|
+
weight_flag = "weight" in layer
|
|
410
|
+
bias_flag = "bias" in layer
|
|
411
|
+
target_flags = []
|
|
412
|
+
for target_layer in target_layers:
|
|
413
|
+
if target_layer == "mlp_w":
|
|
414
|
+
target_flags.append(mlp_flag and not bias_flag)
|
|
415
|
+
elif target_layer == "attn_w":
|
|
416
|
+
target_flags.append(attn_flag and not bias_flag)
|
|
417
|
+
elif target_layer == "all":
|
|
418
|
+
target_flags.append(True)
|
|
419
|
+
elif target_layer == "mlp":
|
|
420
|
+
target_flags.append(mlp_flag)
|
|
421
|
+
elif target_layer == "attn":
|
|
422
|
+
target_flags.append(attn_flag)
|
|
423
|
+
else:
|
|
424
|
+
raise ValueError(f"Unsupported target layer: {target_layer}")
|
|
425
|
+
target_flag = any(target_flags)
|
|
426
|
+
return target_flag
|
|
427
|
+
|
|
428
|
+
def _compress_and_retrieve(self, state_dicts: Dict[str, StateDictType], mode: str):
|
|
429
|
+
"""
|
|
430
|
+
Compress and retrieve the state dicts.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
state_dicts (Dict[str, StateDictType]): The state dicts of the models, keyed by model name, i.e. `{model_name: state_dict}`.
|
|
434
|
+
mode (str): The mode of the compression and retrieval.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
Dict[str, StateDictType]: The compressed and retrieved state dicts, keyed by model name, i.e. `{model_name: state_dict}`.
|
|
438
|
+
"""
|
|
439
|
+
# Assume the state_dicts have the same layers.
|
|
440
|
+
layers = state_dicts[list(state_dicts.keys())[0]].keys()
|
|
441
|
+
models = list(state_dicts.keys())
|
|
442
|
+
compressed_layers = {}
|
|
443
|
+
compression_context = {model: {} for model in models}
|
|
444
|
+
retrieval_context = {model: {} for model in models}
|
|
445
|
+
retrieval_models = deepcopy(state_dicts)
|
|
446
|
+
# target_layer_flags = [self._target_layer_flag(layer) for layer in layers]
|
|
447
|
+
# implement target_layer_flags with dropout
|
|
448
|
+
target_layer_flags: List[bool] = []
|
|
449
|
+
count = 0
|
|
450
|
+
for layer in layers:
|
|
451
|
+
if self._target_layer_flag(layer):
|
|
452
|
+
# take the target layer per `self.dropout_rate` target layers.
|
|
453
|
+
# e.g. if self.dropout_rate = 2, then take the 2nd and 4th target layers, skip the first and third target layers.
|
|
454
|
+
# if self.dropout_rate = 1, then take all target layers.
|
|
455
|
+
count += 1
|
|
456
|
+
if count == self.dropout_rate:
|
|
457
|
+
target_layer_flags.append(True)
|
|
458
|
+
count = 0
|
|
459
|
+
else:
|
|
460
|
+
target_layer_flags.append(False)
|
|
461
|
+
else:
|
|
462
|
+
target_layer_flags.append(False)
|
|
463
|
+
|
|
464
|
+
target_layers = [
|
|
465
|
+
layer for layer, flag in zip(layers, target_layer_flags) if flag
|
|
466
|
+
]
|
|
467
|
+
log.info(
|
|
468
|
+
f"filtered {len(target_layers)} target layers out of {len(layers)} layers"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
metadata = {
|
|
472
|
+
"nonzero_parameter_count": 0,
|
|
473
|
+
"nonzero_param_count_context": 0,
|
|
474
|
+
"task_vector_retrieval_similarity": {},
|
|
475
|
+
"superposed_model_retrieval_similarity": {},
|
|
476
|
+
"model_retrieval_similarity": {},
|
|
477
|
+
"target_layers": target_layers,
|
|
478
|
+
"task_vector_svd_subspace_similarities": {},
|
|
479
|
+
"superposed_model_svd_subspace_similarities": {},
|
|
480
|
+
"model_svd_subspace_similarities": {},
|
|
481
|
+
"total_param_count_original": 0,
|
|
482
|
+
"total_gb_original": 0,
|
|
483
|
+
"total_gb_retrieved": 0,
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
if "absorber" in models:
|
|
487
|
+
models.remove("absorber")
|
|
488
|
+
absorber = state_dicts["absorber"]
|
|
489
|
+
else:
|
|
490
|
+
absorber = None
|
|
491
|
+
|
|
492
|
+
# get the total number of parameters and bytes (in GB) of the original model
|
|
493
|
+
original_param_summary = get_parameter_summary(state_dicts[models[0]])
|
|
494
|
+
gbs = original_param_summary["bytes"] / 1e9
|
|
495
|
+
log.info(
|
|
496
|
+
f"Total parameters: {human_readable(original_param_summary['all_param'])}"
|
|
497
|
+
)
|
|
498
|
+
log.info(f"Total gigabytes: {gbs}")
|
|
499
|
+
metadata["total_param_count_original"] = original_param_summary["all_param"]
|
|
500
|
+
metadata["total_gb_original"] = gbs
|
|
501
|
+
|
|
502
|
+
# for analysis purposes
|
|
503
|
+
if self.debug >= 2:
|
|
504
|
+
test_models = models
|
|
505
|
+
# test_models = models[:2]
|
|
506
|
+
# layers_old = {model: OrderedDict() for model in models}
|
|
507
|
+
layers_old = {model: deepcopy(state_dicts[model]) for model in models}
|
|
508
|
+
tv_new = {
|
|
509
|
+
model: {model: OrderedDict() for model in models}
|
|
510
|
+
for model in test_models
|
|
511
|
+
}
|
|
512
|
+
# layers_new = {model: {model: OrderedDict() for model in models} for model in test_models}
|
|
513
|
+
|
|
514
|
+
# Shift the layers
|
|
515
|
+
# TODO: make this more robust to other models.
|
|
516
|
+
if self.shift_layers != 0:
|
|
517
|
+
# random shuffling. Do not shuffle layers with no number in their name.
|
|
518
|
+
# because they are likely to be special layers like text embeddings.
|
|
519
|
+
if self.shift_layers == -1:
|
|
520
|
+
layer_mappings = {model: {} for model in models}
|
|
521
|
+
temp_state_dicts = deepcopy(state_dicts)
|
|
522
|
+
|
|
523
|
+
# get layer number index, assume the first number in the layer name is the layer number
|
|
524
|
+
# assume all numbered layers have their number at the same index
|
|
525
|
+
# assume components separated by '.' in the layer name
|
|
526
|
+
found_digit = False
|
|
527
|
+
for layer_idx, layer in enumerate(layers):
|
|
528
|
+
if target_layer_flags[layer_idx]:
|
|
529
|
+
layer_parts = layer.split(".")
|
|
530
|
+
for i, part in enumerate(layer_parts):
|
|
531
|
+
if part.isdigit():
|
|
532
|
+
layer_number_idx = i
|
|
533
|
+
break
|
|
534
|
+
if found_digit:
|
|
535
|
+
break
|
|
536
|
+
|
|
537
|
+
# get groups of target layers with same name except the layer number
|
|
538
|
+
target_layer_groups = {}
|
|
539
|
+
for layer_idx, layer in enumerate(layers):
|
|
540
|
+
if target_layer_flags[layer_idx]:
|
|
541
|
+
layer_parts = layer.split(".")
|
|
542
|
+
if (
|
|
543
|
+
layer_number_idx >= len(layer_parts)
|
|
544
|
+
or not layer_parts[layer_number_idx].isdigit()
|
|
545
|
+
):
|
|
546
|
+
continue # skip layers without number
|
|
547
|
+
base_name = ".".join(
|
|
548
|
+
layer_parts[:layer_number_idx]
|
|
549
|
+
+ layer_parts[layer_number_idx + 1 :]
|
|
550
|
+
)
|
|
551
|
+
layer_number = int(layer_parts[layer_number_idx])
|
|
552
|
+
if base_name not in target_layer_groups:
|
|
553
|
+
target_layer_groups[base_name] = []
|
|
554
|
+
target_layer_groups[base_name].append(layer_number)
|
|
555
|
+
|
|
556
|
+
# construct random shuffled mapping
|
|
557
|
+
random_shuffle_mapping = {model: {} for model in models}
|
|
558
|
+
for model_idx, model in enumerate(models):
|
|
559
|
+
for layer_idx, layer in enumerate(layers):
|
|
560
|
+
if target_layer_flags[layer_idx]:
|
|
561
|
+
layer_parts = layer.split(".")
|
|
562
|
+
if (
|
|
563
|
+
layer_number_idx >= len(layer_parts)
|
|
564
|
+
or not layer_parts[layer_number_idx].isdigit()
|
|
565
|
+
):
|
|
566
|
+
continue # skip layers without number
|
|
567
|
+
base_name = ".".join(
|
|
568
|
+
layer_parts[:layer_number_idx]
|
|
569
|
+
+ layer_parts[layer_number_idx + 1 :]
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
if base_name not in random_shuffle_mapping[model]:
|
|
573
|
+
rng_state = random.getstate()
|
|
574
|
+
# Shuffle the layer numbers differently for each model
|
|
575
|
+
random.seed(self.config.random_seed + model_idx)
|
|
576
|
+
shuffled_layer_numbers = target_layer_groups[
|
|
577
|
+
base_name
|
|
578
|
+
].copy()
|
|
579
|
+
random.shuffle(shuffled_layer_numbers)
|
|
580
|
+
random_shuffle_mapping[model][base_name] = {
|
|
581
|
+
orig: str(shuffled)
|
|
582
|
+
for orig, shuffled in zip(
|
|
583
|
+
target_layer_groups[base_name],
|
|
584
|
+
shuffled_layer_numbers,
|
|
585
|
+
)
|
|
586
|
+
}
|
|
587
|
+
random.setstate(rng_state)
|
|
588
|
+
|
|
589
|
+
for layer_idx, layer in enumerate(layers):
|
|
590
|
+
if target_layer_flags[layer_idx]:
|
|
591
|
+
layer_parts = layer.split(".")
|
|
592
|
+
if (
|
|
593
|
+
layer_number_idx >= len(layer_parts)
|
|
594
|
+
or not layer_parts[layer_number_idx].isdigit()
|
|
595
|
+
):
|
|
596
|
+
continue # skip layers without number
|
|
597
|
+
base_name = ".".join(
|
|
598
|
+
layer_parts[:layer_number_idx]
|
|
599
|
+
+ layer_parts[layer_number_idx + 1 :]
|
|
600
|
+
)
|
|
601
|
+
layer_number = int(layer_parts[layer_number_idx])
|
|
602
|
+
new_layer_number = random_shuffle_mapping[model][base_name][
|
|
603
|
+
layer_number
|
|
604
|
+
]
|
|
605
|
+
new_layer_name = ".".join(
|
|
606
|
+
layer_parts[:layer_number_idx]
|
|
607
|
+
+ [new_layer_number]
|
|
608
|
+
+ layer_parts[layer_number_idx + 1 :]
|
|
609
|
+
)
|
|
610
|
+
temp_state_dicts[model][new_layer_name] = state_dicts[
|
|
611
|
+
model
|
|
612
|
+
][layer]
|
|
613
|
+
layer_mappings[model][new_layer_name] = layer
|
|
614
|
+
state_dicts = temp_state_dicts
|
|
615
|
+
else:
|
|
616
|
+
layer_numbers = {}
|
|
617
|
+
for layer_idx, layer in enumerate(layers):
|
|
618
|
+
if target_layer_flags[layer_idx]:
|
|
619
|
+
layer_parts = layer.split(".")
|
|
620
|
+
for part in layer_parts:
|
|
621
|
+
if part.isdigit():
|
|
622
|
+
layer_numbers[layer] = int(part)
|
|
623
|
+
break # Only consider the first number for each layer
|
|
624
|
+
if layer_numbers:
|
|
625
|
+
max_layer_number = max(layer_numbers.values())
|
|
626
|
+
else:
|
|
627
|
+
max_layer_number = 0
|
|
628
|
+
temp_state_dicts = deepcopy(state_dicts)
|
|
629
|
+
# Wrap around and shift each model by a different amount
|
|
630
|
+
for model_idx, model in enumerate(models):
|
|
631
|
+
for layer_idx, layer in enumerate(layers):
|
|
632
|
+
target_flag = target_layer_flags[layer_idx]
|
|
633
|
+
if not target_flag:
|
|
634
|
+
continue
|
|
635
|
+
layer_number = layer_numbers.get(layer)
|
|
636
|
+
if layer_number is None:
|
|
637
|
+
continue
|
|
638
|
+
new_layer_number = (
|
|
639
|
+
layer_number + model_idx * self.config.shift_layers
|
|
640
|
+
) % (max_layer_number + 1)
|
|
641
|
+
new_layer_parts = []
|
|
642
|
+
replaced = False # Only replace the first numeric part FIXME: make it more robust
|
|
643
|
+
for part in layer.split("."):
|
|
644
|
+
if part.isdigit() and not replaced:
|
|
645
|
+
new_layer_parts.append(str(new_layer_number))
|
|
646
|
+
replaced = True
|
|
647
|
+
else:
|
|
648
|
+
new_layer_parts.append(part)
|
|
649
|
+
new_layer = ".".join(new_layer_parts)
|
|
650
|
+
temp_state_dicts[model][new_layer] = state_dicts[model][layer]
|
|
651
|
+
state_dicts = temp_state_dicts
|
|
652
|
+
|
|
653
|
+
if self.debug >= 2:
|
|
654
|
+
# for evaluating pairwise cosine similarity
|
|
655
|
+
unmerged_task_vectors = deepcopy(state_dicts)
|
|
656
|
+
|
|
657
|
+
# compress
|
|
658
|
+
for layer_idx, layer in enumerate(layers):
|
|
659
|
+
shape = state_dicts[models[0]][layer].shape
|
|
660
|
+
compressed_layer = None
|
|
661
|
+
target_flag = target_layer_flags[layer_idx]
|
|
662
|
+
# self.verbose = 1
|
|
663
|
+
if self.verbose >= 1:
|
|
664
|
+
log.info(f"{layer} | {shape} | {target_flag}")
|
|
665
|
+
if not target_flag:
|
|
666
|
+
if absorber is not None:
|
|
667
|
+
compressed_layer = absorber[layer]
|
|
668
|
+
else:
|
|
669
|
+
for model in models:
|
|
670
|
+
if compressed_layer is None:
|
|
671
|
+
compressed_layer = deepcopy(state_dicts[model][layer])
|
|
672
|
+
else:
|
|
673
|
+
compressed_layer += deepcopy(state_dicts[model][layer])
|
|
674
|
+
else:
|
|
675
|
+
if self.mode == "random_binary_diagonal_matrix":
|
|
676
|
+
for model_idx, model in enumerate(models):
|
|
677
|
+
if self.different_across_layers:
|
|
678
|
+
seed = self.random_seed + model_idx + hash(layer) % 1e6
|
|
679
|
+
else:
|
|
680
|
+
seed = self.random_seed + model_idx
|
|
681
|
+
numpy_state = np.random.get_state()
|
|
682
|
+
np.random.seed(int(seed))
|
|
683
|
+
context = (
|
|
684
|
+
np.random.binomial(p=0.5, n=1, size=(1, shape[-1])).astype(
|
|
685
|
+
np.float32
|
|
686
|
+
)
|
|
687
|
+
* 2
|
|
688
|
+
- 1
|
|
689
|
+
)
|
|
690
|
+
context = torch.from_numpy(context)
|
|
691
|
+
np.random.set_state(numpy_state)
|
|
692
|
+
compression_context[model][
|
|
693
|
+
layer
|
|
694
|
+
] = context # for analysis purposes
|
|
695
|
+
retrieval_context[model][layer] = context
|
|
696
|
+
if compressed_layer is None:
|
|
697
|
+
compressed_layer = state_dicts[model][layer] * context
|
|
698
|
+
else:
|
|
699
|
+
compressed_layer += state_dicts[model][layer] * context
|
|
700
|
+
if self.debug >= 2:
|
|
701
|
+
# hadamard product is not linear, convert it back to diagonal matrix and apply matrix multiplication
|
|
702
|
+
context_diag = torch.diag(context.squeeze())
|
|
703
|
+
unmerged_task_vectors[model][layer] = (
|
|
704
|
+
unmerged_task_vectors[model][layer] @ context_diag
|
|
705
|
+
)
|
|
706
|
+
elif self.mode == "random_rotation_matrix":
|
|
707
|
+
for model_idx, model in enumerate(models):
|
|
708
|
+
if self.different_across_layers:
|
|
709
|
+
seed = self.random_seed + model_idx + hash(layer) % 1e6
|
|
710
|
+
else:
|
|
711
|
+
seed = self.random_seed + model_idx
|
|
712
|
+
context = torch.from_numpy(
|
|
713
|
+
ortho_group.rvs(shape[-1], random_state=seed).astype(
|
|
714
|
+
"float32"
|
|
715
|
+
)
|
|
716
|
+
)
|
|
717
|
+
compression_context[model][
|
|
718
|
+
layer
|
|
719
|
+
] = context # for analysis purposes
|
|
720
|
+
retrieval_context[model][layer] = context.t()
|
|
721
|
+
if compressed_layer is None:
|
|
722
|
+
compressed_layer = state_dicts[model][layer] @ context
|
|
723
|
+
else:
|
|
724
|
+
compressed_layer += state_dicts[model][layer] @ context
|
|
725
|
+
if self.debug >= 2:
|
|
726
|
+
unmerged_task_vectors[model][layer] = (
|
|
727
|
+
unmerged_task_vectors[model][layer] @ context
|
|
728
|
+
)
|
|
729
|
+
elif self.mode == "random_dense_matrix":
|
|
730
|
+
for model_idx, model in enumerate(models):
|
|
731
|
+
if self.different_across_layers:
|
|
732
|
+
seed = self.random_seed + model_idx + hash(layer) % 1e6
|
|
733
|
+
else:
|
|
734
|
+
seed = self.random_seed + model_idx
|
|
735
|
+
numpy_state = np.random.get_state()
|
|
736
|
+
np.random.seed(int(seed))
|
|
737
|
+
context = torch.from_numpy(
|
|
738
|
+
np.random.randn(shape[-1], shape[-1]).astype(np.float32)
|
|
739
|
+
)
|
|
740
|
+
np.random.set_state(numpy_state)
|
|
741
|
+
compression_context[model][
|
|
742
|
+
layer
|
|
743
|
+
] = context # for analysis purposes
|
|
744
|
+
retrieval_context[model][layer] = torch.linalg.pinv(
|
|
745
|
+
context.to("cuda")
|
|
746
|
+
).to("cpu")
|
|
747
|
+
if compressed_layer is None:
|
|
748
|
+
compressed_layer = state_dicts[model][layer] @ context
|
|
749
|
+
else:
|
|
750
|
+
compressed_layer += state_dicts[model][layer] @ context
|
|
751
|
+
if self.debug >= 2:
|
|
752
|
+
unmerged_task_vectors[model][layer] = (
|
|
753
|
+
unmerged_task_vectors[model][layer] @ context
|
|
754
|
+
)
|
|
755
|
+
elif self.mode == "random_diagonal_matrix":
|
|
756
|
+
for model_idx, model in enumerate(models):
|
|
757
|
+
if self.different_across_layers:
|
|
758
|
+
seed = self.random_seed + model_idx + hash(layer) % 1e6
|
|
759
|
+
else:
|
|
760
|
+
seed = self.random_seed + model_idx
|
|
761
|
+
numpy_state = np.random.get_state()
|
|
762
|
+
np.random.seed(int(seed))
|
|
763
|
+
context = torch.from_numpy(
|
|
764
|
+
np.random.randn(1, shape[-1]).astype(np.float32)
|
|
765
|
+
)
|
|
766
|
+
np.random.set_state(numpy_state)
|
|
767
|
+
compression_context[model][
|
|
768
|
+
layer
|
|
769
|
+
] = context # for analysis purposes
|
|
770
|
+
retrieval_context[model][layer] = 1 / context
|
|
771
|
+
if compressed_layer is None:
|
|
772
|
+
compressed_layer = state_dicts[model][layer] * context
|
|
773
|
+
else:
|
|
774
|
+
compressed_layer += state_dicts[model][layer] * context
|
|
775
|
+
if self.debug >= 2:
|
|
776
|
+
unmerged_task_vectors[model][layer] = (
|
|
777
|
+
unmerged_task_vectors[model][layer] * context
|
|
778
|
+
)
|
|
779
|
+
elif self.mode == "identity_matrix":
|
|
780
|
+
for model_idx, model in enumerate(models):
|
|
781
|
+
context = torch.eye(shape[-1])
|
|
782
|
+
compression_context[model][
|
|
783
|
+
layer
|
|
784
|
+
] = context # for analysis purposes
|
|
785
|
+
retrieval_context[model][layer] = context
|
|
786
|
+
if compressed_layer is None:
|
|
787
|
+
compressed_layer = state_dicts[model][layer] @ context
|
|
788
|
+
else:
|
|
789
|
+
compressed_layer += state_dicts[model][layer] @ context
|
|
790
|
+
if self.debug >= 2:
|
|
791
|
+
unmerged_task_vectors[model][layer] = (
|
|
792
|
+
unmerged_task_vectors[model][layer] @ context
|
|
793
|
+
)
|
|
794
|
+
else:
|
|
795
|
+
raise ValueError(f"Unsupported mode: {self.mode}")
|
|
796
|
+
|
|
797
|
+
compressed_layers[layer] = compressed_layer
|
|
798
|
+
|
|
799
|
+
# retrieve: for purpose of benchmarking, retrieve all models at once. In practice, retrieval should be done per model request.
|
|
800
|
+
nonzero_param_count = 0
|
|
801
|
+
nonzero_param_count_context = 0
|
|
802
|
+
total_bytes_retrieved = 0
|
|
803
|
+
|
|
804
|
+
if self.debug >= 2:
|
|
805
|
+
for model in test_models:
|
|
806
|
+
tv_new[model] = deepcopy(unmerged_task_vectors)
|
|
807
|
+
|
|
808
|
+
for layer_idx, layer in enumerate(layers):
|
|
809
|
+
shape = state_dicts[models[0]][layer].shape
|
|
810
|
+
target_flag = target_layer_flags[layer_idx]
|
|
811
|
+
if not target_flag:
|
|
812
|
+
if mode == "superposed_model_soup":
|
|
813
|
+
# we don't count non-target layers for superposed task arithmetic
|
|
814
|
+
# because they can be absorbed into the pretrained weights
|
|
815
|
+
param_count = torch.numel(compressed_layers[layer])
|
|
816
|
+
total_bytes_retrieved += (
|
|
817
|
+
param_count * compressed_layers[layer].element_size()
|
|
818
|
+
)
|
|
819
|
+
nonzero_param_count += param_count
|
|
820
|
+
for model in models:
|
|
821
|
+
retrieval_models[model][layer] = compressed_layers[layer]
|
|
822
|
+
else:
|
|
823
|
+
if (
|
|
824
|
+
mode == "superposed_task_arithmetic"
|
|
825
|
+
and self.mode == "identity_matrix"
|
|
826
|
+
and self.shift_layers == 0
|
|
827
|
+
):
|
|
828
|
+
# we don't count target layers for task arithmetic
|
|
829
|
+
# because they can be absorbed into the pretrained weights
|
|
830
|
+
pass
|
|
831
|
+
else:
|
|
832
|
+
param_count = torch.numel(compressed_layers[layer])
|
|
833
|
+
total_bytes_retrieved += (
|
|
834
|
+
param_count * compressed_layers[layer].element_size()
|
|
835
|
+
)
|
|
836
|
+
nonzero_param_count += torch.numel(compressed_layers[layer])
|
|
837
|
+
|
|
838
|
+
if self.mode in [
|
|
839
|
+
"random_binary_diagonal_matrix",
|
|
840
|
+
"random_rotation_matrix",
|
|
841
|
+
"random_dense_matrix",
|
|
842
|
+
"random_diagonal_matrix",
|
|
843
|
+
"identity_matrix",
|
|
844
|
+
]:
|
|
845
|
+
for model in models:
|
|
846
|
+
if self.mode not in ["identity_matrix"]:
|
|
847
|
+
nonzero_count = torch.numel(retrieval_context[model][layer])
|
|
848
|
+
if self.mode == "random_binary_diagonal_matrix":
|
|
849
|
+
total_bytes_retrieved += (
|
|
850
|
+
nonzero_count * 1
|
|
851
|
+
) # 1 byte per element for binary
|
|
852
|
+
else:
|
|
853
|
+
total_bytes_retrieved += (
|
|
854
|
+
nonzero_count
|
|
855
|
+
* retrieval_context[model][layer].element_size()
|
|
856
|
+
)
|
|
857
|
+
nonzero_param_count += nonzero_count
|
|
858
|
+
nonzero_param_count_context += nonzero_count
|
|
859
|
+
if retrieval_context[model][layer].shape[0] == 1:
|
|
860
|
+
retrieval_models[model][layer] = (
|
|
861
|
+
compressed_layers[layer]
|
|
862
|
+
* retrieval_context[model][layer]
|
|
863
|
+
)
|
|
864
|
+
else:
|
|
865
|
+
retrieval_models[model][layer] = (
|
|
866
|
+
compressed_layers[layer]
|
|
867
|
+
@ retrieval_context[model][layer]
|
|
868
|
+
)
|
|
869
|
+
if self.debug >= 2 and model in test_models:
|
|
870
|
+
if retrieval_context[model][layer].shape[0] == 1:
|
|
871
|
+
retrieval_context_diag = torch.diag(
|
|
872
|
+
retrieval_context[model][layer].squeeze()
|
|
873
|
+
)
|
|
874
|
+
for m in models:
|
|
875
|
+
tv_new[model][m][layer] = (
|
|
876
|
+
tv_new[model][m][layer] @ retrieval_context_diag
|
|
877
|
+
)
|
|
878
|
+
else:
|
|
879
|
+
for m in models:
|
|
880
|
+
tv_new[model][m][layer] = (
|
|
881
|
+
tv_new[model][m][layer]
|
|
882
|
+
@ retrieval_context[model][layer]
|
|
883
|
+
)
|
|
884
|
+
else:
|
|
885
|
+
raise ValueError(f"Unsupported mode: {self.mode}")
|
|
886
|
+
# for model in test_models:
|
|
887
|
+
# # print(retrieval_context[model]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
|
|
888
|
+
# # print('a')
|
|
889
|
+
# print(tv_new[model][models[3]]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
|
|
890
|
+
|
|
891
|
+
# Shift the layers back
|
|
892
|
+
if self.shift_layers != 0:
|
|
893
|
+
if self.shift_layers == -1: # random shuffling
|
|
894
|
+
if self.debug >= 2:
|
|
895
|
+
temp_tv_new = deepcopy(tv_new)
|
|
896
|
+
temp_retrieval_models = deepcopy(retrieval_models)
|
|
897
|
+
for model_idx, model in enumerate(models):
|
|
898
|
+
# reverse_layer_mapping = {shuffled: original for original, shuffled in layer_mappings[model].items()}
|
|
899
|
+
for shuffled_layer, original_layer in layer_mappings[model].items():
|
|
900
|
+
temp_retrieval_models[model][original_layer] = retrieval_models[
|
|
901
|
+
model
|
|
902
|
+
][shuffled_layer]
|
|
903
|
+
if self.debug >= 2 and model in test_models:
|
|
904
|
+
for m in models:
|
|
905
|
+
temp_tv_new[model][m][original_layer] = tv_new[model][
|
|
906
|
+
m
|
|
907
|
+
][shuffled_layer]
|
|
908
|
+
retrieval_models = temp_retrieval_models
|
|
909
|
+
if self.debug >= 2:
|
|
910
|
+
tv_new = temp_tv_new
|
|
911
|
+
else: # TODO: check the correctness of this mode
|
|
912
|
+
# raise NotImplementedError("Shift back mode not implemented yet. No tv_new support yet.")
|
|
913
|
+
if self.debug >= 2:
|
|
914
|
+
temp_tv_new = deepcopy(tv_new)
|
|
915
|
+
temp_retrieval_models = deepcopy(retrieval_models)
|
|
916
|
+
for model_idx, model in enumerate(models):
|
|
917
|
+
for layer_idx, layer in enumerate(layers):
|
|
918
|
+
target_flag = target_layer_flags[layer_idx]
|
|
919
|
+
if not target_flag:
|
|
920
|
+
continue
|
|
921
|
+
layer_parts = layer.split(".")
|
|
922
|
+
layer_number = None
|
|
923
|
+
for part in layer_parts:
|
|
924
|
+
if part.isdigit():
|
|
925
|
+
layer_number = int(part)
|
|
926
|
+
break # Only consider the first number
|
|
927
|
+
if layer_number is None:
|
|
928
|
+
continue
|
|
929
|
+
new_layer_number = (
|
|
930
|
+
layer_number - model_idx * self.shift_layers
|
|
931
|
+
) % (max_layer_number + 1)
|
|
932
|
+
new_layer_parts = []
|
|
933
|
+
replaced = False
|
|
934
|
+
for part in layer_parts:
|
|
935
|
+
if part.isdigit() and not replaced:
|
|
936
|
+
new_layer_parts.append(str(new_layer_number))
|
|
937
|
+
replaced = True # Only replace the first numeric part
|
|
938
|
+
else:
|
|
939
|
+
new_layer_parts.append(part)
|
|
940
|
+
new_layer = ".".join(new_layer_parts)
|
|
941
|
+
temp_retrieval_models[model][new_layer] = retrieval_models[
|
|
942
|
+
model
|
|
943
|
+
][layer]
|
|
944
|
+
if self.debug >= 2 and model in test_models:
|
|
945
|
+
for m in models:
|
|
946
|
+
temp_tv_new[model][m][new_layer] = tv_new[model][m][
|
|
947
|
+
layer
|
|
948
|
+
]
|
|
949
|
+
retrieval_models = temp_retrieval_models
|
|
950
|
+
if self.debug >= 2:
|
|
951
|
+
tv_new = temp_tv_new
|
|
952
|
+
|
|
953
|
+
# for model in test_models:
|
|
954
|
+
# # print(retrieval_context[model]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
|
|
955
|
+
# # print('a')
|
|
956
|
+
# print(tv_new[model][models[3]]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
|
|
957
|
+
|
|
958
|
+
# metadata
|
|
959
|
+
if self.debug >= 2:
|
|
960
|
+
if self.mode in [
|
|
961
|
+
"random_binary_diagonal_matrix",
|
|
962
|
+
"random_rotation_matrix",
|
|
963
|
+
"random_dense_matrix",
|
|
964
|
+
"random_diagonal_matrix",
|
|
965
|
+
"identity_matrix",
|
|
966
|
+
]:
|
|
967
|
+
layers = list(layers_old[models[0]].keys())
|
|
968
|
+
layers_old_flattened = [
|
|
969
|
+
torch.cat([layers_old[model][layer].flatten() for layer in layers])
|
|
970
|
+
for model in models
|
|
971
|
+
]
|
|
972
|
+
metadata["pairwise_cosine_similarity_matrix_before"] = (
|
|
973
|
+
pairwise_cosine_similarity_matrix(layers_old_flattened).tolist()
|
|
974
|
+
)
|
|
975
|
+
metadata["task_vector_dim"] = layers_old_flattened[0].shape[0]
|
|
976
|
+
# layers_new = deepcopy(retrieval_models)
|
|
977
|
+
rms = []
|
|
978
|
+
for retrieval_model in test_models:
|
|
979
|
+
print(f"Retrieval model: {retrieval_model}")
|
|
980
|
+
layers_new_flattened = [
|
|
981
|
+
torch.cat(
|
|
982
|
+
[
|
|
983
|
+
tv_new[retrieval_model][model][layer].flatten()
|
|
984
|
+
for layer in layers
|
|
985
|
+
]
|
|
986
|
+
)
|
|
987
|
+
for model in models
|
|
988
|
+
]
|
|
989
|
+
rms.append(layers_new_flattened)
|
|
990
|
+
# print(layers_new_flattened[0][:50])
|
|
991
|
+
# layers_new_flattened = [torch.cat([layers_new[retrieval_model][layer].flatten() for layer in layers]) for model in models]
|
|
992
|
+
pcsm = pairwise_cosine_similarity_matrix(
|
|
993
|
+
layers_new_flattened
|
|
994
|
+
).tolist()
|
|
995
|
+
print(pcsm)
|
|
996
|
+
metadata[
|
|
997
|
+
f"pairwise_cosine_similarity_matrix_after_{retrieval_model}"
|
|
998
|
+
] = pcsm
|
|
999
|
+
if self.debug >= 0:
|
|
1000
|
+
metadata["nonzero_parameter_count"] = (
|
|
1001
|
+
nonzero_param_count.item()
|
|
1002
|
+
if isinstance(nonzero_param_count, torch.Tensor)
|
|
1003
|
+
else nonzero_param_count
|
|
1004
|
+
)
|
|
1005
|
+
metadata["nonzero_param_count_context"] = (
|
|
1006
|
+
nonzero_param_count_context.item()
|
|
1007
|
+
if isinstance(nonzero_param_count_context, torch.Tensor)
|
|
1008
|
+
else nonzero_param_count_context
|
|
1009
|
+
)
|
|
1010
|
+
gbs = total_bytes_retrieved / 1e9
|
|
1011
|
+
metadata["total_gb_retrieved"] = gbs
|
|
1012
|
+
|
|
1013
|
+
return retrieval_models, metadata
|