fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. fusion_bench/method/__init__.py +11 -0
  2. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
  3. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
  4. fusion_bench/method/base_algorithm.py +1 -0
  5. fusion_bench/method/dawe/dawe_for_clip.py +1 -1
  6. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
  7. fusion_bench/method/expert_sparsity/__init__.py +10 -0
  8. fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
  9. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
  10. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
  11. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
  12. fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
  13. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
  14. fusion_bench/method/knots/__init__.py +0 -0
  15. fusion_bench/method/knots/knots_utils.py +23 -0
  16. fusion_bench/method/pwe_moe/module.py +2 -7
  17. fusion_bench/method/simple_average.py +3 -2
  18. fusion_bench/method/task_singular_vector/TSVM.py +238 -25
  19. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
  20. fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
  21. fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
  22. fusion_bench/mixins/hydra_config.py +1 -1
  23. fusion_bench/mixins/lightning_fabric.py +25 -1
  24. fusion_bench/mixins/serialization.py +18 -2
  25. fusion_bench/modelpool/base_pool.py +1 -0
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
  27. fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
  28. fusion_bench/models/__init__.py +1 -0
  29. fusion_bench/models/expert_sparsity/__init__.py +0 -0
  30. fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
  31. fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
  32. fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
  33. fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
  34. fusion_bench/models/parameter_dict.py +6 -1
  35. fusion_bench/programs/fabric_fusion_program.py +21 -13
  36. fusion_bench/taskpool/base_pool.py +1 -0
  37. fusion_bench/taskpool/dummy.py +6 -4
  38. fusion_bench/utils/__init__.py +4 -3
  39. fusion_bench/utils/dtype.py +2 -1
  40. fusion_bench/utils/fabric.py +11 -4
  41. fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
  42. fusion_bench/utils/lazy_state_dict.py +80 -10
  43. fusion_bench/utils/pylogger.py +30 -0
  44. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
  45. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
  46. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
  47. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
  48. fusion_bench_config/fabric_model_fusion.yaml +2 -2
  49. fusion_bench_config/method/expert_sparsity/README.md +6 -0
  50. fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
  51. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
  52. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
  53. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
  54. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
  56. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
  57. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
  59. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ """
2
+ This module contains the code for loading the calibration data.
3
+
4
+ - C4
5
+ - Math
6
+ """
7
+
8
+ import itertools
9
+ import logging
10
+ import os
11
+
12
+ import torch
13
+ import transformers
14
+ from datasets import load_dataset
15
+ from transformers import PreTrainedTokenizer, default_data_collator
16
+ from transformers.testing_utils import CaptureLogger
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ DATASETS = {
23
+ # C4: Please download first part of the C4 training data `c4-train.00000-of-01024.json` from [allenai/c4](https://huggingface.co/datasets/allenai/c4/blob/main/en/c4-train.00000-of-01024.json.gz).
24
+ "c4": lambda: load_dataset(
25
+ "json",
26
+ data_files={
27
+ "train": hf_hub_download(
28
+ "allenai/c4",
29
+ filename="en/c4-train.00000-of-01024.json.gz",
30
+ repo_type="dataset",
31
+ )
32
+ },
33
+ ),
34
+ # MATH: You can use our pre-built calibration set in `./data/math_pretrain_style.json`. To reproduce our construction, please download the training set of [MATH](https://github.com/hendrycks/math) and use our [script](data/math_calib_construction.py).
35
+ # NOTE: I have uploaded the math_pretrain_style.json to my huggingface repo:
36
+ # https://huggingface.co/datasets/tanganke/math_pretrain_style/tree/main.
37
+ "math": lambda: load_dataset(
38
+ "json",
39
+ data_files={
40
+ "train": hf_hub_download(
41
+ "tanganke/math_pretrain_style",
42
+ filename="math_pretrain_style.json",
43
+ repo_type="dataset",
44
+ )
45
+ },
46
+ ),
47
+ }
48
+
49
+
50
+ def build_calib_loader(
51
+ dataset: str,
52
+ tokenizer: PreTrainedTokenizer,
53
+ max_block_size: int,
54
+ n_blocks_for_stat: int,
55
+ batch_size: int,
56
+ num_workers: int,
57
+ seed: int = 42,
58
+ ):
59
+ # dataset can be a string or a dataset object.
60
+ # If it is a string, it can be the name of the dataset in DATASETS or the path to the dataset (a json file).
61
+ if isinstance(dataset, str):
62
+ if dataset in DATASETS:
63
+ all_set = DATASETS[dataset]()
64
+ else:
65
+ assert os.path.exists(dataset), f"Dataset {dataset} not found."
66
+ all_set = load_dataset("json", data_files={"train": dataset})
67
+ else:
68
+ assert dataset is not None, "Dataset is not provided."
69
+ all_set = dataset
70
+
71
+ block_size = tokenizer.model_max_length
72
+ if block_size > max_block_size:
73
+ logger.info(
74
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `max_block_size` value"
75
+ f" of {max_block_size}. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
76
+ " override this default with `--max_block_size xxx`."
77
+ )
78
+ block_size = max_block_size
79
+
80
+ if n_blocks_for_stat > 0: # Random choose `n_blocks_for_stat` blocks
81
+ calib_set = (
82
+ all_set["train"]
83
+ .shuffle(seed=seed)
84
+ .select(range(min(n_blocks_for_stat * 16, len(all_set["train"]))))
85
+ )
86
+ else: # Use the whole set
87
+ logger.warning("n_blocks_for_stat <= 0, using the whole dataset.")
88
+ calib_set = all_set["train"].shuffle(seed=seed)
89
+
90
+ logger.info(f"Calibration dataset: {calib_set}")
91
+ text_column_name = (
92
+ "text" if "text" in calib_set.features else list(calib_set.features)[0]
93
+ )
94
+
95
+ tok_logger = transformers.utils.logging.get_logger(
96
+ "transformers.tokenization_utils_base"
97
+ )
98
+
99
+ def tokenize_function(examples):
100
+ with CaptureLogger(tok_logger) as cl:
101
+ output = tokenizer(examples[text_column_name])
102
+ # clm input could be much much longer than block_size
103
+ if "Token indices sequence length is longer than the" in cl.out:
104
+ tok_logger.warning(
105
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
106
+ " before being passed to the model."
107
+ )
108
+ return output
109
+
110
+ tokenized_calib_set = calib_set.map(
111
+ tokenize_function,
112
+ batched=True,
113
+ remove_columns=list(calib_set.features),
114
+ )
115
+
116
+ def group_texts(examples):
117
+ # Concatenate all texts.
118
+ concatenated_examples = {
119
+ k: list(itertools.chain(*examples[k])) for k in examples.keys()
120
+ }
121
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
122
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
123
+ # customize this part to your needs.
124
+ if total_length >= block_size:
125
+ total_length = (total_length // block_size) * block_size
126
+ # Split by chunks of max_len.
127
+ result = {
128
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
129
+ for k, t in concatenated_examples.items()
130
+ }
131
+ result["labels"] = result["input_ids"].copy()
132
+ return result
133
+
134
+ lm_calib_set = tokenized_calib_set.map(
135
+ group_texts,
136
+ batched=True,
137
+ )
138
+
139
+ if n_blocks_for_stat > 0:
140
+ assert len(lm_calib_set) > n_blocks_for_stat
141
+ lm_calib_set = lm_calib_set.select(range(n_blocks_for_stat))
142
+
143
+ calib_loader = torch.utils.data.DataLoader(
144
+ lm_calib_set,
145
+ batch_size=batch_size,
146
+ num_workers=num_workers,
147
+ pin_memory=True,
148
+ drop_last=False,
149
+ shuffle=False,
150
+ collate_fn=default_data_collator,
151
+ )
152
+
153
+ return calib_loader
@@ -32,7 +32,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
32
32
  get_layer_wise_weights,
33
33
  )
