fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/base_algorithm.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/preference_700k.py +1 -1
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/surgery/__init__.py +1 -3
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/clip_classification.py +6 -6
- fusion_bench/mixins/lightning_fabric.py +3 -1
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
- fusion_bench/programs/fabric_fusion_program.py +7 -4
- fusion_bench/taskpool/llama/reward_model.py +1 -1
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/type.py +10 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,642 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils.type import StateDictType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def compute_svd_dict(task_vectors, config):
|
|
10
|
+
"""
|
|
11
|
+
Computes the Singular Value Decomposition (SVD) for each task vector in the provided datasets and stores the results in a dictionary.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
task_vectors (list): A list of task vector objects, where each task vector contains a dictionary of matrices to be decomposed.
|
|
15
|
+
config (object): Configuration object containing the list of datasets under the attribute `DATASETS`.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
dict: A dictionary where each key is a dataset name and the value is another dictionary containing the SVD components ('u', 's', 'v') for each matrix in the task vector.
|
|
19
|
+
If a matrix is not 2-dimensional or contains 'text_projection' in its key, it is stored under the key 'dim1' without decomposition.
|
|
20
|
+
"""
|
|
21
|
+
sv_reduction = 1 / len(config.DATASETS)
|
|
22
|
+
with torch.no_grad():
|
|
23
|
+
svd_dict = {}
|
|
24
|
+
for i, (task_vector, dataset) in enumerate(zip(task_vectors, config.DATASETS)):
|
|
25
|
+
svd_dict[dataset] = {}
|
|
26
|
+
print(f"Computing SVD for {dataset}...")
|
|
27
|
+
for key in task_vector.vector:
|
|
28
|
+
svd_dict[dataset][key] = {}
|
|
29
|
+
if (
|
|
30
|
+
len(task_vector.vector[key].shape) == 2
|
|
31
|
+
and "text_projection" not in key
|
|
32
|
+
):
|
|
33
|
+
u, s, v = torch.linalg.svd(
|
|
34
|
+
task_vector.vector[key], full_matrices=False
|
|
35
|
+
)
|
|
36
|
+
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
37
|
+
|
|
38
|
+
temp_u = torch.zeros_like(u)
|
|
39
|
+
# select only the first reduced_index_s columns of u and place them
|
|
40
|
+
temp_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
41
|
+
:, :reduced_index_s
|
|
42
|
+
]
|
|
43
|
+
svd_dict[dataset][key]["u"] = temp_u
|
|
44
|
+
|
|
45
|
+
temp_s = torch.zeros_like(s)
|
|
46
|
+
temp_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
47
|
+
:reduced_index_s
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
svd_dict[dataset][key]["s"] = temp_s # s_reduced
|
|
51
|
+
|
|
52
|
+
# select only the first reduced_index_s rows of v and place them
|
|
53
|
+
temp_v = torch.zeros_like(v)
|
|
54
|
+
temp_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
55
|
+
:reduced_index_s, :
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
svd_dict[dataset][key]["v"] = temp_v
|
|
59
|
+
|
|
60
|
+
# temp_mat = temp_u @ torch.diag_embed(temp_s) @ temp_v
|
|
61
|
+
else:
|
|
62
|
+
svd_dict[dataset][key]["dim1"] = task_vector.vector[key]
|
|
63
|
+
return svd_dict
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def sum_svd_dict(svd_dict, config):
|
|
67
|
+
"""
|
|
68
|
+
Sums the Singular Value Decomposition (SVD) components from multiple datasets and computes a new vector.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
svd_dict (dict): A dictionary containing SVD components for multiple datasets. The structure of the dictionary is expected to be:
|
|
72
|
+
{
|
|
73
|
+
dataset_name: {
|
|
74
|
+
key: {
|
|
75
|
+
"u": tensor,
|
|
76
|
+
"s": tensor,
|
|
77
|
+
"v": tensor,
|
|
78
|
+
"dim1": tensor (optional)
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
config (object): A configuration object that contains a list of dataset names under the attribute `DATASETS`.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
dict: A dictionary containing the merged SVD components or averaged "dim1" values for each key.
|
|
86
|
+
"""
|
|
87
|
+
print("Summing SVD...")
|
|
88
|
+
new_vector = {}
|
|
89
|
+
for key in svd_dict[config.DATASETS[0]]:
|
|
90
|
+
if "u" in svd_dict[config.DATASETS[0]][key].keys():
|
|
91
|
+
sum_u = sum([svd_dict[dataset][key]["u"] for dataset in config.DATASETS])
|
|
92
|
+
sum_s = sum([svd_dict[dataset][key]["s"] for dataset in config.DATASETS])
|
|
93
|
+
sum_v = sum([svd_dict[dataset][key]["v"] for dataset in config.DATASETS])
|
|
94
|
+
|
|
95
|
+
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
96
|
+
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
|
|
97
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
98
|
+
(
|
|
99
|
+
u_u,
|
|
100
|
+
v_u,
|
|
101
|
+
torch.diag(sum_s),
|
|
102
|
+
u_v,
|
|
103
|
+
v_v,
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
for i, dataset in enumerate(config.DATASETS, start=1):
|
|
108
|
+
if i == 1:
|
|
109
|
+
new_vector[key] = svd_dict[dataset][key]["dim1"]
|
|
110
|
+
else:
|
|
111
|
+
new_vector[key] += (
|
|
112
|
+
svd_dict[dataset][key]["dim1"] - new_vector[key]
|
|
113
|
+
) / i
|
|
114
|
+
return new_vector
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
###############
|
|
118
|
+
##### LOSSLESS Orthogonalization
|
|
119
|
+
def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
|
|
120
|
+
"""
|
|
121
|
+
Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
|
|
122
|
+
|
|
123
|
+
This function performs the following steps:
|
|
124
|
+
1. Iterates over each layer in the task vectors.
|
|
125
|
+
2. For each layer, it computes the SVD of the corresponding matrix if it is a 2D tensor excluding "text_projection".
|
|
126
|
+
3. Concatenate the U_i, S_i, and V_i matrices from the SVD across all tasks.
|
|
127
|
+
4. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
|
|
128
|
+
5. After concatenating the SVD components, recomputes the SVD of the summed U and V matrices and constructs the merged layer.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
|
|
132
|
+
config (object): A configuration object containing the device and dataset information.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
dict: A dictionary containing the new vectors after summing the SVD components.
|
|
136
|
+
"""
|
|
137
|
+
# becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
|
|
138
|
+
device = config.device
|
|
139
|
+
print("Computing SVD...")
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
new_vector = {}
|
|
142
|
+
for key in task_vectors[0].vector:
|
|
143
|
+
new_vector[key] = {}
|
|
144
|
+
for i, (task_vector, dataset) in enumerate(
|
|
145
|
+
zip(task_vectors, config.DATASETS)
|
|
146
|
+
):
|
|
147
|
+
vec = task_vector.vector[key].to(device)
|
|
148
|
+
|
|
149
|
+
if (
|
|
150
|
+
len(task_vector.vector[key].shape) == 2
|
|
151
|
+
and "text_projection" not in key
|
|
152
|
+
):
|
|
153
|
+
|
|
154
|
+
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
155
|
+
|
|
156
|
+
if i == 0:
|
|
157
|
+
print(f"Computed SVD for {key}...")
|
|
158
|
+
sum_u = torch.zeros(
|
|
159
|
+
u.shape[0], u.shape[1] * config.num_tasks, device=device
|
|
160
|
+
)
|
|
161
|
+
sum_s = torch.zeros(
|
|
162
|
+
s.shape[0] * config.num_tasks, device=device
|
|
163
|
+
)
|
|
164
|
+
sum_v = torch.zeros(
|
|
165
|
+
v.shape[0] * config.num_tasks, v.shape[1], device=device
|
|
166
|
+
)
|
|
167
|
+
reduced_index_s = s.shape[0]
|
|
168
|
+
|
|
169
|
+
# select only the first reduced_index_s columns of u and place them
|
|
170
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
171
|
+
:, :reduced_index_s
|
|
172
|
+
]
|
|
173
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
174
|
+
:reduced_index_s
|
|
175
|
+
]
|
|
176
|
+
# select only the first reduced_index_s rows of v and place them
|
|
177
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
178
|
+
:reduced_index_s, :
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
else:
|
|
182
|
+
if i == 0:
|
|
183
|
+
new_vector[key] = vec.clone()
|
|
184
|
+
else:
|
|
185
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
186
|
+
|
|
187
|
+
if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
|
|
188
|
+
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
189
|
+
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
|
|
190
|
+
|
|
191
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
192
|
+
(
|
|
193
|
+
u_u,
|
|
194
|
+
v_u,
|
|
195
|
+
torch.diag(sum_s),
|
|
196
|
+
u_v,
|
|
197
|
+
v_v,
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return new_vector
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
###############
|
|
205
|
+
##### LOSSLESS EIGENDECOMP
|
|
206
|
+
def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
|
|
207
|
+
"""
|
|
208
|
+
Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
|
|
209
|
+
|
|
210
|
+
This function performs the following steps:
|
|
211
|
+
1. Iterates over each layer in the task vectors.
|
|
212
|
+
2. For each layer, it computes the SVD of the corresponding matrix if it is a 2D tensor excluding "text_projection".
|
|
213
|
+
3. Concatenate the U_i, S_i, and V_i matrices from the SVD across all tasks.
|
|
214
|
+
4. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
|
|
215
|
+
5. After concatenating the SVD components, recomputes the eigendecomposition of the summed U and V matrices and constructs the merged layer.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
|
|
219
|
+
config (object): A configuration object containing the device and dataset information.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
dict: A dictionary containing the new vectors after merging the SVD components.
|
|
223
|
+
"""
|
|
224
|
+
# becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
|
|
225
|
+
device = config.device
|
|
226
|
+
print("Computing SVD...")
|
|
227
|
+
with torch.no_grad():
|
|
228
|
+
new_vector = {}
|
|
229
|
+
for key in task_vectors[0].vector:
|
|
230
|
+
new_vector[key] = {}
|
|
231
|
+
for i, (task_vector, dataset) in enumerate(
|
|
232
|
+
zip(task_vectors, config.DATASETS)
|
|
233
|
+
):
|
|
234
|
+
vec = task_vector.vector[key].to(device)
|
|
235
|
+
|
|
236
|
+
if (
|
|
237
|
+
len(task_vector.vector[key].shape) == 2
|
|
238
|
+
and "text_projection" not in key
|
|
239
|
+
):
|
|
240
|
+
|
|
241
|
+
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
242
|
+
|
|
243
|
+
if i == 0:
|
|
244
|
+
print(f"Computed SVD for {key}...")
|
|
245
|
+
sum_u = torch.zeros(
|
|
246
|
+
u.shape[0], u.shape[1] * config.num_tasks, device=device
|
|
247
|
+
)
|
|
248
|
+
sum_s = torch.zeros(
|
|
249
|
+
s.shape[0] * config.num_tasks, device=device
|
|
250
|
+
)
|
|
251
|
+
sum_v = torch.zeros(
|
|
252
|
+
v.shape[0] * config.num_tasks, v.shape[1], device=device
|
|
253
|
+
)
|
|
254
|
+
reduced_index_s = s.shape[0]
|
|
255
|
+
|
|
256
|
+
# select only the first reduced_index_s columns of u and place them
|
|
257
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
258
|
+
:, :reduced_index_s
|
|
259
|
+
]
|
|
260
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
261
|
+
:reduced_index_s
|
|
262
|
+
]
|
|
263
|
+
# select only the first reduced_index_s rows of v and place them
|
|
264
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
265
|
+
:reduced_index_s, :
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
else:
|
|
269
|
+
if i == 0:
|
|
270
|
+
new_vector[key] = vec.clone()
|
|
271
|
+
else:
|
|
272
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
273
|
+
|
|
274
|
+
if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
|
|
275
|
+
sum_s, indices = torch.sort(sum_s, stable=True)
|
|
276
|
+
|
|
277
|
+
sum_u = torch.index_select(sum_u, 1, indices)
|
|
278
|
+
l_u, q_u = torch.linalg.eigh(sum_u.mT @ sum_u)
|
|
279
|
+
u_orth = (
|
|
280
|
+
q_u
|
|
281
|
+
@ torch.diag(1.0 / (torch.sqrt(torch.abs(l_u)) + 1e-12))
|
|
282
|
+
@ q_u.mT
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
sum_v = torch.index_select(sum_v, 0, indices)
|
|
286
|
+
|
|
287
|
+
l_v, q_v = torch.linalg.eigh(sum_v @ sum_v.mT)
|
|
288
|
+
v_orth = (
|
|
289
|
+
q_v
|
|
290
|
+
@ torch.diag(1.0 / (torch.sqrt(torch.abs(l_v)) + 1e-12))
|
|
291
|
+
@ q_v.mT
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
new_vector[key] = torch.linalg.multi_dot( # bool_mask *
|
|
295
|
+
(
|
|
296
|
+
u_orth,
|
|
297
|
+
torch.diag(sum_s),
|
|
298
|
+
v_orth,
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
return new_vector
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
###############
|
|
306
|
+
#### TSV Merge Orthogonalization
|
|
307
|
+
@torch.no_grad()
|
|
308
|
+
def compute_and_sum_svd_mem_reduction(
|
|
309
|
+
task_vectors: List[StateDictType],
|
|
310
|
+
exclude_keys: Optional[List[str]] = None,
|
|
311
|
+
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
312
|
+
) -> StateDictType:
|
|
313
|
+
"""
|
|
314
|
+
Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
|
|
315
|
+
reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
|
|
316
|
+
the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
|
|
317
|
+
Computation of the SVD is performed also for the second operation.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
task_vectors (list): A list of task vector objects, where each object contains a
|
|
321
|
+
dictionary of vectors.
|
|
322
|
+
exclude_keys (list): A list of keys to exclude from the TSVM.
|
|
323
|
+
accelerator (torch.device): The device to use for the computation.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
dict: A dictionary containing the new vectors after SVD computation and merging.
|
|
327
|
+
"""
|
|
328
|
+
if exclude_keys is None:
|
|
329
|
+
exclude_keys = []
|
|
330
|
+
sv_reduction = 1 / len(task_vectors)
|
|
331
|
+
|
|
332
|
+
new_vector = {}
|
|
333
|
+
for key in task_vectors[0]:
|
|
334
|
+
original_device = task_vectors[0][key].device
|
|
335
|
+
original_dtype = task_vectors[0][key].dtype
|
|
336
|
+
|
|
337
|
+
new_vector[key] = {}
|
|
338
|
+
for i, task_vector in enumerate(task_vectors):
|
|
339
|
+
vec = task_vector[key].to(accelerator)
|
|
340
|
+
|
|
341
|
+
if len(task_vector[key].shape) == 2 and key not in exclude_keys:
|
|
342
|
+
# at current, the SVD is not supported for half precision, so we need to convert to float32
|
|
343
|
+
if not (
|
|
344
|
+
original_dtype == torch.float32 or original_dtype == torch.float64
|
|
345
|
+
):
|
|
346
|
+
vec = vec.to(dtype=torch.float32)
|
|
347
|
+
|
|
348
|
+
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
349
|
+
|
|
350
|
+
if i == 0:
|
|
351
|
+
print(f"Computed SVD for {key}...")
|
|
352
|
+
sum_u = torch.zeros_like(u, device=accelerator)
|
|
353
|
+
sum_s = torch.zeros_like(s, device=accelerator)
|
|
354
|
+
sum_v = torch.zeros_like(v, device=accelerator)
|
|
355
|
+
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
356
|
+
|
|
357
|
+
# select only the first reduced_index_s columns of u and place them
|
|
358
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
359
|
+
:, :reduced_index_s
|
|
360
|
+
]
|
|
361
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
362
|
+
:reduced_index_s
|
|
363
|
+
]
|
|
364
|
+
# select only the first reduced_index_s rows of v and place them
|
|
365
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
366
|
+
:reduced_index_s, :
|
|
367
|
+
]
|
|
368
|
+
|
|
369
|
+
else:
|
|
370
|
+
# if the vector is not a 2D tensor or is in exclude_keys, compute the mean
|
|
371
|
+
if i == 0:
|
|
372
|
+
new_vector[key] = vec.clone()
|
|
373
|
+
else:
|
|
374
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
375
|
+
|
|
376
|
+
if len(task_vector[key].shape) == 2 and key not in exclude_keys:
|
|
377
|
+
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
378
|
+
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
|
|
379
|
+
|
|
380
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
381
|
+
(
|
|
382
|
+
u_u,
|
|
383
|
+
v_u,
|
|
384
|
+
torch.diag(sum_s),
|
|
385
|
+
u_v,
|
|
386
|
+
v_v,
|
|
387
|
+
)
|
|
388
|
+
)
|
|
389
|
+
new_vector[key] = new_vector[key].to(
|
|
390
|
+
device=original_device, dtype=original_dtype, non_blocking=True
|
|
391
|
+
)
|
|
392
|
+
return new_vector
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
###############
|
|
396
|
+
#### TSV Merge Eigendecomp
|
|
397
|
+
def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
|
|
398
|
+
"""
|
|
399
|
+
Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
|
|
400
|
+
reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
|
|
401
|
+
the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
|
|
402
|
+
Computation of the eigendecomposition is performed instead of the SVD for the second operation.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
task_vectors (list): A list of task vector objects, where each object contains a
|
|
406
|
+
dictionary of vectors.
|
|
407
|
+
config (object): Configuration object containing the following attributes:
|
|
408
|
+
- DATASETS (list): List of datasets.
|
|
409
|
+
- device (torch.device): The device to perform computations on.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
dict: A dictionary containing the new vectors after SVD computation and merging.
|
|
413
|
+
"""
|
|
414
|
+
sv_reduction = 1 / len(config.DATASETS)
|
|
415
|
+
device = config.device
|
|
416
|
+
print("Computing SVD...")
|
|
417
|
+
with torch.no_grad():
|
|
418
|
+
new_vector = {}
|
|
419
|
+
for key in task_vectors[0].vector:
|
|
420
|
+
new_vector[key] = {}
|
|
421
|
+
for i, (task_vector, dataset) in enumerate(
|
|
422
|
+
zip(task_vectors, config.DATASETS)
|
|
423
|
+
):
|
|
424
|
+
vec = task_vector.vector[key].to(device)
|
|
425
|
+
|
|
426
|
+
if (
|
|
427
|
+
len(task_vector.vector[key].shape) == 2
|
|
428
|
+
and "text_projection" not in key
|
|
429
|
+
):
|
|
430
|
+
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
431
|
+
|
|
432
|
+
if i == 0:
|
|
433
|
+
print(f"Computed SVD for {key}...")
|
|
434
|
+
sum_u = torch.zeros_like(u, device=device)
|
|
435
|
+
sum_s = torch.zeros_like(s, device=device)
|
|
436
|
+
sum_v = torch.zeros_like(v, device=device)
|
|
437
|
+
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
438
|
+
|
|
439
|
+
# select only the first reduced_index_s columns of u and place them
|
|
440
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
441
|
+
:, :reduced_index_s
|
|
442
|
+
]
|
|
443
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
444
|
+
:reduced_index_s
|
|
445
|
+
]
|
|
446
|
+
# select only the first reduced_index_s rows of v and place them
|
|
447
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
448
|
+
:reduced_index_s, :
|
|
449
|
+
]
|
|
450
|
+
|
|
451
|
+
else:
|
|
452
|
+
if i == 0:
|
|
453
|
+
new_vector[key] = vec.clone()
|
|
454
|
+
else:
|
|
455
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
456
|
+
|
|
457
|
+
if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
|
|
458
|
+
sum_s, indices = torch.sort(sum_s, stable=True)
|
|
459
|
+
|
|
460
|
+
sum_u = torch.index_select(sum_u, 1, indices)
|
|
461
|
+
l_u, q_u = torch.linalg.eigh(sum_u.mT @ sum_u)
|
|
462
|
+
u_orth = (
|
|
463
|
+
q_u
|
|
464
|
+
@ torch.diag(1.0 / (torch.sqrt(torch.abs(l_u)) + 1e-12))
|
|
465
|
+
@ q_u.mT
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
sum_v = torch.index_select(sum_v, 0, indices)
|
|
469
|
+
|
|
470
|
+
l_v, q_v = torch.linalg.eigh(sum_v @ sum_v.mT)
|
|
471
|
+
v_orth = (
|
|
472
|
+
q_v
|
|
473
|
+
@ torch.diag(1.0 / (torch.sqrt(torch.abs(l_v)) + 1e-12))
|
|
474
|
+
@ q_v.mT
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
new_vector[key] = torch.linalg.multi_dot( # bool_mask *
|
|
478
|
+
(
|
|
479
|
+
sum_u,
|
|
480
|
+
u_orth,
|
|
481
|
+
torch.diag(sum_s),
|
|
482
|
+
v_orth,
|
|
483
|
+
sum_v,
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
return new_vector
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
###############
|
|
491
|
+
#### Rank Reduction TV
|
|
492
|
+
def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
|
|
493
|
+
"""
|
|
494
|
+
Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction.
|
|
495
|
+
|
|
496
|
+
This function performs SVD on the vectors in `task_vectors` and reduces their rank based on the
|
|
497
|
+
number of tasks specified in `config.DATASETS`. The reduced vectors are then summed together.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute
|
|
501
|
+
which is a dictionary where keys are vector names and values are tensors.
|
|
502
|
+
config (object): Configuration object containing the following attributes:
|
|
503
|
+
- DATASETS (list): List of datasets.
|
|
504
|
+
- device (torch.device): The device to perform computations on.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
dict: A dictionary containing the new vectors after SVD computation and summation.
|
|
508
|
+
"""
|
|
509
|
+
sv_reduction = 1 / len(config.DATASETS)
|
|
510
|
+
device = config.device
|
|
511
|
+
print("Computing SVD...")
|
|
512
|
+
with torch.no_grad():
|
|
513
|
+
new_vector = {}
|
|
514
|
+
for key in task_vectors[0].vector:
|
|
515
|
+
new_vector[key] = {}
|
|
516
|
+
for i, (task_vector, dataset) in enumerate(
|
|
517
|
+
zip(task_vectors, config.DATASETS)
|
|
518
|
+
):
|
|
519
|
+
vec = task_vector.vector[key].to(device)
|
|
520
|
+
|
|
521
|
+
if (
|
|
522
|
+
len(task_vector.vector[key].shape) == 2
|
|
523
|
+
and "text_projection" not in key
|
|
524
|
+
):
|
|
525
|
+
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
526
|
+
|
|
527
|
+
if i == 0:
|
|
528
|
+
print(f"Computed SVD for {key}...")
|
|
529
|
+
sum_u = torch.zeros_like(u, device=device)
|
|
530
|
+
sum_s = torch.zeros_like(s, device=device)
|
|
531
|
+
sum_v = torch.zeros_like(v, device=device)
|
|
532
|
+
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
533
|
+
|
|
534
|
+
# select only the first reduced_index_s columns of u and place them
|
|
535
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
536
|
+
:, :reduced_index_s
|
|
537
|
+
]
|
|
538
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
539
|
+
:reduced_index_s
|
|
540
|
+
]
|
|
541
|
+
# select only the first reduced_index_s rows of v and place them
|
|
542
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
543
|
+
:reduced_index_s, :
|
|
544
|
+
]
|
|
545
|
+
|
|
546
|
+
else:
|
|
547
|
+
if i == 0:
|
|
548
|
+
new_vector[key] = vec.clone()
|
|
549
|
+
else:
|
|
550
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
551
|
+
|
|
552
|
+
if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
|
|
553
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
554
|
+
(
|
|
555
|
+
sum_u,
|
|
556
|
+
torch.diag(sum_s),
|
|
557
|
+
sum_v,
|
|
558
|
+
)
|
|
559
|
+
)
|
|
560
|
+
return new_vector
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
|
|
564
|
+
"""To perform dummy operations."""
|
|
565
|
+
sv_reduction = 1 / 8
|
|
566
|
+
print("Computing SVD...")
|
|
567
|
+
with torch.no_grad():
|
|
568
|
+
new_vector = {}
|
|
569
|
+
for key in task_vectors[0].vector:
|
|
570
|
+
new_vector[key] = {}
|
|
571
|
+
for i in range(0, 8):
|
|
572
|
+
if (
|
|
573
|
+
len(task_vectors[0].vector[key].shape) == 2
|
|
574
|
+
and "text_projection" not in key
|
|
575
|
+
):
|
|
576
|
+
if i == 0:
|
|
577
|
+
u, s, v = torch.linalg.svd(
|
|
578
|
+
task_vectors[0].vector[key], full_matrices=False
|
|
579
|
+
)
|
|
580
|
+
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
581
|
+
|
|
582
|
+
print(f"Computed SVD for {key}...")
|
|
583
|
+
sum_u = torch.zeros_like(u)
|
|
584
|
+
sum_s = torch.zeros_like(s)
|
|
585
|
+
sum_v = torch.zeros_like(v)
|
|
586
|
+
|
|
587
|
+
# select only the first reduced_index_s columns of u and place them
|
|
588
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
589
|
+
:, :reduced_index_s
|
|
590
|
+
]
|
|
591
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
592
|
+
:reduced_index_s
|
|
593
|
+
]
|
|
594
|
+
# select only the first reduced_index_s rows of v and place them
|
|
595
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
596
|
+
:reduced_index_s, :
|
|
597
|
+
]
|
|
598
|
+
else:
|
|
599
|
+
# generate u vectors orthogonal to the previous ones
|
|
600
|
+
# generate v vectors orthogonal to the previous ones
|
|
601
|
+
print("dummy")
|
|
602
|
+
u = torch.nn.functional.normalize(
|
|
603
|
+
torch.randn_like(sum_u), p=2, dim=-2
|
|
604
|
+
)
|
|
605
|
+
v = torch.nn.functional.normalize(
|
|
606
|
+
torch.randn_like(sum_v), p=2, dim=-1
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
# select only the first reduced_index_s columns of u and place them
|
|
610
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
611
|
+
:, :reduced_index_s
|
|
612
|
+
]
|
|
613
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
614
|
+
:reduced_index_s
|
|
615
|
+
]
|
|
616
|
+
# select only the first reduced_index_s rows of v and place them
|
|
617
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
618
|
+
:reduced_index_s, :
|
|
619
|
+
]
|
|
620
|
+
|
|
621
|
+
else:
|
|
622
|
+
if i == 0:
|
|
623
|
+
new_vector[key] = task_vectors[0].vector[key]
|
|
624
|
+
# else:
|
|
625
|
+
# new_vector[key] += (
|
|
626
|
+
# task_vector.vector[key] - new_vector[key]
|
|
627
|
+
# ) / (i + 1)
|
|
628
|
+
|
|
629
|
+
if (
|
|
630
|
+
len(task_vectors[0].vector[key].shape) == 2
|
|
631
|
+
and "text_projection" not in key
|
|
632
|
+
):
|
|
633
|
+
|
|
634
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
635
|
+
(
|
|
636
|
+
sum_u,
|
|
637
|
+
torch.diag(sum_s),
|
|
638
|
+
sum_v,
|
|
639
|
+
)
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
return new_vector
|
|
@@ -4,10 +4,13 @@ This is modified based on https://github.com/EnnengYang/AdaMerging/blob/main/src
|
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
6
|
from collections import OrderedDict
|
|
7
|
+
from typing import List
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
from torch import Tensor, nn
|
|
10
11
|
|
|
12
|
+
from fusion_bench.utils.type import StateDictType
|
|
13
|
+
|
|
11
14
|
|
|
12
15
|
# Model conversion utils
|
|
13
16
|
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
@@ -82,7 +85,7 @@ def add_ptm_to_tv(tv_dict, ptm_dict):
|
|
|
82
85
|
return final_dict
|
|
83
86
|
|
|
84
87
|
|
|
85
|
-
def check_parameterNamesMatch(checkpoints):
|
|
88
|
+
def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
|
|
86
89
|
"""
|
|
87
90
|
Check if the parameter names match across multiple checkpoints.
|
|
88
91
|
|
|
@@ -105,7 +108,9 @@ def check_parameterNamesMatch(checkpoints):
|
|
|
105
108
|
)
|
|
106
109
|
|
|
107
110
|
|
|
108
|
-
def check_state_dicts_equal(
|
|
111
|
+
def check_state_dicts_equal(
|
|
112
|
+
state_dict1: StateDictType, state_dict2: StateDictType
|
|
113
|
+
) -> bool:
|
|
109
114
|
"""
|
|
110
115
|
Check if two state dictionaries are equal.
|
|
111
116
|
|