fusion-bench 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +14 -5
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/pylogger.py +28 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/RECORD +104 -44
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,31 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
1
|
+
R"""
|
|
2
|
+
# Task Singular Vector Merging (TSVM) Algorithm Implementation
|
|
3
|
+
|
|
4
|
+
This module implements the Task Singular Vector Merging algorithm for combining multiple fine-tuned models
|
|
5
|
+
into a single unified model.
|
|
6
|
+
|
|
7
|
+
## Example Usage:
|
|
8
|
+
|
|
9
|
+
Merge 8 CLIP-ViT-B/32 models with TSVM:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
fusion_bench \
|
|
13
|
+
method=task_singular_vector/TaskSingularVectorMerging \
|
|
14
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
|
|
15
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
Merge 8 CLIP-ViT-B/32 models with TSVM and return individual transformed models:
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
fusion_bench \
|
|
22
|
+
method=task_singular_vector/TaskSingularVectorMerging \
|
|
23
|
+
method.return_single_task_models=true \
|
|
24
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
|
|
25
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Merge 20 CLIP-VIT-B/32 models with TSVM:
|
|
3
29
|
|
|
4
30
|
```bash
|
|
5
31
|
fusion_bench \
|
|
@@ -7,14 +33,22 @@ fusion_bench \
|
|
|
7
33
|
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only \
|
|
8
34
|
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TALL20
|
|
9
35
|
```
|
|
36
|
+
|
|
37
|
+
## References:
|
|
38
|
+
|
|
39
|
+
- Gargiulo, et al. Task Singular Vectors: Reducing Task Interference in Model Merging.
|
|
40
|
+
https://arxiv.org/abs/2412.00081
|
|
41
|
+
- See `docs/algorithms/task_singular_vector.md` for more details.
|
|
10
42
|
"""
|
|
11
43
|
|
|
44
|
+
from copy import deepcopy
|
|
12
45
|
from typing import Iterable, List, Optional, Union
|
|
13
46
|
|
|
14
47
|
import torch
|
|
15
48
|
from omegaconf import ListConfig
|
|
16
49
|
from torch import Tensor, nn
|
|
17
50
|
|
|
51
|
+
import fusion_bench as fb
|
|
18
52
|
from fusion_bench import BaseAlgorithm
|
|
19
53
|
from fusion_bench.mixins import LightningFabricMixin
|
|
20
54
|
from fusion_bench.utils import timeit_context
|
|
@@ -35,49 +69,228 @@ from .utils import (
|
|
|
35
69
|
|
|
36
70
|
|
|
37
71
|
class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
72
|
+
"""
|
|
73
|
+
Task Singular Vector Merging (TSVM) Algorithm
|
|
74
|
+
|
|
75
|
+
This class implements a model merging technique that leverages Singular Value
|
|
76
|
+
Decomposition (SVD) to identify and combine the most important directions in the task vector
|
|
77
|
+
space. The algorithm is particularly effective for merging multiple models fine-tuned on
|
|
78
|
+
different tasks while preserving their essential capabilities.
|
|
79
|
+
|
|
80
|
+
Key Concepts:
|
|
81
|
+
- Task Vector: The difference between a fine-tuned model and its pretrained base model,
|
|
82
|
+
representing the knowledge gained during fine-tuning for a specific task.
|
|
83
|
+
- Singular Value Decomposition: A matrix factorization technique used to find the principal
|
|
84
|
+
components (most important directions) in the space of task vectors.
|
|
85
|
+
- Model Merging: The process of combining multiple models into a single unified model that
|
|
86
|
+
retains capabilities from all constituent models.
|
|
87
|
+
|
|
88
|
+
Algorithm Steps:
|
|
89
|
+
1. Extract task vectors from all fine-tuned models by subtracting the pretrained model
|
|
90
|
+
2. Apply SVD to the matrix of task vectors to find principal components
|
|
91
|
+
3. Reconstruct task vectors using only the most significant singular vectors
|
|
92
|
+
4. Merge the reconstructed task vectors (either individually scaled or as a sum)
|
|
93
|
+
5. Add the final merged task vector to the pretrained model to create the unified model
|
|
94
|
+
|
|
95
|
+
see `docs/algorithms/task_singular_vector.md` for comprehensive algorithmic details.
|
|
96
|
+
"""
|
|
38
97
|
|
|
39
98
|
def __init__(
|
|
40
99
|
self,
|
|
41
|
-
alpha: Union[float, Iterable[float]] = None,
|
|
42
|
-
|
|
100
|
+
alpha: Optional[Union[float, Iterable[float]]] = None,
|
|
101
|
+
exclude_keys: Optional[List[str]] = None,
|
|
102
|
+
return_single_task_models: bool = False,
|
|
43
103
|
**kwargs,
|
|
44
104
|
):
|
|
105
|
+
"""
|
|
106
|
+
Initialize the Task Singular Vector Merging algorithm.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
alpha (Union[float, Iterable[float]], optional): Scaling factor(s) for task vectors.
|
|
110
|
+
This parameter controls the strength of the task-specific adaptations in the final model.
|
|
111
|
+
|
|
112
|
+
- If a single float: Applied to the final merged task vector after SVD reconstruction.
|
|
113
|
+
This uniformly scales the entire merged adaptation.
|
|
114
|
+
|
|
115
|
+
- If an iterable of floats: Applied to individual task vectors before SVD and merging.
|
|
116
|
+
Must have the same length as the number of models in the modelpool.
|
|
117
|
+
This allows for task-specific weighting (e.g., giving more importance to certain tasks).
|
|
118
|
+
|
|
119
|
+
- If None: No scaling is applied (equivalent to alpha=1.0).
|
|
120
|
+
|
|
121
|
+
Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.
|
|
122
|
+
|
|
123
|
+
exclude_keys (Optional[List[str]], optional): List of parameter names to exclude from TSVM.
|
|
124
|
+
These parameters will not participate in the SVD computation and merging process.
|
|
125
|
+
Useful for excluding certain layers (e.g., task-specific heads, normalization layers)
|
|
126
|
+
that should not be merged across tasks. Defaults to an empty list.
|
|
127
|
+
|
|
128
|
+
Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.
|
|
129
|
+
|
|
130
|
+
return_single_task_models (bool, optional): Whether to return individual transformed models.
|
|
131
|
+
|
|
132
|
+
- If True: Returns a dictionary containing both individual models with their transformed
|
|
133
|
+
task vectors applied AND the final merged model. The dictionary has the structure:
|
|
134
|
+
|
|
135
|
+
>>> {'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}
|
|
136
|
+
|
|
137
|
+
- If False: Returns only the final merged model.
|
|
138
|
+
|
|
139
|
+
This is useful for analysis or when you need access to intermediate results.
|
|
140
|
+
Defaults to False.
|
|
141
|
+
|
|
142
|
+
**kwargs: Additional arguments passed to the parent BaseAlgorithm class.
|
|
143
|
+
|
|
144
|
+
Note:
|
|
145
|
+
The choice between single alpha vs. list of alphas affects the merging strategy:
|
|
146
|
+
- Single alpha: SVD is applied first, then the result is scaled
|
|
147
|
+
- List of alphas: Individual task vectors are scaled first, then SVD is applied
|
|
148
|
+
"""
|
|
45
149
|
self.alpha = alpha
|
|
46
|
-
self.
|
|
150
|
+
self.exclude_keys = exclude_keys if exclude_keys is not None else []
|
|
151
|
+
self.return_single_task_models = return_single_task_models
|
|
47
152
|
super().__init__(**kwargs)
|
|
48
153
|
|
|
49
|
-
def
|
|
50
|
-
|
|
154
|
+
def load_pretrained_model_and_task_vectors(self, modelpool: fb.BaseModelPool):
|
|
155
|
+
"""
|
|
156
|
+
Load the pretrained base model and compute task vectors from all fine-tuned models.
|
|
157
|
+
|
|
158
|
+
This method performs the initial step of the TSVM algorithm by:
|
|
159
|
+
1. Loading the original pretrained model (before any task-specific fine-tuning)
|
|
160
|
+
2. For each fine-tuned model in the pool:
|
|
161
|
+
- Load the fine-tuned model
|
|
162
|
+
- Compute the task vector (fine-tuned params - pretrained params)
|
|
163
|
+
- Optionally apply individual scaling if alpha is provided as a list
|
|
164
|
+
|
|
165
|
+
Task vectors represent the knowledge gained during fine-tuning and are the core
|
|
166
|
+
data structure that TSVM operates on.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
modelpool (fb.BaseModelPool): Pool containing the pretrained model and all
|
|
170
|
+
fine-tuned models to be merged.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
tuple: A tuple containing:
|
|
174
|
+
- pretrained_model: The original pretrained model (torch.nn.Module)
|
|
175
|
+
- task_vectors: List of task vectors (List[StateDictType]), where each
|
|
176
|
+
task vector is a state dictionary representing the parameter differences
|
|
177
|
+
for one specific task
|
|
178
|
+
"""
|
|
179
|
+
# Load the original pretrained model that serves as the base for all fine-tuned variants
|
|
51
180
|
pretrained_model = modelpool.load_pretrained_model()
|
|
52
|
-
finetuned_models = list(modelpool.models())
|
|
53
181
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
182
|
+
# Initialize list to store computed task vectors
|
|
183
|
+
task_vectors = []
|
|
184
|
+
|
|
185
|
+
# Process each fine-tuned model in the modelpool
|
|
186
|
+
for model_idx, model_name in enumerate(modelpool.model_names):
|
|
187
|
+
# Load the current fine-tuned model
|
|
188
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
189
|
+
|
|
190
|
+
# Compute task vector: difference between fine-tuned and pretrained parameters
|
|
191
|
+
# This captures the task-specific adaptations learned during fine-tuning
|
|
192
|
+
task_vector = state_dict_sub(
|
|
193
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
194
|
+
)
|
|
195
|
+
task_vectors.append(task_vector)
|
|
57
196
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
if isinstance(self.alpha, Iterable):
|
|
197
|
+
# Apply individual scaling to task vectors if alpha is provided as a list
|
|
198
|
+
# This allows for task-specific weighting before the SVD computation
|
|
199
|
+
if self.alpha is not None and isinstance(self.alpha, Iterable):
|
|
200
|
+
# Ensure the number of alpha values matches the number of models
|
|
61
201
|
assert len(self.alpha) == len(
|
|
62
|
-
|
|
63
|
-
), "Alpha
|
|
64
|
-
task_vectors = [
|
|
65
|
-
state_dict_mul(state_dict=tv, scalar=alpha)
|
|
66
|
-
for alpha, tv in zip(self.alpha, task_vectors)
|
|
67
|
-
]
|
|
202
|
+
modelpool.model_names
|
|
203
|
+
), f"Alpha list length ({len(self.alpha)}) must match number of models ({len(modelpool.model_names)})"
|
|
68
204
|
|
|
205
|
+
# Scale the current task vector by its corresponding alpha value
|
|
206
|
+
task_vectors[-1] = state_dict_mul(
|
|
207
|
+
state_dict=task_vectors[-1], scalar=self.alpha[model_idx]
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return pretrained_model, task_vectors
|
|
211
|
+
|
|
212
|
+
def run(self, modelpool: fb.BaseModelPool):
|
|
213
|
+
"""
|
|
214
|
+
Execute the complete Task Singular Vector Merging algorithm.
|
|
215
|
+
|
|
216
|
+
This is the main entry point that orchestrates the entire TSVM process:
|
|
217
|
+
|
|
218
|
+
The algorithm leverages the mathematical insight that task vectors often lie in a
|
|
219
|
+
lower-dimensional subspace, and SVD helps identify the most important directions
|
|
220
|
+
in this subspace while filtering out noise and interference.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
modelpool (fb.BaseModelPool): Pool of models to merge, including:
|
|
224
|
+
- The pretrained base model
|
|
225
|
+
- Multiple fine-tuned models (one per task)
|
|
226
|
+
All models must have compatible architectures.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Union[torch.nn.Module, Dict[str, torch.nn.Module]]:
|
|
230
|
+
- If return_single_task_models=False: Returns the merged model
|
|
231
|
+
- If return_single_task_models=True: Returns a dictionary with:
|
|
232
|
+
* Individual transformed models keyed by their original names
|
|
233
|
+
* Final merged model under the key 'merged'
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
AssertionError: If alpha list length doesn't match the number of models
|
|
237
|
+
"""
|
|
238
|
+
# Determine the compute device for SVD operations (GPU if available for faster computation)
|
|
239
|
+
accelerator = self.fabric.device
|
|
240
|
+
|
|
241
|
+
# Phase 1: Load pretrained model and compute task vectors from all fine-tuned models
|
|
242
|
+
pretrained_model, task_vectors = self.load_pretrained_model_and_task_vectors(
|
|
243
|
+
modelpool
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Phase 2: Apply SVD-based merging to the task vectors
|
|
247
|
+
# This is the core of the TSVM algorithm where:
|
|
248
|
+
# - Task vectors are organized into a matrix
|
|
249
|
+
# - SVD finds the principal components (most important directions)
|
|
250
|
+
# - Task vectors are reconstructed using only the most significant components
|
|
251
|
+
# - The reconstructed vectors are merged (summed) to create a unified task vector
|
|
69
252
|
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
70
253
|
task_vectors,
|
|
71
|
-
exclude_keys=self.
|
|
72
|
-
accelerator=
|
|
254
|
+
exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
|
|
255
|
+
accelerator=accelerator, # Use GPU if available
|
|
256
|
+
return_single_task_models=self.return_single_task_models,
|
|
73
257
|
)
|
|
74
258
|
|
|
75
|
-
#
|
|
76
|
-
if self.
|
|
259
|
+
# Handle the case where individual transformed task vectors are also returned
|
|
260
|
+
if self.return_single_task_models:
|
|
261
|
+
new_merged_tv, single_task_models = new_merged_tv
|
|
262
|
+
|
|
263
|
+
# Phase 3: Apply global scaling to the merged task vector (if alpha is a single value)
|
|
264
|
+
# This is different from individual scaling applied earlier - here we scale the
|
|
265
|
+
# final merged result, which affects the overall strength of all merged adaptations
|
|
266
|
+
if self.alpha is not None and isinstance(self.alpha, (float, int)):
|
|
77
267
|
print(f"Scaling new merged task vector by alpha: {self.alpha}")
|
|
78
268
|
new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
|
|
79
269
|
|
|
270
|
+
# Phase 4: Prepare individual transformed models if requested
|
|
271
|
+
if self.return_single_task_models:
|
|
272
|
+
models = {}
|
|
273
|
+
# Create individual models by adding each transformed task vector to the pretrained base
|
|
274
|
+
for model_idx, model_name in enumerate(modelpool.model_names):
|
|
275
|
+
# Create a deep copy to avoid modifying the original pretrained model
|
|
276
|
+
model = deepcopy(pretrained_model)
|
|
277
|
+
# Apply the transformed task vector to get the individual model
|
|
278
|
+
model.load_state_dict(
|
|
279
|
+
state_dict_add(model.state_dict(), single_task_models[model_idx])
|
|
280
|
+
)
|
|
281
|
+
models[model_name] = model
|
|
282
|
+
|
|
283
|
+
# Phase 5: Create the final merged model by adding the merged task vector to pretrained model
|
|
284
|
+
# This produces a single model that combines capabilities from all input models
|
|
80
285
|
pretrained_model.load_state_dict(
|
|
81
286
|
state_dict_add(new_merged_tv, pretrained_model.state_dict())
|
|
82
287
|
)
|
|
83
|
-
|
|
288
|
+
|
|
289
|
+
# Phase 6: Return results based on the requested output format
|
|
290
|
+
if self.return_single_task_models:
|
|
291
|
+
# Include the final merged model in the dictionary of results
|
|
292
|
+
models["merged"] = pretrained_model
|
|
293
|
+
return models
|
|
294
|
+
else:
|
|
295
|
+
# Return only the merged model
|
|
296
|
+
return pretrained_model
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import collections
|
|
1
2
|
import math
|
|
2
|
-
from typing import List, Optional
|
|
3
|
+
from typing import Dict, List, Optional
|
|
3
4
|
|
|
4
5
|
import torch
|
|
6
|
+
from torch import Tensor, nn
|
|
5
7
|
|
|
6
8
|
from fusion_bench.utils.type import StateDictType
|
|
7
9
|
|
|
@@ -314,7 +316,8 @@ def compute_and_sum_svd_mem_reduction(
|
|
|
314
316
|
task_vectors: List[StateDictType],
|
|
315
317
|
exclude_keys: Optional[List[str]] = None,
|
|
316
318
|
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
317
|
-
|
|
319
|
+
return_single_task_models: bool = False,
|
|
320
|
+
):
|
|
318
321
|
"""
|
|
319
322
|
Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
|
|
320
323
|
reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
|
|
@@ -326,23 +329,25 @@ def compute_and_sum_svd_mem_reduction(
|
|
|
326
329
|
dictionary of vectors.
|
|
327
330
|
exclude_keys (list): A list of keys to exclude from the TSVM.
|
|
328
331
|
accelerator (torch.device): The device to use for the computation.
|
|
332
|
+
return_single_task_models (bool): Whether to return the single task models after the TSVM.
|
|
329
333
|
|
|
330
334
|
Returns:
|
|
331
335
|
dict: A dictionary containing the new vectors after SVD computation and merging.
|
|
332
336
|
"""
|
|
333
337
|
if exclude_keys is None:
|
|
334
338
|
exclude_keys = []
|
|
335
|
-
|
|
339
|
+
num_tasks = len(task_vectors)
|
|
340
|
+
sv_reduction = 1 / num_tasks
|
|
336
341
|
|
|
337
|
-
new_vector = {}
|
|
342
|
+
new_vector: Dict[str, Tensor] = {}
|
|
343
|
+
if return_single_task_models:
|
|
344
|
+
single_task_models = [{} for _ in range(num_tasks)]
|
|
338
345
|
for key in task_vectors[0]:
|
|
339
346
|
original_device = task_vectors[0][key].device
|
|
340
347
|
original_dtype = task_vectors[0][key].dtype
|
|
341
348
|
|
|
342
|
-
new_vector[key] = {}
|
|
343
349
|
for i, task_vector in enumerate(task_vectors):
|
|
344
|
-
vec = task_vector[key].to(accelerator)
|
|
345
|
-
|
|
350
|
+
vec = task_vector[key].to(device=accelerator, non_blocking=True)
|
|
346
351
|
if len(task_vector[key].shape) == 2 and key not in exclude_keys:
|
|
347
352
|
# at current, the SVD is not supported for half precision, so we need to convert to float32
|
|
348
353
|
if not (
|
|
@@ -350,13 +355,14 @@ def compute_and_sum_svd_mem_reduction(
|
|
|
350
355
|
):
|
|
351
356
|
vec = vec.to(dtype=torch.float32)
|
|
352
357
|
|
|
353
|
-
|
|
358
|
+
# vec = u @ torch.diag(s) @ vh
|
|
359
|
+
u, s, vh = torch.linalg.svd(vec, full_matrices=False)
|
|
354
360
|
|
|
355
361
|
if i == 0:
|
|
356
362
|
print(f"Computed SVD for {key}...")
|
|
357
363
|
sum_u = torch.zeros_like(u, device=accelerator)
|
|
358
364
|
sum_s = torch.zeros_like(s, device=accelerator)
|
|
359
|
-
|
|
365
|
+
sum_vh = torch.zeros_like(vh, device=accelerator)
|
|
360
366
|
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
361
367
|
|
|
362
368
|
# select only the first reduced_index_s columns of u and place them
|
|
@@ -367,10 +373,9 @@ def compute_and_sum_svd_mem_reduction(
|
|
|
367
373
|
:reduced_index_s
|
|
368
374
|
]
|
|
369
375
|
# select only the first reduced_index_s rows of v and place them
|
|
370
|
-
|
|
376
|
+
sum_vh[i * reduced_index_s : (i + 1) * reduced_index_s, :] = vh[
|
|
371
377
|
:reduced_index_s, :
|
|
372
378
|
]
|
|
373
|
-
|
|
374
379
|
else:
|
|
375
380
|
# if the vector is not a 2D tensor or is in exclude_keys, compute the mean
|
|
376
381
|
if i == 0:
|
|
@@ -379,22 +384,49 @@ def compute_and_sum_svd_mem_reduction(
|
|
|
379
384
|
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
380
385
|
|
|
381
386
|
if len(task_vector[key].shape) == 2 and key not in exclude_keys:
|
|
382
|
-
u_u, s_u,
|
|
383
|
-
|
|
387
|
+
u_u, s_u, vh_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
388
|
+
u_vh, s_vh, vh_vh = torch.linalg.svd(sum_vh, full_matrices=False)
|
|
384
389
|
|
|
385
390
|
new_vector[key] = torch.linalg.multi_dot(
|
|
386
391
|
(
|
|
387
392
|
u_u,
|
|
388
|
-
|
|
393
|
+
vh_u,
|
|
389
394
|
torch.diag(sum_s),
|
|
390
|
-
|
|
391
|
-
|
|
395
|
+
u_vh,
|
|
396
|
+
vh_vh,
|
|
392
397
|
)
|
|
393
398
|
)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
399
|
+
new_vector[key] = new_vector[key].to(
|
|
400
|
+
device=original_device, dtype=original_dtype, non_blocking=True
|
|
401
|
+
)
|
|
402
|
+
if return_single_task_models:
|
|
403
|
+
reduced_index_s = int(sum_s.shape[0] * sv_reduction)
|
|
404
|
+
new_u = u_u @ vh_u
|
|
405
|
+
new_vh = u_vh @ vh_vh
|
|
406
|
+
for i in range(num_tasks):
|
|
407
|
+
single_task_models[i][key] = torch.linalg.multi_dot(
|
|
408
|
+
(
|
|
409
|
+
new_u[:, i * reduced_index_s : (i + 1) * reduced_index_s],
|
|
410
|
+
torch.diag(
|
|
411
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s]
|
|
412
|
+
),
|
|
413
|
+
new_vh[i * reduced_index_s : (i + 1) * reduced_index_s, :],
|
|
414
|
+
)
|
|
415
|
+
).to(
|
|
416
|
+
device=original_device, dtype=original_dtype, non_blocking=True
|
|
417
|
+
)
|
|
418
|
+
else:
|
|
419
|
+
new_vector[key] = new_vector[key].to(
|
|
420
|
+
device=original_device, dtype=original_dtype, non_blocking=True
|
|
421
|
+
)
|
|
422
|
+
if return_single_task_models:
|
|
423
|
+
for i in range(num_tasks):
|
|
424
|
+
single_task_models[i][key] = new_vector[key].clone()
|
|
425
|
+
|
|
426
|
+
if not return_single_task_models:
|
|
427
|
+
return new_vector
|
|
428
|
+
else:
|
|
429
|
+
return new_vector, single_task_models
|
|
398
430
|
|
|
399
431
|
|
|
400
432
|
###############
|
|
@@ -9,7 +9,7 @@ from hydra import compose, initialize
|
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
|
|
11
11
|
from fusion_bench.utils import import_object, instantiate
|
|
12
|
-
from fusion_bench.utils.
|
|
12
|
+
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
13
13
|
|
|
14
14
|
log = logging.getLogger(__name__)
|
|
15
15
|
|
|
@@ -11,7 +11,7 @@ from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
|
11
11
|
from omegaconf import DictConfig, OmegaConf
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import import_object
|
|
14
|
-
from fusion_bench.utils.
|
|
14
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
import lightning.fabric.loggers.tensorboard
|
|
@@ -172,3 +172,27 @@ class LightningFabricMixin:
|
|
|
172
172
|
return True
|
|
173
173
|
else:
|
|
174
174
|
return False
|
|
175
|
+
|
|
176
|
+
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
177
|
+
"""
|
|
178
|
+
Logs the metric to the fabric's logger.
|
|
179
|
+
"""
|
|
180
|
+
self.fabric.log(name, value, step=step)
|
|
181
|
+
|
|
182
|
+
def log_dict(self, metrics: dict, step: Optional[int] = None):
|
|
183
|
+
"""
|
|
184
|
+
Logs the metrics to the fabric's logger.
|
|
185
|
+
"""
|
|
186
|
+
self.fabric.log_dict(metrics, step=step)
|
|
187
|
+
|
|
188
|
+
def log_optimizer_lr(
|
|
189
|
+
self,
|
|
190
|
+
optimizer: torch.optim.Optimizer,
|
|
191
|
+
step: Optional[int] = None,
|
|
192
|
+
name_template: str = "train/lr_group_{0}",
|
|
193
|
+
):
|
|
194
|
+
"""
|
|
195
|
+
Logs the learning rate of the optimizer to the fabric's logger.
|
|
196
|
+
"""
|
|
197
|
+
for i, param_group in enumerate(optimizer.param_groups):
|
|
198
|
+
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
|
@@ -4,13 +4,14 @@ from typing import Dict, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
from omegaconf import OmegaConf
|
|
6
6
|
|
|
7
|
-
from fusion_bench.utils import instantiate
|
|
7
|
+
from fusion_bench.utils import import_object, instantiate
|
|
8
8
|
|
|
9
9
|
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class YAMLSerializationMixin:
|
|
13
13
|
_recursive_: bool = False
|
|
14
|
+
_config_key: Optional[str] = None
|
|
14
15
|
_config_mapping: Dict[str, str] = {
|
|
15
16
|
"_recursive_": "_recursive_",
|
|
16
17
|
}
|
|
@@ -99,7 +100,22 @@ class YAMLSerializationMixin:
|
|
|
99
100
|
BaseModelPool: The loaded model pool.
|
|
100
101
|
"""
|
|
101
102
|
config = OmegaConf.load(path)
|
|
102
|
-
|
|
103
|
+
if cls._config_key is not None and cls._config_key in config:
|
|
104
|
+
config = config[cls._config_key]
|
|
105
|
+
target_cls = import_object(config["_target_"])
|
|
106
|
+
if target_cls != cls:
|
|
107
|
+
log.warning(
|
|
108
|
+
f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
|
|
109
|
+
f"Instantiating the class {target_cls.__name__} instead."
|
|
110
|
+
)
|
|
111
|
+
return instantiate(
|
|
112
|
+
config,
|
|
113
|
+
_recursive_=(
|
|
114
|
+
cls._recursive_
|
|
115
|
+
if config.get("_recursive_") is None
|
|
116
|
+
else config.get("_recursive_")
|
|
117
|
+
),
|
|
118
|
+
)
|
|
103
119
|
|
|
104
120
|
def to_config(self):
|
|
105
121
|
"""
|
|
@@ -29,6 +29,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
_program = None
|
|
32
|
+
_config_key = "modelpool"
|
|
32
33
|
_models: Union[DictConfig, Dict[str, nn.Module]]
|
|
33
34
|
_config_mapping = BaseYAMLSerializableModel._config_mapping | {
|
|
34
35
|
"_models": "models",
|
|
@@ -3,6 +3,7 @@ from copy import deepcopy
|
|
|
3
3
|
from typing import Optional, Union
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
6
7
|
from omegaconf import DictConfig, open_dict
|
|
7
8
|
from torch import nn
|
|
8
9
|
from torch.utils.data import Dataset
|
|
@@ -40,7 +41,8 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
40
41
|
def load_processor(self, *args, **kwargs) -> CLIPProcessor:
|
|
41
42
|
assert self._processor is not None, "Processor is not defined in the config"
|
|
42
43
|
if isinstance(self._processor, str):
|
|
43
|
-
|
|
44
|
+
if rank_zero_only.rank == 0:
|
|
45
|
+
log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
|
|
44
46
|
processor = CLIPProcessor.from_pretrained(self._processor)
|
|
45
47
|
else:
|
|
46
48
|
processor = instantiate(self._processor, *args, **kwargs)
|
|
@@ -50,7 +52,8 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
50
52
|
model_config = self._models[model_name]
|
|
51
53
|
|
|
52
54
|
if isinstance(model_config, str):
|
|
53
|
-
|
|
55
|
+
if rank_zero_only.rank == 0:
|
|
56
|
+
log.info(f"Loading `transformers.CLIPModel`: {model_config}")
|
|
54
57
|
clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
|
|
55
58
|
return clip_model
|
|
56
59
|
else:
|
|
@@ -102,10 +105,12 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
102
105
|
):
|
|
103
106
|
model = self._models[model_name_or_config]
|
|
104
107
|
if isinstance(model, str):
|
|
105
|
-
|
|
108
|
+
if rank_zero_only.rank == 0:
|
|
109
|
+
log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
|
|
106
110
|
return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
|
|
107
111
|
if isinstance(model, nn.Module):
|
|
108
|
-
|
|
112
|
+
if rank_zero_only.rank == 0:
|
|
113
|
+
log.info(f"Returning existing model: {model}")
|
|
109
114
|
return model
|
|
110
115
|
|
|
111
116
|
# If the model is not a string, we use the default load_model method
|
|
@@ -114,9 +119,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
114
119
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
115
120
|
dataset_config = self._train_datasets[dataset_name]
|
|
116
121
|
if isinstance(dataset_config, str):
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
122
|
+
if rank_zero_only.rank == 0:
|
|
123
|
+
log.info(
|
|
124
|
+
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
125
|
+
)
|
|
120
126
|
dataset = load_dataset(dataset_config, split="train")
|
|
121
127
|
else:
|
|
122
128
|
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
@@ -125,9 +131,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
125
131
|
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
126
132
|
dataset_config = self._val_datasets[dataset_name]
|
|
127
133
|
if isinstance(dataset_config, str):
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
134
|
+
if rank_zero_only.rank == 0:
|
|
135
|
+
log.info(
|
|
136
|
+
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
137
|
+
)
|
|
131
138
|
dataset = load_dataset(dataset_config, split="validation")
|
|
132
139
|
else:
|
|
133
140
|
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
@@ -136,9 +143,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
136
143
|
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
137
144
|
dataset_config = self._test_datasets[dataset_name]
|
|
138
145
|
if isinstance(dataset_config, str):
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
146
|
+
if rank_zero_only.rank == 0:
|
|
147
|
+
log.info(
|
|
148
|
+
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
149
|
+
)
|
|
142
150
|
dataset = load_dataset(dataset_config, split="test")
|
|
143
151
|
else:
|
|
144
152
|
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from fusion_bench import BaseModelPool
|
|
2
|
+
from fusion_bench.utils import instantiate
|
|
3
|
+
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LazyStateDictPool(BaseModelPool):
|
|
7
|
+
def load_model(self, model_name_or_config: str, *args, **kwargs) -> LazyStateDict:
|
|
8
|
+
if model_name_or_config in self._models:
|
|
9
|
+
checkpoint_config = self._models[model_name_or_config]
|
|
10
|
+
else:
|
|
11
|
+
checkpoint_config = model_name_or_config
|
|
12
|
+
if isinstance(checkpoint_config, str):
|
|
13
|
+
return LazyStateDict(checkpoint_config, *args, **kwargs)
|
|
14
|
+
else:
|
|
15
|
+
return instantiate(checkpoint_config, *args, **kwargs)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/tree/main
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .configuration_deepseek import DeepseekV2Config
|
|
6
|
+
from .modeling_deepseek import (
|
|
7
|
+
DeepseekV2ForCausalLM,
|
|
8
|
+
DeepseekV2ForSequenceClassification,
|
|
9
|
+
DeepseekV2MLP,
|
|
10
|
+
DeepseekV2Model,
|
|
11
|
+
DeepseekV2MoE,
|
|
12
|
+
DeepseekV2DecoderLayer,
|
|
13
|
+
)
|
|
14
|
+
from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
|
|
15
|
+
from .tokenization_deepseek_fast import DeepseekTokenizerFast
|