34
34
  from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
35
- from fusion_bench.utils.instantiate import instantiate
35
+ from fusion_bench.utils.instantiate_utils import instantiate
36
36
 
37
37
  from .entropy_loss import entropy_loss
38
38
  from .layer_wise_gossip import ModelScheduler
File without changes
@@ -0,0 +1,23 @@
1
+ import torch
2
+
3
+
4
+ def subspace_alignment(
5
+ delta_weights: list[torch.Tensor],
6
+ svd_dtype: torch.dtype | None = torch.float64,
7
+ eps: float = 1e-4,
8
+ ):
9
+ """
10
+ Reference: Model merging with SVD to tie the Knots. http://arxiv.org/abs/2410.19735
11
+ """
12
+ if svd_dtype is None:
13
+ svd_dtype = delta_weights[0].dtype
14
+ original_dtype = delta_weights[0].dtype
15
+ output_dim, input_dim = delta_weights[0].size()
16
+ concat_task_vector = torch.cat(delta_weights, dim=1)
17
+ U, S, Vh = torch.linalg.svd(concat_task_vector.to(svd_dtype), full_matrices=False)
18
+ # Keep only supported basis components
19
+ U = U[:, S > eps].to(original_dtype)
20
+ Vh = Vh[S > eps].to(original_dtype)
21
+ S = S[S > eps].to(original_dtype)
22
+ Vhs = torch.split(Vh, input_dim, dim=1)
23
+ return U, S, Vhs
@@ -13,14 +13,9 @@ import torch.func
13
13
  from torch import Tensor, nn
