fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/base_algorithm.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/preference_700k.py +1 -1
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/surgery/__init__.py +1 -3
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/clip_classification.py +6 -6
- fusion_bench/mixins/lightning_fabric.py +3 -1
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
- fusion_bench/programs/fabric_fusion_program.py +7 -4
- fusion_bench/taskpool/llama/reward_model.py +1 -1
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/type.py +10 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -65,4 +65,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
65
65
|
else:
|
|
66
66
|
# if processor is None, return the raw image directly
|
|
67
67
|
inputs = image
|
|
68
|
+
# convert boolean label to int, this is for the case when the label is a binary classification task
|
|
69
|
+
if isinstance(item["label"], bool):
|
|
70
|
+
item["label"] = 1 if item["label"] else 0
|
|
68
71
|
return inputs, item["label"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from datasets import load_dataset
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def load_fer2013(path: str = "clip-benchmark/wds_fer2013", split: str = "train"):
|
|
5
|
+
dataset = load_dataset(path, split=split)
|
|
6
|
+
dataset = dataset.remove_columns(["__key__", "__url__"])
|
|
7
|
+
dataset = dataset.rename_columns({"jpg": "image", "cls": "label"})
|
|
8
|
+
return dataset
|
|
9
|
+
|
|
10
|
+
if __name__ == "__main__":
|
|
11
|
+
dataset = load_fer2013(split="test")
|
|
12
|
+
print(dataset)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
from copy import deepcopy
|
|
3
4
|
from typing import TYPE_CHECKING, Optional
|
|
@@ -7,7 +8,6 @@ from lightning.fabric.utilities import rank_zero_only
|
|
|
7
8
|
from tqdm.auto import tqdm
|
|
8
9
|
|
|
9
10
|
from fusion_bench.utils import timeit_context
|
|
10
|
-
import logging
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from transformers import PreTrainedTokenizer
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -49,6 +49,7 @@ _import_structure = {
|
|
|
49
49
|
"PWEMoExactParetoOptimalForCLIP",
|
|
50
50
|
],
|
|
51
51
|
"ada_svd": ["AdaSVDMergingForCLIPVisionModel"],
|
|
52
|
+
"task_singular_vector": ["TaskSingularVectorMerging"],
|
|
52
53
|
# plug-and-play model merging methods
|
|
53
54
|
"concrete_subspace": [
|
|
54
55
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -153,6 +154,7 @@ if TYPE_CHECKING:
|
|
|
153
154
|
SparseLoForLlama,
|
|
154
155
|
)
|
|
155
156
|
from .task_arithmetic import TaskArithmeticAlgorithm
|
|
157
|
+
from .task_singular_vector import TaskSingularVectorMerging
|
|
156
158
|
from .ties_merging import TiesMergingAlgorithm
|
|
157
159
|
from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
|
|
158
160
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
@@ -41,11 +41,10 @@ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
|
41
41
|
from fusion_bench import print_parameters
|
|
42
42
|
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
43
43
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
44
|
-
from fusion_bench.
|
|
45
|
-
HuggingFaceClipVisionPool,
|
|
46
|
-
)
|
|
44
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
47
45
|
from fusion_bench.mixins import CLIPClassificationMixin
|
|
48
46
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
47
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
49
48
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
50
49
|
from fusion_bench.models.linearized.linearized_model_utils import LinearizedModelWraper
|
|
51
50
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
@@ -92,12 +91,12 @@ class ImageClassificationFineTuningForCLIP(
|
|
|
92
91
|
A class for fine-tuning CLIP models for image classification tasks.
|
|
93
92
|
"""
|
|
94
93
|
|
|
95
|
-
def run(self, modelpool:
|
|
94
|
+
def run(self, modelpool: CLIPVisionModelPool):
|
|
96
95
|
"""
|
|
97
96
|
Executes the fine-tuning process.
|
|
98
97
|
|
|
99
98
|
Args:
|
|
100
|
-
modelpool (
|
|
99
|
+
modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.
|
|
101
100
|
|
|
102
101
|
Returns:
|
|
103
102
|
VisionModel: The fine-tuned vision model.
|
|
@@ -109,9 +108,7 @@ class ImageClassificationFineTuningForCLIP(
|
|
|
109
108
|
|
|
110
109
|
L.seed_everything(config.seed)
|
|
111
110
|
|
|
112
|
-
task_names =
|
|
113
|
-
dataset_config["name"] for dataset_config in modelpool.config.train_datasets
|
|
114
|
-
]
|
|
111
|
+
task_names = modelpool.train_dataset_names
|
|
115
112
|
with self.profile("setup model and optimizer"):
|
|
116
113
|
processor, classifier, optimizer, lr_scheduler = self.setup_model()
|
|
117
114
|
|
|
@@ -133,7 +130,7 @@ class ImageClassificationFineTuningForCLIP(
|
|
|
133
130
|
|
|
134
131
|
with self.profile("setup data"):
|
|
135
132
|
train_datasets = [
|
|
136
|
-
modelpool.
|
|
133
|
+
CLIPDataset(modelpool.load_train_dataset(task_name), processor)
|
|
137
134
|
for task_name in task_names
|
|
138
135
|
]
|
|
139
136
|
train_dataloaders = [
|
|
@@ -157,6 +154,7 @@ class ImageClassificationFineTuningForCLIP(
|
|
|
157
154
|
range(config.num_steps),
|
|
158
155
|
desc=self.finetune_method,
|
|
159
156
|
disable=not self.fabric.is_global_zero,
|
|
157
|
+
dynamic_ncols=True,
|
|
160
158
|
):
|
|
161
159
|
optimizer.zero_grad()
|
|
162
160
|
loss = 0
|
|
@@ -183,7 +181,7 @@ class ImageClassificationFineTuningForCLIP(
|
|
|
183
181
|
save_path = os.path.join(
|
|
184
182
|
self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
|
|
185
183
|
)
|
|
186
|
-
self.save_model(classifier, save_path
|
|
184
|
+
self.save_model(classifier, save_path)
|
|
187
185
|
|
|
188
186
|
if config.state_dict_save_path is not None:
|
|
189
187
|
self.save_model(
|
|
@@ -232,9 +230,8 @@ class ImageClassificationFineTuningForCLIP(
|
|
|
232
230
|
config = self.config
|
|
233
231
|
modelpool = self.modelpool
|
|
234
232
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
processor = CLIPProcessor.from_pretrained(pretrained_model_config.path)
|
|
233
|
+
clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
|
|
234
|
+
processor = modelpool.load_processor()
|
|
238
235
|
|
|
239
236
|
self.finetune_method = "full fine-tune"
|
|
240
237
|
if config.use_lora or config.use_l_lora:
|
|
File without changes
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def generate_task_masks(
|
|
12
|
+
tv_flat_checks: torch.Tensor,
|
|
13
|
+
flat_ft: torch.Tensor,
|
|
14
|
+
flat_ptm: torch.Tensor,
|
|
15
|
+
tv: Optional[torch.Tensor] = None,
|
|
16
|
+
tall_mask_lambda: float = 1.0,
|
|
17
|
+
) -> torch.Tensor:
|
|
18
|
+
"""
|
|
19
|
+
Generate task-specific TALL masks
|
|
20
|
+
TALL masks are generated as: mask_t = |theta_0 - theta_t| > |theta_mt - theta_t| * lambda
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
tv_flat_checks: individual task vectors
|
|
24
|
+
flat_ft: individual theta_t (fine-tuned weights)
|
|
25
|
+
flat_ptm: theta_0 (pre-trained weight)
|
|
26
|
+
tv: multi-task vector
|
|
27
|
+
tall_mask_lambda: hyper-parameter lambda for generating TALL masks
|
|
28
|
+
Returns:
|
|
29
|
+
final_mask: generated TALL masks with the given lambda, in shape (n_task, n_parameter)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
print(f"Generating TALL masks.")
|
|
33
|
+
|
|
34
|
+
if tv is None:
|
|
35
|
+
tv = tv_flat_checks.sum(0)
|
|
36
|
+
|
|
37
|
+
flat_multi = flat_ptm + tv
|
|
38
|
+
|
|
39
|
+
original_shape = flat_ft.shape
|
|
40
|
+
|
|
41
|
+
# generate masks by comparing the l1 distance between |theta_0 - theta_t| and |theta_mt - theta_t|
|
|
42
|
+
diff_pt_ft = (flat_ptm - flat_ft).abs()
|
|
43
|
+
diff_multi_ft = (flat_multi - flat_ft).abs()
|
|
44
|
+
# compare the l1 distance, scaled with hyper-parameter lambda
|
|
45
|
+
mask = diff_pt_ft > diff_multi_ft * tall_mask_lambda
|
|
46
|
+
|
|
47
|
+
final_mask = (
|
|
48
|
+
mask.squeeze() if original_shape == tv_flat_checks.squeeze().shape else mask
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
print(
|
|
52
|
+
f"Average sparsity for the mask with tall_mask_lambda of {tall_mask_lambda}: {final_mask.float().mean():.4f}"
|
|
53
|
+
)
|
|
54
|
+
return final_mask
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def construct_tall_mask(
|
|
58
|
+
tv_flat_checks: torch.Tensor,
|
|
59
|
+
flat_ft: torch.Tensor,
|
|
60
|
+
flat_ptm: torch.Tensor,
|
|
61
|
+
merged_tv: torch.Tensor,
|
|
62
|
+
ptm_check: torch.Tensor,
|
|
63
|
+
remove_keys: List[str],
|
|
64
|
+
config,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Construct TALL masks for all tasks for each lambda, and store in dictionary
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
tv_flat_checks: individual task vectors
|
|
71
|
+
flat_ft: individual theta_t (fine-tuned weights)
|
|
72
|
+
flat_ptm: theta_0 (pre-trained weight)
|
|
73
|
+
merged_tv: multi-task vector
|
|
74
|
+
ptm_check: pre-trained weight as state dictionary
|
|
75
|
+
remove_keys: the keys to be removed when converting between dictionary and vector
|
|
76
|
+
Returns:
|
|
77
|
+
tall_masks: constructed TALL masks in dictionary format of {lambda: {task: mask}}
|
|
78
|
+
"""
|
|
79
|
+
tall_masks = {}
|
|
80
|
+
for tall_mask_lambda in [0.2, 0.3, 0.4, 0.5, 0.6]:
|
|
81
|
+
# generate tall masks for each lambda
|
|
82
|
+
masks_at_scale = generate_task_masks(
|
|
83
|
+
tv_flat_checks,
|
|
84
|
+
flat_ft,
|
|
85
|
+
flat_ptm,
|
|
86
|
+
tall_mask_lambda=tall_mask_lambda,
|
|
87
|
+
tv=merged_tv,
|
|
88
|
+
)
|
|
89
|
+
# convert vectors to dictionary
|
|
90
|
+
masks_at_scale = [
|
|
91
|
+
vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys)
|
|
92
|
+
for mask in masks_at_scale
|
|
93
|
+
]
|
|
94
|
+
# store the masks with {dataset: mask}
|
|
95
|
+
tall_masks[tall_mask_lambda] = {
|
|
96
|
+
key: value for key, value in zip(config.DATASETS, masks_at_scale)
|
|
97
|
+
}
|
|
98
|
+
return tall_masks
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def find_optimal_mask(val_metrics, eval_masks, args, save_masks=True):
|
|
102
|
+
"""
|
|
103
|
+
Respectively finds the optimal mask for each data task based on the validation accuracy
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
val_metrics: validation metrics for each lambda
|
|
107
|
+
eval_masks: all generated masks
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
best_masks_for_test: the best masks for each task, selected based on validation accuracy from each task
|
|
111
|
+
best_val_metrics: best validation metrics for each task
|
|
112
|
+
"""
|
|
113
|
+
# transpose the dict from lambda-task to task-lambda
|
|
114
|
+
transposed_dict = {}
|
|
115
|
+
for key, inner_dict in val_metrics.items():
|
|
116
|
+
for inner_key, value in inner_dict.items():
|
|
117
|
+
if inner_key not in transposed_dict:
|
|
118
|
+
transposed_dict[inner_key] = {}
|
|
119
|
+
transposed_dict[inner_key][key] = value
|
|
120
|
+
|
|
121
|
+
# for each task, find the best lambda
|
|
122
|
+
max_subkeys = {
|
|
123
|
+
key: max(inner_dict, key=inner_dict.get)
|
|
124
|
+
for key, inner_dict in transposed_dict.items()
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
# select the best mask for each task, which will be used for testing later
|
|
128
|
+
best_masks_for_test = {}
|
|
129
|
+
best_masks_for_test_vector = {} # the selected masks as vectors
|
|
130
|
+
best_val_metrics = {}
|
|
131
|
+
# respectively for each task:
|
|
132
|
+
for ds in args.DATASETS:
|
|
133
|
+
# select the lambda which achieves the best valdiation accuracy
|
|
134
|
+
best_lambda = float(max_subkeys[ds + "Val:top1"])
|
|
135
|
+
# select the mask based on the selected lambda, save as dictionaries
|
|
136
|
+
best_masks_for_test[ds] = eval_masks[best_lambda][ds]
|
|
137
|
+
# select the mask based on the selected lambda, save as vectors
|
|
138
|
+
best_masks_for_test_vector[ds] = state_dict_to_vector(
|
|
139
|
+
eval_masks[best_lambda][ds], remove_keys=[]
|
|
140
|
+
)
|
|
141
|
+
print(f"Best lambda for {ds} is {best_lambda}")
|
|
142
|
+
# save the best validation metric based on the selected lambda
|
|
143
|
+
best_val_metrics[ds + "Val:top1"] = val_metrics[best_lambda][ds + "Val:top1"]
|
|
144
|
+
|
|
145
|
+
# save the best masks in disk
|
|
146
|
+
if save_masks and not args.method.load_mask:
|
|
147
|
+
# convert to numpy to save with np.packbits for saving storage
|
|
148
|
+
best_masks_for_test_vector = {
|
|
149
|
+
k: np.packbits(v) for k, v in best_masks_for_test_vector.items()
|
|
150
|
+
}
|
|
151
|
+
mask_save_dir = args.model_location.replace("checkpoints", "tall_masks")
|
|
152
|
+
mask_name = (
|
|
153
|
+
f"TALL_mask_{args.num_tasks}task.npy"
|
|
154
|
+
if not args.method.use_ties
|
|
155
|
+
else f"TALL_mask_{args.num_tasks}task_use_ties_{args.method.ties_agg}.npy"
|
|
156
|
+
)
|
|
157
|
+
np.save(
|
|
158
|
+
os.path.join(mask_save_dir, args.model, mask_name),
|
|
159
|
+
best_masks_for_test_vector,
|
|
160
|
+
)
|
|
161
|
+
del best_masks_for_test_vector
|
|
162
|
+
|
|
163
|
+
return best_masks_for_test, best_val_metrics
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def load_tall_mask(remove_keys, ptm_check, config):
|
|
167
|
+
"""Loads TALL masks from disk, unpack and transform to state dictionaries."""
|
|
168
|
+
mask_location = config.model_location.replace("checkpoints", "tall_masks")
|
|
169
|
+
try:
|
|
170
|
+
if config.method.use_ties:
|
|
171
|
+
print("==== Loading TALL Masks built with TIES ====")
|
|
172
|
+
tall_masks = torch.load(
|
|
173
|
+
os.path.join(
|
|
174
|
+
mask_location,
|
|
175
|
+
config.model,
|
|
176
|
+
f"TALL_mask_{config.num_tasks}task_use_ties.npy",
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
print("==== Loading TALL Masks built with Task Arithmetic ====")
|
|
181
|
+
tall_masks = torch.load(
|
|
182
|
+
os.path.join(
|
|
183
|
+
mask_location, config.model, f"TALL_mask_{config.num_tasks}task.npy"
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
except:
|
|
187
|
+
raise Exception("TALL Masks are not constructed yet.")
|
|
188
|
+
|
|
189
|
+
# unpack masks and convert back to torch tensors
|
|
190
|
+
tall_masks = {k: torch.from_numpy(np.unpackbits(v)) for k, v in tall_masks.items()}
|
|
191
|
+
|
|
192
|
+
# convert vectors to dictionaries
|
|
193
|
+
tall_masks = {
|
|
194
|
+
dataset: vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys)
|
|
195
|
+
for dataset, mask in tall_masks.items()
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
return tall_masks
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def construct_consensus_mask(ptm_check, prun_thre_k, config, remove_keys=[]):
|
|
202
|
+
"""
|
|
203
|
+
Generate consensus mask by filtering out least-used parameters
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
ptm_check: pretrained_checkpoint as state dictionary
|
|
207
|
+
prun_thre_k: weight-pruning threhold, stands for the least number of activated tasks for a parameter to be preserved from pruning
|
|
208
|
+
if prun_thre_k is set to 2: remove both catastrophic and selfish weights;
|
|
209
|
+
if prun_thre_k is set to 1: remove only catastrophic weights;
|
|
210
|
+
if prun_thre_k is set to 0: remove no weights -> reduce to TA or TIES
|
|
211
|
+
if prun_thre_k is set to > num_tasks: remove all weights -> reduce to zero-shot
|
|
212
|
+
Returns:
|
|
213
|
+
consensus_mask_vector: constructed consensus mask as vector (boolean in shape (n_parameter, ))
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
print("==== Generating Consensus Mask ====")
|
|
217
|
+
# load TALL masks (in shape (n_task, n_parameter))
|
|
218
|
+
tall_masks = load_tall_mask(remove_keys, ptm_check, config)
|
|
219
|
+
tall_masks = list(tall_masks.values())
|
|
220
|
+
|
|
221
|
+
# generate consensus masks
|
|
222
|
+
consensus_mask = copy.deepcopy(tall_masks[0])
|
|
223
|
+
for key, value in consensus_mask.items():
|
|
224
|
+
consensus_mask[key] = torch.zeros_like(value)
|
|
225
|
+
# count for each parameter, the tasks it has been activated for
|
|
226
|
+
for mask in tall_masks:
|
|
227
|
+
consensus_mask[key] = consensus_mask[key] + mask[key].float()
|
|
228
|
+
# filter out the least-activated parameters based on given threshold
|
|
229
|
+
consensus_mask[key] = consensus_mask[key].float() >= prun_thre_k
|
|
230
|
+
consensus_mask_vector = state_dict_to_vector(
|
|
231
|
+
consensus_mask, remove_keys=remove_keys
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
return consensus_mask_vector
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor, nn
|
|
3
|
+
|
|
4
|
+
from fusion_bench import BaseAlgorithm
|
|
5
|
+
|
|
6
|
+
from .utils import TSVC_utils, check_parameterNamesMatch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TaskSingularVectorCompression(BaseAlgorithm):
|
|
10
|
+
def __init__(self, **kwargs):
|
|
11
|
+
super().__init__(**kwargs)
|
|
12
|
+
|
|
13
|
+
def run(self, modelpool):
|
|
14
|
+
raise NotImplementedError(
|
|
15
|
+
"Task Singular Vector Compression is not implemented yet."
|
|
16
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
method=task_singular_vector/TaskSingularVectorMerging \
|
|
7
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only \
|
|
8
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TALL20
|
|
9
|
+
```
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import List, Optional
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from torch import Tensor, nn
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseAlgorithm
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
19
|
+
from fusion_bench.utils import timeit_context
|
|
20
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub, state_dict_add
|
|
21
|
+
from fusion_bench.utils.type import StateDictType
|
|
22
|
+
|
|
23
|
+
from .utils import (
|
|
24
|
+
TSVM_utils,
|
|
25
|
+
check_parameterNamesMatch,
|
|
26
|
+
check_state_dicts_equal,
|
|
27
|
+
state_dict_to_vector,
|
|
28
|
+
vector_to_state_dict,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
remove_keys: Optional[List[str]] = None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
self.remove_keys = remove_keys if remove_keys is not None else []
|
|
40
|
+
super().__init__(**kwargs)
|
|
41
|
+
|
|
42
|
+
def run(self, modelpool):
|
|
43
|
+
# Load the pre-trained model and the fine-tuned models
|
|
44
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
45
|
+
finetuned_models = list(modelpool.models())
|
|
46
|
+
|
|
47
|
+
ptm_check = pretrained_model.state_dict()
|
|
48
|
+
ft_checks = [model.state_dict() for model in finetuned_models]
|
|
49
|
+
check_parameterNamesMatch(ft_checks + [ptm_check])
|
|
50
|
+
|
|
51
|
+
with timeit_context("Flattening out Checkpoints"):
|
|
52
|
+
task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
|
|
53
|
+
|
|
54
|
+
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
55
|
+
task_vectors,
|
|
56
|
+
exclude_keys=self.remove_keys,
|
|
57
|
+
accelerator=self.fabric.device,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
pretrained_model.load_state_dict(
|
|
61
|
+
state_dict_add(new_merged_tv, pretrained_model.state_dict())
|
|
62
|
+
)
|
|
63
|
+
return pretrained_model
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module is modified from the original code of the paper:
|
|
3
|
+
|
|
4
|
+
- Gargiulo, et.al. Task Singular Vectors: Reducing Task Interference in Model Merging
|
|
5
|
+
- http://arxiv.org/abs/2412.00081
|
|
6
|
+
- https://github.com/AntoAndGar/task_singular_vectors/
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .TSVM import TaskSingularVectorMerging
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def compute_svd_and_compress(key, matrix, sv_reduction):
|
|
5
|
+
"""
|
|
6
|
+
Computes the Singular Value Decomposition (SVD) of a given matrix and compresses it by reducing the number of singular values.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
key (Any): An identifier for the matrix.
|
|
10
|
+
matrix (torch.Tensor): The input matrix to decompose.
|
|
11
|
+
sv_reduction (float): The fraction of singular values to retain (0 < sv_reduction <= 1).
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
tuple: A tuple containing:
|
|
15
|
+
- key (Any): The original identifier for the matrix.
|
|
16
|
+
- u (torch.Tensor): The left singular vectors of the reduced SVD.
|
|
17
|
+
- s (torch.Tensor): The reduced singular values.
|
|
18
|
+
- v (torch.Tensor): The right singular vectors of the reduced SVD.
|
|
19
|
+
"""
|
|
20
|
+
u, s, v = torch.linalg.svd(matrix, full_matrices=False)
|
|
21
|
+
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
22
|
+
return key, u[:, :reduced_index_s], s[:reduced_index_s], v[:reduced_index_s, :]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compress_tv(task_vectors, sv_reduction):
|
|
26
|
+
"""
|
|
27
|
+
Compress task vectors using Singular Value Decomposition (SVD).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
task_vectors (dict): A dictionary where keys are dataset names and values are task vectors.
|
|
31
|
+
Each task vector is expected to have a 'vector' attribute which is a dictionary
|
|
32
|
+
with keys as layer names and values as layer matrices.
|
|
33
|
+
sv_reduction (int): The fraction of singular values to keep for compression.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
dict: A dictionary with the same structure as `task_vectors`, but with each layer matrix
|
|
37
|
+
replaced by its compressed SVD components (u, s, v) if the layer is 2-dimensional.
|
|
38
|
+
If the layer is not 2-dimensional, it is stored as is under the key "dim1".
|
|
39
|
+
"""
|
|
40
|
+
with torch.no_grad():
|
|
41
|
+
svd_dict = {}
|
|
42
|
+
for dataset, task_vector in task_vectors.items():
|
|
43
|
+
svd_dict[dataset] = {}
|
|
44
|
+
for key, layer in task_vector.vector.items():
|
|
45
|
+
if len(layer.shape) == 2: # and "text_projection" not in key:
|
|
46
|
+
_, u, s, v = compute_svd_and_compress(key, layer, sv_reduction)
|
|
47
|
+
svd_dict[dataset][key] = {"u": u, "s": s, "v": v}
|
|
48
|
+
else:
|
|
49
|
+
svd_dict[dataset][key] = {"dim1": layer}
|
|
50
|
+
return svd_dict
|