fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
  30. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  32. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  33. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  34. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  35. fusion_bench/programs/fabric_fusion_program.py +5 -0
  36. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +1 -1
  39. fusion_bench/utils/lazy_state_dict.py +268 -0
  40. fusion_bench/utils/parameters.py +33 -0
  41. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  42. fusion_bench/utils/type.py +1 -0
  43. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
  44. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
  45. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  46. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  53. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  54. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  55. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  56. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  57. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  58. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  60. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  61. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  63. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  64. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  68. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  72. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  73. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  74. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
  75. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
  76. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
  77. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
  78. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
  79. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
  80. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
  81. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
  82. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  83. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  84. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  85. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  86. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1013 @@
1
+ import logging
2
+ import random
3
+ from collections import OrderedDict
4
+ from copy import deepcopy
5
+ from typing import Dict, List, Literal, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy.stats import ortho_group
10
+ from torch import Tensor, nn
11
+
12
+ from fusion_bench.method.base_algorithm import BaseAlgorithm
13
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
14
+ from fusion_bench.modelpool import BaseModelPool
15
+ from fusion_bench.utils.parameters import get_parameter_summary, human_readable
16
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_avg
17
+ from fusion_bench.utils.type import StateDictType
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ def cosine_similarity(tensor1: Tensor, tensor2: Tensor) -> float:
23
+ if tensor1.shape != tensor2.shape:
24
+ raise ValueError("Matrices must have the same shape")
25
+ vec1 = tensor1.flatten()
26
+ vec2 = tensor2.flatten()
27
+ dot_product = torch.sum(vec1 * vec2)
28
+ norm1 = torch.sqrt(torch.sum(vec1**2))
29
+ norm2 = torch.sqrt(torch.sum(vec2**2))
30
+ if norm1 == 0 or norm2 == 0:
31
+ return 0.0
32
+ return dot_product / (norm1 * norm2)
33
+
34
+
35
+ def svd_and_partition(
36
+ A: torch.Tensor, num_chunks: int = 3
37
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
38
+ U, S, V = torch.svd(A)
39
+ singular_values = len(S)
40
+ chunk_size = singular_values // num_chunks
41
+ U_chunks, S_chunks, V_chunks = [], [], []
42
+
43
+ for i in range(num_chunks):
44
+ start_idx = i * chunk_size
45
+ end_idx = singular_values if i == num_chunks - 1 else start_idx + chunk_size
46
+
47
+ U_chunks.append(U[:, start_idx:end_idx])
48
+ S_chunks.append(S[start_idx:end_idx])
49
+ V_chunks.append(V[:, start_idx:end_idx])
50
+
51
+ return U_chunks, S_chunks, V_chunks
52
+
53
+
54
+ def compute_svd_subspace_similarity(
55
+ ref: torch.Tensor, retrieval: torch.Tensor, num_chunks: int = 3
56
+ ) -> List[dict]:
57
+ if torch.cuda.is_available():
58
+ ref = ref.cuda()
59
+ retrieval = retrieval.cuda()
60
+ U_chunks, S_chunks, V_chunks = svd_and_partition(ref, num_chunks)
61
+ similarities = []
62
+ for i in range(num_chunks):
63
+ retrieval_approx = (
64
+ U_chunks[i] @ U_chunks[i].T @ retrieval @ V_chunks[i] @ V_chunks[i].T
65
+ )
66
+ frob_sim = torch.norm(ref - retrieval_approx, p="fro").item() / ref.numel()
67
+ cos_sim = cosine_similarity(ref, retrieval_approx)
68
+ if isinstance(cos_sim, torch.Tensor):
69
+ cos_sim = cos_sim.item()
70
+ similarities.append(
71
+ {
72
+ "subspace": i + 1,
73
+ "frobenius_similarity": frob_sim,
74
+ "cosine_similarity": cos_sim,
75
+ }
76
+ )
77
+ return similarities
78
+
79
+
80
+ def pairwise_cosine_similarity_matrix(tensors: List[torch.Tensor]) -> torch.Tensor:
81
+ if torch.cuda.is_available():
82
+ tensors = [tensor.cuda() for tensor in tensors]
83
+ n = len(tensors)
84
+ similarity_matrix = torch.zeros((n, n))
85
+ for i in range(n):
86
+ for j in range(n):
87
+ similarity = cosine_similarity(tensors[i], tensors[j])
88
+ similarity_matrix[i, j] = similarity.item()
89
+ return similarity_matrix
90
+
91
+
92
+ def compare_models(
93
+ state_dict1: StateDictType, state_dict2: StateDictType, target_layers=None
94
+ ):
95
+ results = {
96
+ "layerwise_l2": {},
97
+ "layerwise_cosine_similarity": {},
98
+ "total_l2": None,
99
+ "average_l2": None,
100
+ "total_cosine_similarity": None,
101
+ "average_cosine_similarity": None,
102
+ }
103
+ # Initialize lists to store flattened parameters
104
+ params1_list = []
105
+ params2_list = []
106
+
107
+ keys1 = set(state_dict1.keys())
108
+ keys2 = set(state_dict2.keys())
109
+ # filter out layers that are not in target_layers
110
+ if target_layers is not None:
111
+ keys1 = keys1.intersection(target_layers)
112
+ keys2 = keys2.intersection(target_layers)
113
+
114
+ common_keys = keys1 & keys2
115
+ if keys1 != keys2:
116
+ print(
117
+ "Warning: State dicts have different keys. Comparison will be made on common keys only."
118
+ )
119
+ num_layers = len(common_keys)
120
+
121
+ for key in common_keys:
122
+ tensor1 = state_dict1[key].float()
123
+ tensor2 = state_dict2[key].float()
124
+
125
+ # Compute L2 norm difference
126
+ l2_diff = torch.norm(tensor1 - tensor2, p=2) / tensor1.numel()
127
+ results["layerwise_l2"][key] = l2_diff.item()
128
+
129
+ # Compute cosine similarity
130
+ tensor1_flat = tensor1.reshape(-1)
131
+ tensor2_flat = tensor2.reshape(-1)
132
+ cos_sim = cosine_similarity(tensor1_flat, tensor2_flat).item()
133
+ results["layerwise_cosine_similarity"][key] = cos_sim
134
+
135
+ # Collect parameters for total metrics
136
+ params1_list.append(tensor1_flat)
137
+ params2_list.append(tensor2_flat)
138
+
139
+ # Compute total metrics over all parameters
140
+ if params1_list and params2_list:
141
+ params1 = torch.cat(params1_list)
142
+ params2 = torch.cat(params2_list)
143
+ # Compute total L2 norm difference
144
+ total_l2_difference = (
145
+ torch.norm(params1 - params2, p=2).item() / params1.numel()
146
+ )
147
+ results["total_l2"] = total_l2_difference
148
+ # Compute total cosine similarity
149
+ total_cosine_similarity = cosine_similarity(params1, params2).item()
150
+ results["total_cosine_similarity"] = total_cosine_similarity
151
+ else:
152
+ results["total_l2"] = None
153
+ results["total_cosine_similarity"] = None
154
+
155
+ # Compute average metrics
156
+ if num_layers > 0:
157
+ average_l2 = sum(results["layerwise_l2"].values()) / num_layers
158
+ average_cosine_similarity = (
159
+ sum(results["layerwise_cosine_similarity"].values()) / num_layers
160
+ )
161
+ results["average_l2"] = average_l2
162
+ results["average_cosine_similarity"] = average_cosine_similarity
163
+ else:
164
+ results["average_l2"] = None
165
+ results["average_cosine_similarity"] = None
166
+
167
+ return results
168
+
169
+
170
+ class SuperposedAlgorithmBase(
171
+ BaseAlgorithm,
172
+ SimpleProfilerMixin,
173
+ ):
174
+ _config_mapping = BaseAlgorithm._config_mapping | {
175
+ "mode": "mode",
176
+ "target_layer": "target_layer",
177
+ "random_seed": "random_seed",
178
+ "different_across_layers": "different_across_layers",
179
+ "joint_matrix_mode": "joint_matrix_mode",
180
+ "rank": "rank",
181
+ "random_components": "random_components",
182
+ "shift_layers": "shift_layers",
183
+ "absorber": "absorber",
184
+ "debug": "debug",
185
+ "ms_mode": "ms_mode",
186
+ "verbose": "verbose",
187
+ "dropout_rate": "dropout_rate",
188
+ }
189
+
190
+ def __init__(
191
+ self,
192
+ mode: str,
193
+ target_layer: str,
194
+ random_seed: int,
195
+ different_across_layers: bool,
196
+ joint_matrix_mode: str,
197
+ rank: int,
198
+ random_components: bool,
199
+ shift_layers: int,
200
+ absorber: Literal["average", "pretrained", "None"],
201
+ debug: int,
202
+ ms_mode: str,
203
+ verbose: int,
204
+ dropout_rate: int,
205
+ **kwargs,
206
+ ):
207
+ super().__init__(**kwargs)
208
+ self.mode = mode
209
+ self.target_layer = target_layer
210
+ self.random_seed = random_seed
211
+ self.different_across_layers = different_across_layers
212
+ self.joint_matrix_mode = joint_matrix_mode
213
+ self.rank = rank
214
+ self.random_components = random_components
215
+ self.shift_layers = shift_layers
216
+ self.absorber = absorber
217
+ self.debug = debug
218
+ self.ms_mode = ms_mode
219
+ self.verbose = verbose
220
+ self.dropout_rate = dropout_rate
221
+
222
+ def _compute_svd_subspace_similarities(
223
+ self,
224
+ original_state_dict: StateDictType,
225
+ retrieved_state_dict: StateDictType,
226
+ target_layers: Optional[List[str]] = None,
227
+ ) -> dict:
228
+ svd_similarities = {}
229
+ for layer_name, original_param in original_state_dict.items():
230
+ if target_layers is not None and layer_name not in target_layers:
231
+ continue
232
+ if (
233
+ original_param.dim() == 2
234
+ ): # Only compute for 2D tensors (weight matrices)
235
+ retrieved_param = retrieved_state_dict[layer_name]
236
+ svd_similarities[layer_name] = compute_svd_subspace_similarity(
237
+ original_param.float(), retrieved_param.float()
238
+ )
239
+ return svd_similarities
240
+
241
+ def _load_state_dicts(self, modelpool: BaseModelPool) -> Dict[str, StateDictType]:
242
+ """
243
+ Load the state dicts of the models in the modelpool.
244
+
245
+ Args:
246
+ modelpool (BaseModelPool): The modelpool to load the state dicts from.
247
+
248
+ Returns:
249
+ Dict[str, StateDictType]: A dictionary of state dicts, keyed by model name.
250
+ """
251
+ state_dicts = {}
252
+ for model_name in modelpool.model_names:
253
+ with self.profile("load model"):
254
+ model = modelpool.load_model(model_name)
255
+ state_dicts[model_name] = model.state_dict(keep_vars=True)
256
+ return state_dicts
257
+
258
+ def _compute_absorber(
259
+ self,
260
+ state_dicts: Dict[str, StateDictType],
261
+ pretrained_model: Optional[nn.Module] = None,
262
+ ) -> Optional[StateDictType]:
263
+ """
264
+ Compute the absorber state dict.
265
+
266
+ Args:
267
+ state_dicts (Dict[str, StateDictType]): The state dicts of the models, keyed by model name, i.e. `{model_name: state_dict}`.
268
+ pretrained_model (Optional[nn.Module]): The pretrained model.
269
+
270
+ Returns:
271
+ Optional[StateDictType]: The absorber state dict.
272
+ """
273
+ if self.absorber == "average":
274
+ return state_dict_avg(list(state_dicts.values()))
275
+ elif self.absorber == "pretrained":
276
+ return pretrained_model.state_dict(keep_vars=True)
277
+ elif self.absorber == "None":
278
+ return None
279
+ else:
280
+ raise ValueError(
281
+ f"Unsupported absorber type: {self.absorber}. Must be one of 'average', 'pretrained', or 'None'."
282
+ )
283
+
284
+ @staticmethod
285
+ def svd_decomposition(A, r):
286
+ if torch.cuda.is_available():
287
+ A = A.cuda()
288
+ U, S, V = torch.svd(A)
289
+ return (U[:, :r] @ torch.diag(S[:r])).cpu(), V.t()[:r, :].cpu()
290
+
291
+ @staticmethod
292
+ def svd_decomposition_bm(A, r_most, r_mid):
293
+ if torch.cuda.is_available():
294
+ A = A.cuda()
295
+
296
+ # Perform SVD
297
+ U, S, V = torch.svd(A)
298
+
299
+ # Get the most significant 'r_most' dimensions
300
+ U_most = U[:, :r_most]
301
+ S_most = S[:r_most]
302
+ V_most = V[:, :r_most]
303
+
304
+ # Get the middle 'r_mid' dimensions
305
+ start_mid = len(S) // 2 - r_mid // 2
306
+ end_mid = start_mid + r_mid
307
+ U_mid = U[:, start_mid:end_mid]
308
+ S_mid = S[start_mid:end_mid]
309
+ V_mid = V[:, start_mid:end_mid]
310
+
311
+ # Combine the results into two sets
312
+ U_combined = torch.cat([U_most, U_mid], dim=1)
313
+ S_combined = torch.cat([S_most, S_mid])
314
+ V_combined = torch.cat([V_most, V_mid], dim=1)
315
+
316
+ return (U_combined @ torch.diag(S_combined)).cpu(), V_combined.t().cpu()
317
+
318
+ @staticmethod
319
+ def svd_decomposition(A, r=None, r_most=None, r_mid=None, random_components=False):
320
+ """
321
+ Perform SVD decomposition with options for:
322
+ 1. Truncated SVD with 'r' components (if r is provided and random_components=False).
323
+ 2. Most significant 'r_most' and middle 'r_mid' components (if r_most and r_mid are provided).
324
+ 3. Randomly selected 'r' components (if r is provided and random_components=True).
325
+
326
+ Args:
327
+ A (torch.Tensor): The input matrix to decompose.
328
+ r (int, optional): Number of components for standard or random SVD.
329
+ r_most (int, optional): Number of most significant components.
330
+ r_mid (int, optional): Number of middle components.
331
+ random_components (bool, optional): Whether to sample 'r' random components.
332
+
333
+ Returns:
334
+ (torch.Tensor, torch.Tensor): Two matrices resulting from the SVD decomposition.
335
+ """
336
+ if torch.cuda.is_available():
337
+ A = A.cuda()
338
+
339
+ # Perform SVD
340
+ U, S, V = torch.svd(A)
341
+
342
+ if r is not None and not random_components:
343
+ # Standard SVD decomposition with 'r' components
344
+ return (U[:, :r] @ torch.diag(S[:r])).cpu(), V.t()[:r, :].cpu()
345
+
346
+ elif r_most is not None and r_mid is not None:
347
+ # SVD decomposition with 'r_most' most significant and 'r_mid' middle components
348
+ # Most significant components
349
+ U_most = U[:, :r_most]
350
+ S_most = S[:r_most]
351
+ V_most = V[:, :r_most]
352
+
353
+ # Middle components
354
+ start_mid = len(S) // 2 - r_mid // 2
355
+ end_mid = start_mid + r_mid
356
+ U_mid = U[:, start_mid:end_mid]
357
+ S_mid = S[start_mid:end_mid]
358
+ V_mid = V[:, start_mid:end_mid]
359
+
360
+ # Combine the most and middle components
361
+ U_combined = torch.cat([U_most, U_mid], dim=1)
362
+ S_combined = torch.cat([S_most, S_mid])
363
+ V_combined = torch.cat([V_most, V_mid], dim=1)
364
+
365
+ return (U_combined @ torch.diag(S_combined)).cpu(), V_combined.t().cpu()
366
+
367
+ elif r is not None and random_components:
368
+ # SVD decomposition with random 'r' components
369
+ indices = torch.randperm(len(S))[:r]
370
+ U_rand = U[:, indices]
371
+ S_rand = S[indices]
372
+ V_rand = V[:, indices]
373
+
374
+ return (U_rand @ torch.diag(S_rand)).cpu(), V_rand.t().cpu()
375
+
376
+ else:
377
+ raise ValueError(
378
+ "Invalid combination of arguments. Provide correct parameters."
379
+ )
380
+
381
+ @staticmethod
382
+ def _get_rank(A, rank):
383
+ if isinstance(rank, str):
384
+ r1, r2 = rank.split("-")
385
+ r1 = int(float(r1) * min(A.shape)) if "." in r1 else int(r1)
386
+ r2 = int(float(r2) * min(A.shape)) if "." in r2 else int(r2)
387
+ return r1, r2
388
+ if isinstance(rank, int):
389
+ return rank
390
+ elif isinstance(rank, float):
391
+ return int(rank * min(A.shape))
392
+
393
+ def _target_layer_flag(self, layer: str):
394
+ """
395
+ The method takes a layer name as input and returns a boolean indicating whether this layer should be targeted.
396
+
397
+ Current implementation assume Transformer architecture and layer number is the first number in the layer name.
398
+
399
+ Args:
400
+ layer (str): The name of the layer.
401
+
402
+ Returns:
403
+ bool: True if the layer should be targeted, False otherwise.
404
+ """
405
+ target_layers = self.target_layer # e.g. ["mlp_w", "attn_w"]
406
+ # TODO: figure out what wo is in flan-t5
407
+ mlp_flag = "mlp" in layer or "Dense" in layer
408
+ attn_flag = "attn" in layer or "Attention" in layer
409
+ weight_flag = "weight" in layer
410
+ bias_flag = "bias" in layer
411
+ target_flags = []
412
+ for target_layer in target_layers:
413
+ if target_layer == "mlp_w":
414
+ target_flags.append(mlp_flag and not bias_flag)
415
+ elif target_layer == "attn_w":
416
+ target_flags.append(attn_flag and not bias_flag)
417
+ elif target_layer == "all":
418
+ target_flags.append(True)
419
+ elif target_layer == "mlp":
420
+ target_flags.append(mlp_flag)
421
+ elif target_layer == "attn":
422
+ target_flags.append(attn_flag)
423
+ else:
424
+ raise ValueError(f"Unsupported target layer: {target_layer}")
425
+ target_flag = any(target_flags)
426
+ return target_flag
427
+
428
+ def _compress_and_retrieve(self, state_dicts: Dict[str, StateDictType], mode: str):
429
+ """
430
+ Compress and retrieve the state dicts.
431
+
432
+ Args:
433
+ state_dicts (Dict[str, StateDictType]): The state dicts of the models, keyed by model name, i.e. `{model_name: state_dict}`.
434
+ mode (str): The mode of the compression and retrieval.
435
+
436
+ Returns:
437
+ Dict[str, StateDictType]: The compressed and retrieved state dicts, keyed by model name, i.e. `{model_name: state_dict}`.
438
+ """
439
+ # Assume the state_dicts have the same layers.
440
+ layers = state_dicts[list(state_dicts.keys())[0]].keys()
441
+ models = list(state_dicts.keys())
442
+ compressed_layers = {}
443
+ compression_context = {model: {} for model in models}
444
+ retrieval_context = {model: {} for model in models}
445
+ retrieval_models = deepcopy(state_dicts)
446
+ # target_layer_flags = [self._target_layer_flag(layer) for layer in layers]
447
+ # implement target_layer_flags with dropout
448
+ target_layer_flags: List[bool] = []
449
+ count = 0
450
+ for layer in layers:
451
+ if self._target_layer_flag(layer):
452
+ # take the target layer per `self.dropout_rate` target layers.
453
+ # e.g. if self.dropout_rate = 2, then take the 2nd and 4th target layers, skip the first and third target layers.
454
+ # if self.dropout_rate = 1, then take all target layers.
455
+ count += 1
456
+ if count == self.dropout_rate:
457
+ target_layer_flags.append(True)
458
+ count = 0
459
+ else:
460
+ target_layer_flags.append(False)
461
+ else:
462
+ target_layer_flags.append(False)
463
+
464
+ target_layers = [
465
+ layer for layer, flag in zip(layers, target_layer_flags) if flag
466
+ ]
467
+ log.info(
468
+ f"filtered {len(target_layers)} target layers out of {len(layers)} layers"
469
+ )
470
+
471
+ metadata = {
472
+ "nonzero_parameter_count": 0,
473
+ "nonzero_param_count_context": 0,
474
+ "task_vector_retrieval_similarity": {},
475
+ "superposed_model_retrieval_similarity": {},
476
+ "model_retrieval_similarity": {},
477
+ "target_layers": target_layers,
478
+ "task_vector_svd_subspace_similarities": {},
479
+ "superposed_model_svd_subspace_similarities": {},
480
+ "model_svd_subspace_similarities": {},
481
+ "total_param_count_original": 0,
482
+ "total_gb_original": 0,
483
+ "total_gb_retrieved": 0,
484
+ }
485
+
486
+ if "absorber" in models:
487
+ models.remove("absorber")
488
+ absorber = state_dicts["absorber"]
489
+ else:
490
+ absorber = None
491
+
492
+ # get the total number of parameters and bytes (in GB) of the original model
493
+ original_param_summary = get_parameter_summary(state_dicts[models[0]])
494
+ gbs = original_param_summary["bytes"] / 1e9
495
+ log.info(
496
+ f"Total parameters: {human_readable(original_param_summary['all_param'])}"
497
+ )
498
+ log.info(f"Total gigabytes: {gbs}")
499
+ metadata["total_param_count_original"] = original_param_summary["all_param"]
500
+ metadata["total_gb_original"] = gbs
501
+
502
+ # for analysis purposes
503
+ if self.debug >= 2:
504
+ test_models = models
505
+ # test_models = models[:2]
506
+ # layers_old = {model: OrderedDict() for model in models}
507
+ layers_old = {model: deepcopy(state_dicts[model]) for model in models}
508
+ tv_new = {
509
+ model: {model: OrderedDict() for model in models}
510
+ for model in test_models
511
+ }
512
+ # layers_new = {model: {model: OrderedDict() for model in models} for model in test_models}
513
+
514
+ # Shift the layers
515
+ # TODO: make this more robust to other models.
516
+ if self.shift_layers != 0:
517
+ # random shuffling. Do not shuffle layers with no number in their name.
518
+ # because they are likely to be special layers like text embeddings.
519
+ if self.shift_layers == -1:
520
+ layer_mappings = {model: {} for model in models}
521
+ temp_state_dicts = deepcopy(state_dicts)
522
+
523
+ # get layer number index, assume the first number in the layer name is the layer number
524
+ # assume all numbered layers have their number at the same index
525
+ # assume components separated by '.' in the layer name
526
+ found_digit = False
527
+ for layer_idx, layer in enumerate(layers):
528
+ if target_layer_flags[layer_idx]:
529
+ layer_parts = layer.split(".")
530
+ for i, part in enumerate(layer_parts):
531
+ if part.isdigit():
532
+ layer_number_idx = i
533
+ break
534
+ if found_digit:
535
+ break
536
+
537
+ # get groups of target layers with same name except the layer number
538
+ target_layer_groups = {}
539
+ for layer_idx, layer in enumerate(layers):
540
+ if target_layer_flags[layer_idx]:
541
+ layer_parts = layer.split(".")
542
+ if (
543
+ layer_number_idx >= len(layer_parts)
544
+ or not layer_parts[layer_number_idx].isdigit()
545
+ ):
546
+ continue # skip layers without number
547
+ base_name = ".".join(
548
+ layer_parts[:layer_number_idx]
549
+ + layer_parts[layer_number_idx + 1 :]
550
+ )
551
+ layer_number = int(layer_parts[layer_number_idx])
552
+ if base_name not in target_layer_groups:
553
+ target_layer_groups[base_name] = []
554
+ target_layer_groups[base_name].append(layer_number)
555
+
556
+ # construct random shuffled mapping
557
+ random_shuffle_mapping = {model: {} for model in models}
558
+ for model_idx, model in enumerate(models):
559
+ for layer_idx, layer in enumerate(layers):
560
+ if target_layer_flags[layer_idx]:
561
+ layer_parts = layer.split(".")
562
+ if (
563
+ layer_number_idx >= len(layer_parts)
564
+ or not layer_parts[layer_number_idx].isdigit()
565
+ ):
566
+ continue # skip layers without number
567
+ base_name = ".".join(
568
+ layer_parts[:layer_number_idx]
569
+ + layer_parts[layer_number_idx + 1 :]
570
+ )
571
+
572
+ if base_name not in random_shuffle_mapping[model]:
573
+ rng_state = random.getstate()
574
+ # Shuffle the layer numbers differently for each model
575
+ random.seed(self.config.random_seed + model_idx)
576
+ shuffled_layer_numbers = target_layer_groups[
577
+ base_name
578
+ ].copy()
579
+ random.shuffle(shuffled_layer_numbers)
580
+ random_shuffle_mapping[model][base_name] = {
581
+ orig: str(shuffled)
582
+ for orig, shuffled in zip(
583
+ target_layer_groups[base_name],
584
+ shuffled_layer_numbers,
585
+ )
586
+ }
587
+ random.setstate(rng_state)
588
+
589
+ for layer_idx, layer in enumerate(layers):
590
+ if target_layer_flags[layer_idx]:
591
+ layer_parts = layer.split(".")
592
+ if (
593
+ layer_number_idx >= len(layer_parts)
594
+ or not layer_parts[layer_number_idx].isdigit()
595
+ ):
596
+ continue # skip layers without number
597
+ base_name = ".".join(
598
+ layer_parts[:layer_number_idx]
599
+ + layer_parts[layer_number_idx + 1 :]
600
+ )
601
+ layer_number = int(layer_parts[layer_number_idx])
602
+ new_layer_number = random_shuffle_mapping[model][base_name][
603
+ layer_number
604
+ ]
605
+ new_layer_name = ".".join(
606
+ layer_parts[:layer_number_idx]
607
+ + [new_layer_number]
608
+ + layer_parts[layer_number_idx + 1 :]
609
+ )
610
+ temp_state_dicts[model][new_layer_name] = state_dicts[
611
+ model
612
+ ][layer]
613
+ layer_mappings[model][new_layer_name] = layer
614
+ state_dicts = temp_state_dicts
615
+ else:
616
+ layer_numbers = {}
617
+ for layer_idx, layer in enumerate(layers):
618
+ if target_layer_flags[layer_idx]:
619
+ layer_parts = layer.split(".")
620
+ for part in layer_parts:
621
+ if part.isdigit():
622
+ layer_numbers[layer] = int(part)
623
+ break # Only consider the first number for each layer
624
+ if layer_numbers:
625
+ max_layer_number = max(layer_numbers.values())
626
+ else:
627
+ max_layer_number = 0
628
+ temp_state_dicts = deepcopy(state_dicts)
629
+ # Wrap around and shift each model by a different amount
630
+ for model_idx, model in enumerate(models):
631
+ for layer_idx, layer in enumerate(layers):
632
+ target_flag = target_layer_flags[layer_idx]
633
+ if not target_flag:
634
+ continue
635
+ layer_number = layer_numbers.get(layer)
636
+ if layer_number is None:
637
+ continue
638
+ new_layer_number = (
639
+ layer_number + model_idx * self.config.shift_layers
640
+ ) % (max_layer_number + 1)
641
+ new_layer_parts = []
642
+ replaced = False # Only replace the first numeric part FIXME: make it more robust
643
+ for part in layer.split("."):
644
+ if part.isdigit() and not replaced:
645
+ new_layer_parts.append(str(new_layer_number))
646
+ replaced = True
647
+ else:
648
+ new_layer_parts.append(part)
649
+ new_layer = ".".join(new_layer_parts)
650
+ temp_state_dicts[model][new_layer] = state_dicts[model][layer]
651
+ state_dicts = temp_state_dicts
652
+
653
+ if self.debug >= 2:
654
+ # for evaluating pairwise cosine similarity
655
+ unmerged_task_vectors = deepcopy(state_dicts)
656
+
657
+ # compress
658
+ for layer_idx, layer in enumerate(layers):
659
+ shape = state_dicts[models[0]][layer].shape
660
+ compressed_layer = None
661
+ target_flag = target_layer_flags[layer_idx]
662
+ # self.verbose = 1
663
+ if self.verbose >= 1:
664
+ log.info(f"{layer} | {shape} | {target_flag}")
665
+ if not target_flag:
666
+ if absorber is not None:
667
+ compressed_layer = absorber[layer]
668
+ else:
669
+ for model in models:
670
+ if compressed_layer is None:
671
+ compressed_layer = deepcopy(state_dicts[model][layer])
672
+ else:
673
+ compressed_layer += deepcopy(state_dicts[model][layer])
674
+ else:
675
+ if self.mode == "random_binary_diagonal_matrix":
676
+ for model_idx, model in enumerate(models):
677
+ if self.different_across_layers:
678
+ seed = self.random_seed + model_idx + hash(layer) % 1e6
679
+ else:
680
+ seed = self.random_seed + model_idx
681
+ numpy_state = np.random.get_state()
682
+ np.random.seed(int(seed))
683
+ context = (
684
+ np.random.binomial(p=0.5, n=1, size=(1, shape[-1])).astype(
685
+ np.float32
686
+ )
687
+ * 2
688
+ - 1
689
+ )
690
+ context = torch.from_numpy(context)
691
+ np.random.set_state(numpy_state)
692
+ compression_context[model][
693
+ layer
694
+ ] = context # for analysis purposes
695
+ retrieval_context[model][layer] = context
696
+ if compressed_layer is None:
697
+ compressed_layer = state_dicts[model][layer] * context
698
+ else:
699
+ compressed_layer += state_dicts[model][layer] * context
700
+ if self.debug >= 2:
701
+ # hadamard product is not linear, convert it back to diagonal matrix and apply matrix multiplication
702
+ context_diag = torch.diag(context.squeeze())
703
+ unmerged_task_vectors[model][layer] = (
704
+ unmerged_task_vectors[model][layer] @ context_diag
705
+ )
706
+ elif self.mode == "random_rotation_matrix":
707
+ for model_idx, model in enumerate(models):
708
+ if self.different_across_layers:
709
+ seed = self.random_seed + model_idx + hash(layer) % 1e6
710
+ else:
711
+ seed = self.random_seed + model_idx
712
+ context = torch.from_numpy(
713
+ ortho_group.rvs(shape[-1], random_state=seed).astype(
714
+ "float32"
715
+ )
716
+ )
717
+ compression_context[model][
718
+ layer
719
+ ] = context # for analysis purposes
720
+ retrieval_context[model][layer] = context.t()
721
+ if compressed_layer is None:
722
+ compressed_layer = state_dicts[model][layer] @ context
723
+ else:
724
+ compressed_layer += state_dicts[model][layer] @ context
725
+ if self.debug >= 2:
726
+ unmerged_task_vectors[model][layer] = (
727
+ unmerged_task_vectors[model][layer] @ context
728
+ )
729
+ elif self.mode == "random_dense_matrix":
730
+ for model_idx, model in enumerate(models):
731
+ if self.different_across_layers:
732
+ seed = self.random_seed + model_idx + hash(layer) % 1e6
733
+ else:
734
+ seed = self.random_seed + model_idx
735
+ numpy_state = np.random.get_state()
736
+ np.random.seed(int(seed))
737
+ context = torch.from_numpy(
738
+ np.random.randn(shape[-1], shape[-1]).astype(np.float32)
739
+ )
740
+ np.random.set_state(numpy_state)
741
+ compression_context[model][
742
+ layer
743
+ ] = context # for analysis purposes
744
+ retrieval_context[model][layer] = torch.linalg.pinv(
745
+ context.to("cuda")
746
+ ).to("cpu")
747
+ if compressed_layer is None:
748
+ compressed_layer = state_dicts[model][layer] @ context
749
+ else:
750
+ compressed_layer += state_dicts[model][layer] @ context
751
+ if self.debug >= 2:
752
+ unmerged_task_vectors[model][layer] = (
753
+ unmerged_task_vectors[model][layer] @ context
754
+ )
755
+ elif self.mode == "random_diagonal_matrix":
756
+ for model_idx, model in enumerate(models):
757
+ if self.different_across_layers:
758
+ seed = self.random_seed + model_idx + hash(layer) % 1e6
759
+ else:
760
+ seed = self.random_seed + model_idx
761
+ numpy_state = np.random.get_state()
762
+ np.random.seed(int(seed))
763
+ context = torch.from_numpy(
764
+ np.random.randn(1, shape[-1]).astype(np.float32)
765
+ )
766
+ np.random.set_state(numpy_state)
767
+ compression_context[model][
768
+ layer
769
+ ] = context # for analysis purposes
770
+ retrieval_context[model][layer] = 1 / context
771
+ if compressed_layer is None:
772
+ compressed_layer = state_dicts[model][layer] * context
773
+ else:
774
+ compressed_layer += state_dicts[model][layer] * context
775
+ if self.debug >= 2:
776
+ unmerged_task_vectors[model][layer] = (
777
+ unmerged_task_vectors[model][layer] * context
778
+ )
779
+ elif self.mode == "identity_matrix":
780
+ for model_idx, model in enumerate(models):
781
+ context = torch.eye(shape[-1])
782
+ compression_context[model][
783
+ layer
784
+ ] = context # for analysis purposes
785
+ retrieval_context[model][layer] = context
786
+ if compressed_layer is None:
787
+ compressed_layer = state_dicts[model][layer] @ context
788
+ else:
789
+ compressed_layer += state_dicts[model][layer] @ context
790
+ if self.debug >= 2:
791
+ unmerged_task_vectors[model][layer] = (
792
+ unmerged_task_vectors[model][layer] @ context
793
+ )
794
+ else:
795
+ raise ValueError(f"Unsupported mode: {self.mode}")
796
+
797
+ compressed_layers[layer] = compressed_layer
798
+
799
+ # retrieve: for purpose of benchmarking, retrieve all models at once. In practice, retrieval should be done per model request.
800
+ nonzero_param_count = 0
801
+ nonzero_param_count_context = 0
802
+ total_bytes_retrieved = 0
803
+
804
+ if self.debug >= 2:
805
+ for model in test_models:
806
+ tv_new[model] = deepcopy(unmerged_task_vectors)
807
+
808
+ for layer_idx, layer in enumerate(layers):
809
+ shape = state_dicts[models[0]][layer].shape
810
+ target_flag = target_layer_flags[layer_idx]
811
+ if not target_flag:
812
+ if mode == "superposed_model_soup":
813
+ # we don't count non-target layers for superposed task arithmetic
814
+ # because they can be absorbed into the pretrained weights
815
+ param_count = torch.numel(compressed_layers[layer])
816
+ total_bytes_retrieved += (
817
+ param_count * compressed_layers[layer].element_size()
818
+ )
819
+ nonzero_param_count += param_count
820
+ for model in models:
821
+ retrieval_models[model][layer] = compressed_layers[layer]
822
+ else:
823
+ if (
824
+ mode == "superposed_task_arithmetic"
825
+ and self.mode == "identity_matrix"
826
+ and self.shift_layers == 0
827
+ ):
828
+ # we don't count target layers for task arithmetic
829
+ # because they can be absorbed into the pretrained weights
830
+ pass
831
+ else:
832
+ param_count = torch.numel(compressed_layers[layer])
833
+ total_bytes_retrieved += (
834
+ param_count * compressed_layers[layer].element_size()
835
+ )
836
+ nonzero_param_count += torch.numel(compressed_layers[layer])
837
+
838
+ if self.mode in [
839
+ "random_binary_diagonal_matrix",
840
+ "random_rotation_matrix",
841
+ "random_dense_matrix",
842
+ "random_diagonal_matrix",
843
+ "identity_matrix",
844
+ ]:
845
+ for model in models:
846
+ if self.mode not in ["identity_matrix"]:
847
+ nonzero_count = torch.numel(retrieval_context[model][layer])
848
+ if self.mode == "random_binary_diagonal_matrix":
849
+ total_bytes_retrieved += (
850
+ nonzero_count * 1
851
+ ) # 1 byte per element for binary
852
+ else:
853
+ total_bytes_retrieved += (
854
+ nonzero_count
855
+ * retrieval_context[model][layer].element_size()
856
+ )
857
+ nonzero_param_count += nonzero_count
858
+ nonzero_param_count_context += nonzero_count
859
+ if retrieval_context[model][layer].shape[0] == 1:
860
+ retrieval_models[model][layer] = (
861
+ compressed_layers[layer]
862
+ * retrieval_context[model][layer]
863
+ )
864
+ else:
865
+ retrieval_models[model][layer] = (
866
+ compressed_layers[layer]
867
+ @ retrieval_context[model][layer]
868
+ )
869
+ if self.debug >= 2 and model in test_models:
870
+ if retrieval_context[model][layer].shape[0] == 1:
871
+ retrieval_context_diag = torch.diag(
872
+ retrieval_context[model][layer].squeeze()
873
+ )
874
+ for m in models:
875
+ tv_new[model][m][layer] = (
876
+ tv_new[model][m][layer] @ retrieval_context_diag
877
+ )
878
+ else:
879
+ for m in models:
880
+ tv_new[model][m][layer] = (
881
+ tv_new[model][m][layer]
882
+ @ retrieval_context[model][layer]
883
+ )
884
+ else:
885
+ raise ValueError(f"Unsupported mode: {self.mode}")
886
+ # for model in test_models:
887
+ # # print(retrieval_context[model]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
888
+ # # print('a')
889
+ # print(tv_new[model][models[3]]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
890
+
891
+ # Shift the layers back
892
+ if self.shift_layers != 0:
893
+ if self.shift_layers == -1: # random shuffling
894
+ if self.debug >= 2:
895
+ temp_tv_new = deepcopy(tv_new)
896
+ temp_retrieval_models = deepcopy(retrieval_models)
897
+ for model_idx, model in enumerate(models):
898
+ # reverse_layer_mapping = {shuffled: original for original, shuffled in layer_mappings[model].items()}
899
+ for shuffled_layer, original_layer in layer_mappings[model].items():
900
+ temp_retrieval_models[model][original_layer] = retrieval_models[
901
+ model
902
+ ][shuffled_layer]
903
+ if self.debug >= 2 and model in test_models:
904
+ for m in models:
905
+ temp_tv_new[model][m][original_layer] = tv_new[model][
906
+ m
907
+ ][shuffled_layer]
908
+ retrieval_models = temp_retrieval_models
909
+ if self.debug >= 2:
910
+ tv_new = temp_tv_new
911
+ else: # TODO: check the correctness of this mode
912
+ # raise NotImplementedError("Shift back mode not implemented yet. No tv_new support yet.")
913
+ if self.debug >= 2:
914
+ temp_tv_new = deepcopy(tv_new)
915
+ temp_retrieval_models = deepcopy(retrieval_models)
916
+ for model_idx, model in enumerate(models):
917
+ for layer_idx, layer in enumerate(layers):
918
+ target_flag = target_layer_flags[layer_idx]
919
+ if not target_flag:
920
+ continue
921
+ layer_parts = layer.split(".")
922
+ layer_number = None
923
+ for part in layer_parts:
924
+ if part.isdigit():
925
+ layer_number = int(part)
926
+ break # Only consider the first number
927
+ if layer_number is None:
928
+ continue
929
+ new_layer_number = (
930
+ layer_number - model_idx * self.shift_layers
931
+ ) % (max_layer_number + 1)
932
+ new_layer_parts = []
933
+ replaced = False
934
+ for part in layer_parts:
935
+ if part.isdigit() and not replaced:
936
+ new_layer_parts.append(str(new_layer_number))
937
+ replaced = True # Only replace the first numeric part
938
+ else:
939
+ new_layer_parts.append(part)
940
+ new_layer = ".".join(new_layer_parts)
941
+ temp_retrieval_models[model][new_layer] = retrieval_models[
942
+ model
943
+ ][layer]
944
+ if self.debug >= 2 and model in test_models:
945
+ for m in models:
946
+ temp_tv_new[model][m][new_layer] = tv_new[model][m][
947
+ layer
948
+ ]
949
+ retrieval_models = temp_retrieval_models
950
+ if self.debug >= 2:
951
+ tv_new = temp_tv_new
952
+
953
+ # for model in test_models:
954
+ # # print(retrieval_context[model]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
955
+ # # print('a')
956
+ # print(tv_new[model][models[3]]['vision_model.encoder.layers.4.self_attn.q_proj.weight'])
957
+
958
+ # metadata
959
+ if self.debug >= 2:
960
+ if self.mode in [
961
+ "random_binary_diagonal_matrix",
962
+ "random_rotation_matrix",
963
+ "random_dense_matrix",
964
+ "random_diagonal_matrix",
965
+ "identity_matrix",
966
+ ]:
967
+ layers = list(layers_old[models[0]].keys())
968
+ layers_old_flattened = [
969
+ torch.cat([layers_old[model][layer].flatten() for layer in layers])
970
+ for model in models
971
+ ]
972
+ metadata["pairwise_cosine_similarity_matrix_before"] = (
973
+ pairwise_cosine_similarity_matrix(layers_old_flattened).tolist()
974
+ )
975
+ metadata["task_vector_dim"] = layers_old_flattened[0].shape[0]
976
+ # layers_new = deepcopy(retrieval_models)
977
+ rms = []
978
+ for retrieval_model in test_models:
979
+ print(f"Retrieval model: {retrieval_model}")
980
+ layers_new_flattened = [
981
+ torch.cat(
982
+ [
983
+ tv_new[retrieval_model][model][layer].flatten()
984
+ for layer in layers
985
+ ]
986
+ )
987
+ for model in models
988
+ ]
989
+ rms.append(layers_new_flattened)
990
+ # print(layers_new_flattened[0][:50])
991
+ # layers_new_flattened = [torch.cat([layers_new[retrieval_model][layer].flatten() for layer in layers]) for model in models]
992
+ pcsm = pairwise_cosine_similarity_matrix(
993
+ layers_new_flattened
994
+ ).tolist()
995
+ print(pcsm)
996
+ metadata[
997
+ f"pairwise_cosine_similarity_matrix_after_{retrieval_model}"
998
+ ] = pcsm
999
+ if self.debug >= 0:
1000
+ metadata["nonzero_parameter_count"] = (
1001
+ nonzero_param_count.item()
1002
+ if isinstance(nonzero_param_count, torch.Tensor)
1003
+ else nonzero_param_count
1004
+ )
1005
+ metadata["nonzero_param_count_context"] = (
1006
+ nonzero_param_count_context.item()
1007
+ if isinstance(nonzero_param_count_context, torch.Tensor)
1008
+ else nonzero_param_count_context
1009
+ )
1010
+ gbs = total_bytes_retrieved / 1e9
1011
+ metadata["total_gb_retrieved"] = gbs
1012
+
1013
+ return retrieval_models, metadata