14
14
  from torch.nn import functional as F
15
15
 
16
- log = logging.getLogger(__name__)
17
-
16
+ from fusion_bench.utils import join_list
18
17
 
19
- def join_list(list_of_list: List[List]):
20
- ans = []
21
- for item in list_of_list:
22
- ans.extend(item)
23
- return ans
18
+ log = logging.getLogger(__name__)
24
19
 
25
20
 
26
21
  class PWEMoEGate(nn.Module):
@@ -39,12 +39,13 @@ def simple_average(
39
39
  >>> import torch.nn as nn
40
40
  >>> model1 = nn.Linear(10, 10)
41
41
  >>> model2 = nn.Linear(10, 10)
42
- >>> averaged_model = simple_averageing([model1, model2])
42
+ >>> averaged_model = simple_average([model1, model2])
43
43
 
44
44
  >>> state_dict1 = model1.state_dict()
45
45
  >>> state_dict2 = model2.state_dict()
46
- >>> averaged_state_dict = simple_averageing([state_dict1, state_dict2])
46
+ >>> averaged_state_dict = simple_average([state_dict1, state_dict2])
47
47
  """
48
+ assert len(modules) > 0, "modules must be a non-empty list"
48
49
  if isinstance(modules[0], nn.Module):
49
50
  if base_module is None:
50
51
  new_module = deepcopy(modules[0])
@@ -1,5 +1,31 @@
1
- """
2
- Example:
1
+ R"""
2
+ # Task Singular Vector Merging (TSVM) Algorithm Implementation
3
+
4
+ This module implements the Task Singular Vector Merging algorithm for combining multiple fine-tuned models
5
+ into a single unified model.
6
+
7
+ ## Example Usage:
8
+
9
+ Merge 8 CLIP-ViT-B/32 models with TSVM:
10
+
11
+ ```bash
12
+ fusion_bench \
13
+ method=task_singular_vector/TaskSingularVectorMerging \
14
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
15
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
16
+ ```
17
+
18
+ Merge 8 CLIP-ViT-B/32 models with TSVM and return individual transformed models:
19
+
20
+ ```bash
21
+ fusion_bench \
22
+ method=task_singular_vector/TaskSingularVectorMerging \
23
+ method.return_single_task_models=true \
24
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
25
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
26
+ ```
27
+
28
+ Merge 20 CLIP-VIT-B/32 models with TSVM:
3
29
 
4
30
  ```bash
5
31
  fusion_bench \
@@ -7,14 +33,22 @@ fusion_bench \
7
33
  modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only \
8
34
  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TALL20
9
35
  ```
36
+
37
+ ## References:
38
+
39
+ - Gargiulo, et al. Task Singular Vectors: Reducing Task Interference in Model Merging.
40
+ https://arxiv.org/abs/2412.00081
41
+ - See `docs/algorithms/task_singular_vector.md` for more details.
10
42
  """
11
43
 
44
+ from copy import deepcopy
12
45
  from typing import Iterable, List, Optional, Union
13
46
 
14
47
  import torch
15
48
  from omegaconf import ListConfig
16
49
  from torch import Tensor, nn
17
50
 
51
+ import fusion_bench as fb
18
52
  from fusion_bench import BaseAlgorithm
19
53
  from fusion_bench.mixins import LightningFabricMixin
20
54
  from fusion_bench.utils import timeit_context
@@ -35,49 +69,228 @@ from .utils import (
35
69
 
36
70
 
37
71
  class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
72
+ """
73
+ Task Singular Vector Merging (TSVM) Algorithm
74
+
75
+ This class implements a model merging technique that leverages Singular Value
76
+ Decomposition (SVD) to identify and combine the most important directions in the task vector
77
+ space. The algorithm is particularly effective for merging multiple models fine-tuned on
78
+ different tasks while preserving their essential capabilities.
79
+
80
+ Key Concepts:
81
+ - Task Vector: The difference between a fine-tuned model and its pretrained base model,
82
+ representing the knowledge gained during fine-tuning for a specific task.
83
+ - Singular Value Decomposition: A matrix factorization technique used to find the principal
84
+ components (most important directions) in the space of task vectors.
85
+ - Model Merging: The process of combining multiple models into a single unified model that
86
+ retains capabilities from all constituent models.
87
+
88
+ Algorithm Steps:
89
+ 1. Extract task vectors from all fine-tuned models by subtracting the pretrained model
90
+ 2. Apply SVD to the matrix of task vectors to find principal components
91
+ 3. Reconstruct task vectors using only the most significant singular vectors
92
+ 4. Merge the reconstructed task vectors (either individually scaled or as a sum)
93
+ 5. Add the final merged task vector to the pretrained model to create the unified model
94
+
95
+ see `docs/algorithms/task_singular_vector.md` for comprehensive algorithmic details.
96
+ """
38
97
 
39
98
  def __init__(
40
99
  self,
41
- alpha: Union[float, Iterable[float]] = None,
42
- remove_keys: Optional[List[str]] = None,
100
+ alpha: Optional[Union[float, Iterable[float]]] = None,
101
+ exclude_keys: Optional[List[str]] = None,
102
+ return_single_task_models: bool = False,
43
103
  **kwargs,
44
104
  ):
105
+ """
106
+ Initialize the Task Singular Vector Merging algorithm.
107
+
108
+ Args:
109
+ alpha (Union[float, Iterable[float]], optional): Scaling factor(s) for task vectors.
110
+ This parameter controls the strength of the task-specific adaptations in the final model.
111
+
112
+ - If a single float: Applied to the final merged task vector after SVD reconstruction.
113
+ This uniformly scales the entire merged adaptation.
114
+
115
+ - If an iterable of floats: Applied to individual task vectors before SVD and merging.
116
+ Must have the same length as the number of models in the modelpool.
117
+ This allows for task-specific weighting (e.g., giving more importance to certain tasks).
118
+
119
+ - If None: No scaling is applied (equivalent to alpha=1.0).
120
+
121
+ Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.
122
+
123
+ exclude_keys (Optional[List[str]], optional): List of parameter names to exclude from TSVM.
124
+ These parameters will not participate in the SVD computation and merging process.
125
+ Useful for excluding certain layers (e.g., task-specific heads, normalization layers)
126
+ that should not be merged across tasks. Defaults to an empty list.
127
+
128
+ Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.
129
+
130
+ return_single_task_models (bool, optional): Whether to return individual transformed models.
131
+
132
+ - If True: Returns a dictionary containing both individual models with their transformed
133
+ task vectors applied AND the final merged model. The dictionary has the structure:
134
+
135
+ >>> {'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}
136
+
137
+ - If False: Returns only the final merged model.
138
+
139
+ This is useful for analysis or when you need access to intermediate results.
140
+ Defaults to False.
141
+
142
+ **kwargs: Additional arguments passed to the parent BaseAlgorithm class.
143
+
144
+ Note:
145
+ The choice between single alpha vs. list of alphas affects the merging strategy:
146
+ - Single alpha: SVD is applied first, then the result is scaled
147
+ - List of alphas: Individual task vectors are scaled first, then SVD is applied
148
+ """
45
149
  self.alpha = alpha
46
- self.remove_keys = remove_keys if remove_keys is not None else []
150
+ self.exclude_keys = exclude_keys if exclude_keys is not None else []
151
+ self.return_single_task_models = return_single_task_models
47
152
  super().__init__(**kwargs)
48
153
 
49
- def run(self, modelpool):
50
- # Load the pre-trained model and the fine-tuned models
154
+ def load_pretrained_model_and_task_vectors(self, modelpool: fb.BaseModelPool):
155
+ """
156
+ Load the pretrained base model and compute task vectors from all fine-tuned models.
157
+
158
+ This method performs the initial step of the TSVM algorithm by:
159
+ 1. Loading the original pretrained model (before any task-specific fine-tuning)
160
+ 2. For each fine-tuned model in the pool:
161
+ - Load the fine-tuned model
162
+ - Compute the task vector (fine-tuned params - pretrained params)
163
+ - Optionally apply individual scaling if alpha is provided as a list
164
+
165
+ Task vectors represent the knowledge gained during fine-tuning and are the core
166
+ data structure that TSVM operates on.
167
+
168
+ Args:
169
+ modelpool (fb.BaseModelPool): Pool containing the pretrained model and all
170
+ fine-tuned models to be merged.
171
+
172
+ Returns:
173
+ tuple: A tuple containing:
174
+ - pretrained_model: The original pretrained model (torch.nn.Module)
175
+ - task_vectors: List of task vectors (List[StateDictType]), where each
176
+ task vector is a state dictionary representing the parameter differences
177
+ for one specific task
178
+ """
179
+ # Load the original pretrained model that serves as the base for all fine-tuned variants
51
180
  pretrained_model = modelpool.load_pretrained_model()
52
- finetuned_models = list(modelpool.models())
53
181
 
54
- ptm_check = pretrained_model.state_dict()
55
- ft_checks = [model.state_dict() for model in finetuned_models]
56
- check_parameterNamesMatch(ft_checks + [ptm_check])
182
+ # Initialize list to store computed task vectors
183
+ task_vectors = []
184
+
185
+ # Process each fine-tuned model in the modelpool
186
+ for model_idx, model_name in enumerate(modelpool.model_names):
187
+ # Load the current fine-tuned model
188
+ finetuned_model = modelpool.load_model(model_name)
189
+
190
+ # Compute task vector: difference between fine-tuned and pretrained parameters
191
+ # This captures the task-specific adaptations learned during fine-tuning
192
+ task_vector = state_dict_sub(
193
+ finetuned_model.state_dict(), pretrained_model.state_dict()
194
+ )
195
+ task_vectors.append(task_vector)
57
196
 
58
- with timeit_context("Flattening out Checkpoints"):
59
- task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
60
- if isinstance(self.alpha, Iterable):
197
+ # Apply individual scaling to task vectors if alpha is provided as a list
198
+ # This allows for task-specific weighting before the SVD computation
199
+ if self.alpha is not None and isinstance(self.alpha, Iterable):
200
+ # Ensure the number of alpha values matches the number of models
61
201
  assert len(self.alpha) == len(
62
- task_vectors
63
- ), "Alpha and task vectors must have the same length"
64
- task_vectors = [
65
- state_dict_mul(state_dict=tv, scalar=alpha)
66
- for alpha, tv in zip(self.alpha, task_vectors)
67
- ]
202
+ modelpool.model_names
203
+ ), f"Alpha list length ({len(self.alpha)}) must match number of models ({len(modelpool.model_names)})"
68
204
 
205
+ # Scale the current task vector by its corresponding alpha value
206
+ task_vectors[-1] = state_dict_mul(
207
+ state_dict=task_vectors[-1], scalar=self.alpha[model_idx]
208
+ )
209
+
210
+ return pretrained_model, task_vectors
211
+
212
+ def run(self, modelpool: fb.BaseModelPool):
213
+ """
214
+ Execute the complete Task Singular Vector Merging algorithm.
215
+
216
+ This is the main entry point that orchestrates the entire TSVM process:
217
+
218
+ The algorithm leverages the mathematical insight that task vectors often lie in a
219
+ lower-dimensional subspace, and SVD helps identify the most important directions
220
+ in this subspace while filtering out noise and interference.
221
+
222
+ Args:
223
+ modelpool (fb.BaseModelPool): Pool of models to merge, including:
224
+ - The pretrained base model
225
+ - Multiple fine-tuned models (one per task)
226
+ All models must have compatible architectures.
227
+
228
+ Returns:
229
+ Union[torch.nn.Module, Dict[str, torch.nn.Module]]:
230
+ - If return_single_task_models=False: Returns the merged model
231
+ - If return_single_task_models=True: Returns a dictionary with:
232
+ * Individual transformed models keyed by their original names
233
+ * Final merged model under the key 'merged'
234
+
235
+ Raises:
236
+ AssertionError: If alpha list length doesn't match the number of models
237
+ """
238
+ # Determine the compute device for SVD operations (GPU if available for faster computation)
239
+ accelerator = self.fabric.device
240
+
241
+ # Phase 1: Load pretrained model and compute task vectors from all fine-tuned models
242
+ pretrained_model, task_vectors = self.load_pretrained_model_and_task_vectors(
243
+ modelpool
244
+ )
245
+
246
+ # Phase 2: Apply SVD-based merging to the task vectors
247
+ # This is the core of the TSVM algorithm where:
248
+ # - Task vectors are organized into a matrix
249
+ # - SVD finds the principal components (most important directions)
250
+ # - Task vectors are reconstructed using only the most significant components
251
+ # - The reconstructed vectors are merged (summed) to create a unified task vector
69
252
  new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
70
253
  task_vectors,
71
- exclude_keys=self.remove_keys,
72
- accelerator=self.fabric.device,
254
+ exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
255
+ accelerator=accelerator, # Use GPU if available
256
+ return_single_task_models=self.return_single_task_models,
73
257
  )
74
258
 
75
- # If alpha is a float, we need to scale the new merged task vector by alpha
76
- if self.alpha is not None and isinstance(self.alpha, float):
259
+ # Handle the case where individual transformed task vectors are also returned
260
+ if self.return_single_task_models:
261
+ new_merged_tv, single_task_models = new_merged_tv
262
+
263
+ # Phase 3: Apply global scaling to the merged task vector (if alpha is a single value)
264
+ # This is different from individual scaling applied earlier - here we scale the
265
+ # final merged result, which affects the overall strength of all merged adaptations
266
+ if self.alpha is not None and isinstance(self.alpha, (float, int)):
77
267
  print(f"Scaling new merged task vector by alpha: {self.alpha}")
78
268
  new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
79
269
 
270
+ # Phase 4: Prepare individual transformed models if requested
271
+ if self.return_single_task_models:
272
+ models = {}
273
+ # Create individual models by adding each transformed task vector to the pretrained base
274
+ for model_idx, model_name in enumerate(modelpool.model_names):
275
+ # Create a deep copy to avoid modifying the original pretrained model
276
+ model = deepcopy(pretrained_model)
277
+ # Apply the transformed task vector to get the individual model
278
+ model.load_state_dict(
279
+ state_dict_add(model.state_dict(), single_task_models[model_idx])
280
+ )
281
+ models[model_name] = model
282
+
283
+ # Phase 5: Create the final merged model by adding the merged task vector to pretrained model
284
+ # This produces a single model that combines capabilities from all input models
80
285
  pretrained_model.load_state_dict(
81
286
  state_dict_add(new_merged_tv, pretrained_model.state_dict())
82
287
  )
83
- return pretrained_model
288
+
289
+ # Phase 6: Return results based on the requested output format
290
+ if self.return_single_task_models:
291
+ # Include the final merged model in the dictionary of results
292
+ models["merged"] = pretrained_model
293
+ return models
294
+ else:
295
+ # Return only the merged model
296
+ return pretrained_model
@@ -1,7 +1,9 @@
1
+ import collections
1
2
  import math
2
- from typing import List, Optional
3
+ from typing import Dict, List, Optional
3
4
 
4
5
  import torch
6
+ from torch import Tensor, nn
5
7
 
6
8
  from fusion_bench.utils.type import StateDictType
7
9
 
@@ -314,7 +316,8 @@ def compute_and_sum_svd_mem_reduction(
314
316
  task_vectors: List[StateDictType],
315
317
  exclude_keys: Optional[List[str]] = None,
316
318
  accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
317
- ) -> StateDictType:
319
+ return_single_task_models: bool = False,
320
+ ):
318
321
  """
319
322
  Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
320
323
  reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
@@ -326,23 +329,25 @@ def compute_and_sum_svd_mem_reduction(
326
329
  dictionary of vectors.
327
330
  exclude_keys (list): A list of keys to exclude from the TSVM.
328
331
  accelerator (torch.device): The device to use for the computation.
332
+ return_single_task_models (bool): Whether to return the single task models after the TSVM.
329
333
 
330
334
  Returns:
331
335
  dict: A dictionary containing the new vectors after SVD computation and merging.
332
336
  """
333
337
  if exclude_keys is None:
334
338
  exclude_keys = []
335
- sv_reduction = 1 / len(task_vectors)
339
+ num_tasks = len(task_vectors)
340
+ sv_reduction = 1 / num_tasks
336
341
 
337
- new_vector = {}
342
+ new_vector: Dict[str, Tensor] = {}
343
+ if return_single_task_models:
344
+ single_task_models = [{} for _ in range(num_tasks)]
338
345
  for key in task_vectors[0]:
339
346
  original_device = task_vectors[0][key].device
340
347
  original_dtype = task_vectors[0][key].dtype
341
348
 
342
- new_vector[key] = {}
343
349
  for i, task_vector in enumerate(task_vectors):
344
- vec = task_vector[key].to(accelerator)
345
-
350
+ vec = task_vector[key].to(device=accelerator, non_blocking=True)
346
351
  if len(task_vector[key].shape) == 2 and key not in exclude_keys:
347
352
  # at current, the SVD is not supported for half precision, so we need to convert to float32
348
353
  if not (
@@ -350,13 +355,14 @@ def compute_and_sum_svd_mem_reduction(
350
355
  ):
351
356
  vec = vec.to(dtype=torch.float32)
352
357
 
353
- u, s, v = torch.linalg.svd(vec, full_matrices=False)
358
+ # vec = u @ torch.diag(s) @ vh
359
+ u, s, vh = torch.linalg.svd(vec, full_matrices=False)
354
360
 
355
361
  if i == 0:
356
362
  print(f"Computed SVD for {key}...")
357
363
  sum_u = torch.zeros_like(u, device=accelerator)
358
364
  sum_s = torch.zeros_like(s, device=accelerator)
359
- sum_v = torch.zeros_like(v, device=accelerator)
365
+ sum_vh = torch.zeros_like(vh, device=accelerator)
360
366
  reduced_index_s = int(s.shape[0] * sv_reduction)
361
367
 
362
368
  # select only the first reduced_index_s columns of u and place them
@@ -367,10 +373,9 @@ def compute_and_sum_svd_mem_reduction(
367
373
  :reduced_index_s
368
374
  ]
369
375
  # select only the first reduced_index_s rows of v and place them
370
- sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
376
+ sum_vh[i * reduced_index_s : (i + 1) * reduced_index_s, :] = vh[
371
377
  :reduced_index_s, :
372
378
  ]
373
-
374
379
  else:
375
380
  # if the vector is not a 2D tensor or is in exclude_keys, compute the mean
376
381
  if i == 0:
@@ -379,22 +384,49 @@ def compute_and_sum_svd_mem_reduction(
379
384
  new_vector[key] += (vec - new_vector[key]) / (i + 1)
380
385
 
381
386
  if len(task_vector[key].shape) == 2 and key not in exclude_keys:
382
- u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
383
- u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
387
+ u_u, s_u, vh_u = torch.linalg.svd(sum_u, full_matrices=False)
388
+ u_vh, s_vh, vh_vh = torch.linalg.svd(sum_vh, full_matrices=False)
384
389
 
385
390
  new_vector[key] = torch.linalg.multi_dot(
386
391
  (
387
392
  u_u,
388
- v_u,
393
+ vh_u,
389
394
  torch.diag(sum_s),
390
- u_v,
391
- v_v,
395
+ u_vh,
396
+ vh_vh,
392
397
  )
393
398
  )
394
- new_vector[key] = new_vector[key].to(
395
- device=original_device, dtype=original_dtype, non_blocking=True
396
- )
397
- return new_vector
399
+ new_vector[key] = new_vector[key].to(
400
+ device=original_device, dtype=original_dtype, non_blocking=True
401
+ )
402
+ if return_single_task_models:
403
+ reduced_index_s = int(sum_s.shape[0] * sv_reduction)
404
+ new_u = u_u @ vh_u
405
+ new_vh = u_vh @ vh_vh
406
+ for i in range(num_tasks):
407
+ single_task_models[i][key] = torch.linalg.multi_dot(
408
+ (
409
+ new_u[:, i * reduced_index_s : (i + 1) * reduced_index_s],
410
+ torch.diag(
411
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s]
412
+ ),
413
+ new_vh[i * reduced_index_s : (i + 1) * reduced_index_s, :],
414
+ )
415
+ ).to(
416
+ device=original_device, dtype=original_dtype, non_blocking=True
417
+ )
418
+ else:
419
+ new_vector[key] = new_vector[key].to(
420
+ device=original_device, dtype=original_dtype, non_blocking=True
421
+ )
422
+ if return_single_task_models:
423
+ for i in range(num_tasks):
424
+ single_task_models[i][key] = new_vector[key].clone()
425
+
426
+ if not return_single_task_models:
427
+ return new_vector
428
+ else:
429
+ return new_vector, single_task_models
398
430
 
399
431
 
400
432
  ###############
@@ -5,3 +5,4 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
5
5
  from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
6
6
 
7
7
  from . import TSVC_utils, TSVM_utils
8
+ from .task_singular_interference import compute_task_singular_interference