fusion-bench 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
  3. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
  4. fusion_bench/method/base_algorithm.py +1 -0
  5. fusion_bench/method/dawe/dawe_for_clip.py +1 -1
  6. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
  7. fusion_bench/method/fw_merging/__init__.py +2 -0
  8. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  9. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  10. fusion_bench/method/fw_merging/utils.py +331 -0
  11. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
  12. fusion_bench/method/moe_pruner/__init__.py +7 -0
  13. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  14. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  15. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  16. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  17. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  18. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  19. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  20. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  21. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  22. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  23. fusion_bench/method/pruning/__init__.py +1 -0
  24. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  25. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  26. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  27. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  28. fusion_bench/method/pwe_moe/module.py +2 -7
  29. fusion_bench/method/randes/__init__.py +15 -0
  30. fusion_bench/method/randes/base_algorithm.py +1013 -0
  31. fusion_bench/method/randes/modelsoup.py +126 -0
  32. fusion_bench/method/randes/task_arithmetic.py +318 -0
  33. fusion_bench/method/simple_average.py +3 -2
  34. fusion_bench/method/sparselo/sparselo.py +20 -2
  35. fusion_bench/method/tall_mask/__init__.py +1 -0
  36. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  37. fusion_bench/method/task_singular_vector/TSVM.py +238 -25
  38. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
  39. fusion_bench/mixins/hydra_config.py +1 -1
  40. fusion_bench/mixins/lightning_fabric.py +25 -1
  41. fusion_bench/mixins/serialization.py +18 -2
  42. fusion_bench/modelpool/base_pool.py +1 -0
  43. fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
  44. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  45. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  46. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  47. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  48. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  49. fusion_bench/models/parameter_dict.py +6 -1
  50. fusion_bench/programs/fabric_fusion_program.py +14 -5
  51. fusion_bench/taskpool/base_pool.py +1 -0
  52. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  53. fusion_bench/taskpool/dummy.py +6 -4
  54. fusion_bench/utils/__init__.py +2 -1
  55. fusion_bench/utils/data.py +1 -1
  56. fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
  57. fusion_bench/utils/lazy_state_dict.py +268 -0
  58. fusion_bench/utils/parameters.py +33 -0
  59. fusion_bench/utils/pylogger.py +28 -0
  60. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  61. fusion_bench/utils/type.py +1 -0
  62. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/METADATA +8 -2
  63. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/RECORD +104 -44
  64. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  66. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  67. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  68. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  69. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  70. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  71. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  72. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  73. fusion_bench_config/fabric_model_fusion.yaml +2 -2
  74. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  75. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  76. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  77. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  78. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  79. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  80. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  81. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  82. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  83. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  84. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  85. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
  86. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  87. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  88. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  89. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  90. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  91. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  92. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  93. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  94. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
  95. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
  96. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
  97. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
  98. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  99. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  100. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  101. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
  102. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/entry_points.txt +0 -0
  103. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/licenses/LICENSE +0 -0
  104. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,31 @@
1
- """
2
- Example:
1
+ R"""
2
+ # Task Singular Vector Merging (TSVM) Algorithm Implementation
3
+
4
+ This module implements the Task Singular Vector Merging algorithm for combining multiple fine-tuned models
5
+ into a single unified model.
6
+
7
+ ## Example Usage:
8
+
9
+ Merge 8 CLIP-ViT-B/32 models with TSVM:
10
+
11
+ ```bash
12
+ fusion_bench \
13
+ method=task_singular_vector/TaskSingularVectorMerging \
14
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
15
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
16
+ ```
17
+
18
+ Merge 8 CLIP-ViT-B/32 models with TSVM and return individual transformed models:
19
+
20
+ ```bash
21
+ fusion_bench \
22
+ method=task_singular_vector/TaskSingularVectorMerging \
23
+ method.return_single_task_models=true \
24
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
25
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
26
+ ```
27
+
28
+ Merge 20 CLIP-VIT-B/32 models with TSVM:
3
29
 
4
30
  ```bash
5
31
  fusion_bench \
@@ -7,14 +33,22 @@ fusion_bench \
7
33
  modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only \
8
34
  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TALL20
9
35
  ```
36
+
37
+ ## References:
38
+
39
+ - Gargiulo, et al. Task Singular Vectors: Reducing Task Interference in Model Merging.
40
+ https://arxiv.org/abs/2412.00081
41
+ - See `docs/algorithms/task_singular_vector.md` for more details.
10
42
  """
11
43
 
44
+ from copy import deepcopy
12
45
  from typing import Iterable, List, Optional, Union
13
46
 
14
47
  import torch
15
48
  from omegaconf import ListConfig
16
49
  from torch import Tensor, nn
17
50
 
51
+ import fusion_bench as fb
18
52
  from fusion_bench import BaseAlgorithm
19
53
  from fusion_bench.mixins import LightningFabricMixin
20
54
  from fusion_bench.utils import timeit_context
@@ -35,49 +69,228 @@ from .utils import (
35
69
 
36
70
 
37
71
  class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
72
+ """
73
+ Task Singular Vector Merging (TSVM) Algorithm
74
+
75
+ This class implements a model merging technique that leverages Singular Value
76
+ Decomposition (SVD) to identify and combine the most important directions in the task vector
77
+ space. The algorithm is particularly effective for merging multiple models fine-tuned on
78
+ different tasks while preserving their essential capabilities.
79
+
80
+ Key Concepts:
81
+ - Task Vector: The difference between a fine-tuned model and its pretrained base model,
82
+ representing the knowledge gained during fine-tuning for a specific task.
83
+ - Singular Value Decomposition: A matrix factorization technique used to find the principal
84
+ components (most important directions) in the space of task vectors.
85
+ - Model Merging: The process of combining multiple models into a single unified model that
86
+ retains capabilities from all constituent models.
87
+
88
+ Algorithm Steps:
89
+ 1. Extract task vectors from all fine-tuned models by subtracting the pretrained model
90
+ 2. Apply SVD to the matrix of task vectors to find principal components
91
+ 3. Reconstruct task vectors using only the most significant singular vectors
92
+ 4. Merge the reconstructed task vectors (either individually scaled or as a sum)
93
+ 5. Add the final merged task vector to the pretrained model to create the unified model
94
+
95
+ see `docs/algorithms/task_singular_vector.md` for comprehensive algorithmic details.
96
+ """
38
97
 
39
98
  def __init__(
40
99
  self,
41
- alpha: Union[float, Iterable[float]] = None,
42
- remove_keys: Optional[List[str]] = None,
100
+ alpha: Optional[Union[float, Iterable[float]]] = None,
101
+ exclude_keys: Optional[List[str]] = None,
102
+ return_single_task_models: bool = False,
43
103
  **kwargs,
44
104
  ):
105
+ """
106
+ Initialize the Task Singular Vector Merging algorithm.
107
+
108
+ Args:
109
+ alpha (Union[float, Iterable[float]], optional): Scaling factor(s) for task vectors.
110
+ This parameter controls the strength of the task-specific adaptations in the final model.
111
+
112
+ - If a single float: Applied to the final merged task vector after SVD reconstruction.
113
+ This uniformly scales the entire merged adaptation.
114
+
115
+ - If an iterable of floats: Applied to individual task vectors before SVD and merging.
116
+ Must have the same length as the number of models in the modelpool.
117
+ This allows for task-specific weighting (e.g., giving more importance to certain tasks).
118
+
119
+ - If None: No scaling is applied (equivalent to alpha=1.0).
120
+
121
+ Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.
122
+
123
+ exclude_keys (Optional[List[str]], optional): List of parameter names to exclude from TSVM.
124
+ These parameters will not participate in the SVD computation and merging process.
125
+ Useful for excluding certain layers (e.g., task-specific heads, normalization layers)
126
+ that should not be merged across tasks. Defaults to an empty list.
127
+
128
+ Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.
129
+
130
+ return_single_task_models (bool, optional): Whether to return individual transformed models.
131
+
132
+ - If True: Returns a dictionary containing both individual models with their transformed
133
+ task vectors applied AND the final merged model. The dictionary has the structure:
134
+
135
+ >>> {'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}
136
+
137
+ - If False: Returns only the final merged model.
138
+
139
+ This is useful for analysis or when you need access to intermediate results.
140
+ Defaults to False.
141
+
142
+ **kwargs: Additional arguments passed to the parent BaseAlgorithm class.
143
+
144
+ Note:
145
+ The choice between single alpha vs. list of alphas affects the merging strategy:
146
+ - Single alpha: SVD is applied first, then the result is scaled
147
+ - List of alphas: Individual task vectors are scaled first, then SVD is applied
148
+ """
45
149
  self.alpha = alpha
46
- self.remove_keys = remove_keys if remove_keys is not None else []
150
+ self.exclude_keys = exclude_keys if exclude_keys is not None else []
151
+ self.return_single_task_models = return_single_task_models
47
152
  super().__init__(**kwargs)
48
153
 
49
- def run(self, modelpool):
50
- # Load the pre-trained model and the fine-tuned models
154
+ def load_pretrained_model_and_task_vectors(self, modelpool: fb.BaseModelPool):
155
+ """
156
+ Load the pretrained base model and compute task vectors from all fine-tuned models.
157
+
158
+ This method performs the initial step of the TSVM algorithm by:
159
+ 1. Loading the original pretrained model (before any task-specific fine-tuning)
160
+ 2. For each fine-tuned model in the pool:
161
+ - Load the fine-tuned model
162
+ - Compute the task vector (fine-tuned params - pretrained params)
163
+ - Optionally apply individual scaling if alpha is provided as a list
164
+
165
+ Task vectors represent the knowledge gained during fine-tuning and are the core
166
+ data structure that TSVM operates on.
167
+
168
+ Args:
169
+ modelpool (fb.BaseModelPool): Pool containing the pretrained model and all
170
+ fine-tuned models to be merged.
171
+
172
+ Returns:
173
+ tuple: A tuple containing:
174
+ - pretrained_model: The original pretrained model (torch.nn.Module)
175
+ - task_vectors: List of task vectors (List[StateDictType]), where each
176
+ task vector is a state dictionary representing the parameter differences
177
+ for one specific task
178
+ """
179
+ # Load the original pretrained model that serves as the base for all fine-tuned variants
51
180
  pretrained_model = modelpool.load_pretrained_model()
52
- finetuned_models = list(modelpool.models())
53
181
 
54
- ptm_check = pretrained_model.state_dict()
55
- ft_checks = [model.state_dict() for model in finetuned_models]
56
- check_parameterNamesMatch(ft_checks + [ptm_check])
182
+ # Initialize list to store computed task vectors
183
+ task_vectors = []
184
+
185
+ # Process each fine-tuned model in the modelpool
186
+ for model_idx, model_name in enumerate(modelpool.model_names):
187
+ # Load the current fine-tuned model
188
+ finetuned_model = modelpool.load_model(model_name)
189
+
190
+ # Compute task vector: difference between fine-tuned and pretrained parameters
191
+ # This captures the task-specific adaptations learned during fine-tuning
192
+ task_vector = state_dict_sub(
193
+ finetuned_model.state_dict(), pretrained_model.state_dict()
194
+ )
195
+ task_vectors.append(task_vector)
57
196
 
58
- with timeit_context("Flattening out Checkpoints"):
59
- task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
60
- if isinstance(self.alpha, Iterable):
197
+ # Apply individual scaling to task vectors if alpha is provided as a list
198
+ # This allows for task-specific weighting before the SVD computation
199
+ if self.alpha is not None and isinstance(self.alpha, Iterable):
200
+ # Ensure the number of alpha values matches the number of models
61
201
  assert len(self.alpha) == len(
62
- task_vectors
63
- ), "Alpha and task vectors must have the same length"
64
- task_vectors = [
65
- state_dict_mul(state_dict=tv, scalar=alpha)
66
- for alpha, tv in zip(self.alpha, task_vectors)
67
- ]
202
+ modelpool.model_names
203
+ ), f"Alpha list length ({len(self.alpha)}) must match number of models ({len(modelpool.model_names)})"
68
204
 
205
+ # Scale the current task vector by its corresponding alpha value
206
+ task_vectors[-1] = state_dict_mul(
207
+ state_dict=task_vectors[-1], scalar=self.alpha[model_idx]
208
+ )
209
+
210
+ return pretrained_model, task_vectors
211
+
212
+ def run(self, modelpool: fb.BaseModelPool):
213
+ """
214
+ Execute the complete Task Singular Vector Merging algorithm.
215
+
216
+ This is the main entry point that orchestrates the entire TSVM process:
217
+
218
+ The algorithm leverages the mathematical insight that task vectors often lie in a
219
+ lower-dimensional subspace, and SVD helps identify the most important directions
220
+ in this subspace while filtering out noise and interference.
221
+
222
+ Args:
223
+ modelpool (fb.BaseModelPool): Pool of models to merge, including:
224
+ - The pretrained base model
225
+ - Multiple fine-tuned models (one per task)
226
+ All models must have compatible architectures.
227
+
228
+ Returns:
229
+ Union[torch.nn.Module, Dict[str, torch.nn.Module]]:
230
+ - If return_single_task_models=False: Returns the merged model
231
+ - If return_single_task_models=True: Returns a dictionary with:
232
+ * Individual transformed models keyed by their original names
233
+ * Final merged model under the key 'merged'
234
+
235
+ Raises:
236
+ AssertionError: If alpha list length doesn't match the number of models
237
+ """
238
+ # Determine the compute device for SVD operations (GPU if available for faster computation)
239
+ accelerator = self.fabric.device
240
+
241
+ # Phase 1: Load pretrained model and compute task vectors from all fine-tuned models
242
+ pretrained_model, task_vectors = self.load_pretrained_model_and_task_vectors(
243
+ modelpool
244
+ )
245
+
246
+ # Phase 2: Apply SVD-based merging to the task vectors
247
+ # This is the core of the TSVM algorithm where:
248
+ # - Task vectors are organized into a matrix
249
+ # - SVD finds the principal components (most important directions)
250
+ # - Task vectors are reconstructed using only the most significant components
251
+ # - The reconstructed vectors are merged (summed) to create a unified task vector
69
252
  new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
70
253
  task_vectors,
71
- exclude_keys=self.remove_keys,
72
- accelerator=self.fabric.device,
254
+ exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
255
+ accelerator=accelerator, # Use GPU if available
256
+ return_single_task_models=self.return_single_task_models,
73
257
  )
74
258
 
75
- # If alpha is a float, we need to scale the new merged task vector by alpha
76
- if self.alpha is not None and isinstance(self.alpha, float):
259
+ # Handle the case where individual transformed task vectors are also returned
260
+ if self.return_single_task_models:
261
+ new_merged_tv, single_task_models = new_merged_tv
262
+
263
+ # Phase 3: Apply global scaling to the merged task vector (if alpha is a single value)
264
+ # This is different from individual scaling applied earlier - here we scale the
265
+ # final merged result, which affects the overall strength of all merged adaptations
266
+ if self.alpha is not None and isinstance(self.alpha, (float, int)):
77
267
  print(f"Scaling new merged task vector by alpha: {self.alpha}")
78
268
  new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
79
269
 
270
+ # Phase 4: Prepare individual transformed models if requested
271
+ if self.return_single_task_models:
272
+ models = {}
273
+ # Create individual models by adding each transformed task vector to the pretrained base
274
+ for model_idx, model_name in enumerate(modelpool.model_names):
275
+ # Create a deep copy to avoid modifying the original pretrained model
276
+ model = deepcopy(pretrained_model)
277
+ # Apply the transformed task vector to get the individual model
278
+ model.load_state_dict(
279
+ state_dict_add(model.state_dict(), single_task_models[model_idx])
280
+ )
281
+ models[model_name] = model
282
+
283
+ # Phase 5: Create the final merged model by adding the merged task vector to pretrained model
284
+ # This produces a single model that combines capabilities from all input models
80
285
  pretrained_model.load_state_dict(
81
286
  state_dict_add(new_merged_tv, pretrained_model.state_dict())
82
287
  )
83
- return pretrained_model
288
+
289
+ # Phase 6: Return results based on the requested output format
290
+ if self.return_single_task_models:
291
+ # Include the final merged model in the dictionary of results
292
+ models["merged"] = pretrained_model
293
+ return models
294
+ else:
295
+ # Return only the merged model
296
+ return pretrained_model
@@ -1,7 +1,9 @@
1
+ import collections
1
2
  import math
2
- from typing import List, Optional
3
+ from typing import Dict, List, Optional
3
4
 
4
5
  import torch
6
+ from torch import Tensor, nn
5
7
 
6
8
  from fusion_bench.utils.type import StateDictType
7
9
 
@@ -314,7 +316,8 @@ def compute_and_sum_svd_mem_reduction(
314
316
  task_vectors: List[StateDictType],
315
317
  exclude_keys: Optional[List[str]] = None,
316
318
  accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
317
- ) -> StateDictType:
319
+ return_single_task_models: bool = False,
320
+ ):
318
321
  """
319
322
  Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
320
323
  reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
@@ -326,23 +329,25 @@ def compute_and_sum_svd_mem_reduction(
326
329
  dictionary of vectors.
327
330
  exclude_keys (list): A list of keys to exclude from the TSVM.
328
331
  accelerator (torch.device): The device to use for the computation.
332
+ return_single_task_models (bool): Whether to return the single task models after the TSVM.
329
333
 
330
334
  Returns:
331
335
  dict: A dictionary containing the new vectors after SVD computation and merging.
332
336
  """
333
337
  if exclude_keys is None:
334
338
  exclude_keys = []
335
- sv_reduction = 1 / len(task_vectors)
339
+ num_tasks = len(task_vectors)
340
+ sv_reduction = 1 / num_tasks
336
341
 
337
- new_vector = {}
342
+ new_vector: Dict[str, Tensor] = {}
343
+ if return_single_task_models:
344
+ single_task_models = [{} for _ in range(num_tasks)]
338
345
  for key in task_vectors[0]:
339
346
  original_device = task_vectors[0][key].device
340
347
  original_dtype = task_vectors[0][key].dtype
341
348
 
342
- new_vector[key] = {}
343
349
  for i, task_vector in enumerate(task_vectors):
344
- vec = task_vector[key].to(accelerator)
345
-
350
+ vec = task_vector[key].to(device=accelerator, non_blocking=True)
346
351
  if len(task_vector[key].shape) == 2 and key not in exclude_keys:
347
352
  # at current, the SVD is not supported for half precision, so we need to convert to float32
348
353
  if not (
@@ -350,13 +355,14 @@ def compute_and_sum_svd_mem_reduction(
350
355
  ):
351
356
  vec = vec.to(dtype=torch.float32)
352
357
 
353
- u, s, v = torch.linalg.svd(vec, full_matrices=False)
358
+ # vec = u @ torch.diag(s) @ vh
359
+ u, s, vh = torch.linalg.svd(vec, full_matrices=False)
354
360
 
355
361
  if i == 0:
356
362
  print(f"Computed SVD for {key}...")
357
363
  sum_u = torch.zeros_like(u, device=accelerator)
358
364
  sum_s = torch.zeros_like(s, device=accelerator)
359
- sum_v = torch.zeros_like(v, device=accelerator)
365
+ sum_vh = torch.zeros_like(vh, device=accelerator)
360
366
  reduced_index_s = int(s.shape[0] * sv_reduction)
361
367
 
362
368
  # select only the first reduced_index_s columns of u and place them
@@ -367,10 +373,9 @@ def compute_and_sum_svd_mem_reduction(
367
373
  :reduced_index_s
368
374
  ]
369
375
  # select only the first reduced_index_s rows of v and place them
370
- sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
376
+ sum_vh[i * reduced_index_s : (i + 1) * reduced_index_s, :] = vh[
371
377
  :reduced_index_s, :
372
378
  ]
373
-
374
379
  else:
375
380
  # if the vector is not a 2D tensor or is in exclude_keys, compute the mean
376
381
  if i == 0:
@@ -379,22 +384,49 @@ def compute_and_sum_svd_mem_reduction(
379
384
  new_vector[key] += (vec - new_vector[key]) / (i + 1)
380
385
 
381
386
  if len(task_vector[key].shape) == 2 and key not in exclude_keys:
382
- u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
383
- u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
387
+ u_u, s_u, vh_u = torch.linalg.svd(sum_u, full_matrices=False)
388
+ u_vh, s_vh, vh_vh = torch.linalg.svd(sum_vh, full_matrices=False)
384
389
 
385
390
  new_vector[key] = torch.linalg.multi_dot(
386
391
  (
387
392
  u_u,
388
- v_u,
393
+ vh_u,
389
394
  torch.diag(sum_s),
390
- u_v,
391
- v_v,
395
+ u_vh,
396
+ vh_vh,
392
397
  )
393
398
  )
394
- new_vector[key] = new_vector[key].to(
395
- device=original_device, dtype=original_dtype, non_blocking=True
396
- )
397
- return new_vector
399
+ new_vector[key] = new_vector[key].to(
400
+ device=original_device, dtype=original_dtype, non_blocking=True
401
+ )
402
+ if return_single_task_models:
403
+ reduced_index_s = int(sum_s.shape[0] * sv_reduction)
404
+ new_u = u_u @ vh_u
405
+ new_vh = u_vh @ vh_vh
406
+ for i in range(num_tasks):
407
+ single_task_models[i][key] = torch.linalg.multi_dot(
408
+ (
409
+ new_u[:, i * reduced_index_s : (i + 1) * reduced_index_s],
410
+ torch.diag(
411
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s]
412
+ ),
413
+ new_vh[i * reduced_index_s : (i + 1) * reduced_index_s, :],
414
+ )
415
+ ).to(
416
+ device=original_device, dtype=original_dtype, non_blocking=True
417
+ )
418
+ else:
419
+ new_vector[key] = new_vector[key].to(
420
+ device=original_device, dtype=original_dtype, non_blocking=True
421
+ )
422
+ if return_single_task_models:
423
+ for i in range(num_tasks):
424
+ single_task_models[i][key] = new_vector[key].clone()
425
+
426
+ if not return_single_task_models:
427
+ return new_vector
428
+ else:
429
+ return new_vector, single_task_models
398
430
 
399
431
 
400
432
  ###############
@@ -9,7 +9,7 @@ from hydra import compose, initialize
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
 
11
11
  from fusion_bench.utils import import_object, instantiate
12
- from fusion_bench.utils.instantiate import set_print_function_call
12
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -11,7 +11,7 @@ from lightning.fabric.utilities.rank_zero import rank_zero_only
11
11
  from omegaconf import DictConfig, OmegaConf
12
12
 
13
13
  from fusion_bench.utils import import_object
14
- from fusion_bench.utils.instantiate import instantiate
14
+ from fusion_bench.utils.instantiate_utils import instantiate
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  import lightning.fabric.loggers.tensorboard
@@ -172,3 +172,27 @@ class LightningFabricMixin:
172
172
  return True
173
173
  else:
174
174
  return False
175
+
176
+ def log(self, name: str, value: Any, step: Optional[int] = None):
177
+ """
178
+ Logs the metric to the fabric's logger.
179
+ """
180
+ self.fabric.log(name, value, step=step)
181
+
182
+ def log_dict(self, metrics: dict, step: Optional[int] = None):
183
+ """
184
+ Logs the metrics to the fabric's logger.
185
+ """
186
+ self.fabric.log_dict(metrics, step=step)
187
+
188
+ def log_optimizer_lr(
189
+ self,
190
+ optimizer: torch.optim.Optimizer,
191
+ step: Optional[int] = None,
192
+ name_template: str = "train/lr_group_{0}",
193
+ ):
194
+ """
195
+ Logs the learning rate of the optimizer to the fabric's logger.
196
+ """
197
+ for i, param_group in enumerate(optimizer.param_groups):
198
+ self.fabric.log(name_template.format(i), param_group["lr"], step=step)
@@ -4,13 +4,14 @@ from typing import Dict, Optional, Union
4
4
 
5
5
  from omegaconf import OmegaConf
6
6
 
7
- from fusion_bench.utils import instantiate
7
+ from fusion_bench.utils import import_object, instantiate
8
8
 
9
9
  log = logging.getLogger(__name__)
10
10
 
11
11
 
12
12
  class YAMLSerializationMixin:
13
13
  _recursive_: bool = False
14
+ _config_key: Optional[str] = None
14
15
  _config_mapping: Dict[str, str] = {
15
16
  "_recursive_": "_recursive_",
16
17
  }
@@ -99,7 +100,22 @@ class YAMLSerializationMixin:
99
100
  BaseModelPool: The loaded model pool.
100
101
  """
101
102
  config = OmegaConf.load(path)
102
- return instantiate(config, _recursive_=cls._recursive_)
103
+ if cls._config_key is not None and cls._config_key in config:
104
+ config = config[cls._config_key]
105
+ target_cls = import_object(config["_target_"])
106
+ if target_cls != cls:
107
+ log.warning(
108
+ f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
109
+ f"Instantiating the class {target_cls.__name__} instead."
110
+ )
111
+ return instantiate(
112
+ config,
113
+ _recursive_=(
114
+ cls._recursive_
115
+ if config.get("_recursive_") is None
116
+ else config.get("_recursive_")
117
+ ),
118
+ )
103
119
 
104
120
  def to_config(self):
105
121
  """
@@ -29,6 +29,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
29
29
  """
30
30
 
31
31
  _program = None
32
+ _config_key = "modelpool"
32
33
  _models: Union[DictConfig, Dict[str, nn.Module]]
33
34
  _config_mapping = BaseYAMLSerializableModel._config_mapping | {
34
35
  "_models": "models",
@@ -3,6 +3,7 @@ from copy import deepcopy
3
3
  from typing import Optional, Union
4
4
 
5
5
  from datasets import load_dataset
6
+ from lightning.fabric.utilities import rank_zero_only
6
7
  from omegaconf import DictConfig, open_dict
7
8
  from torch import nn
8
9
  from torch.utils.data import Dataset
@@ -40,7 +41,8 @@ class CLIPVisionModelPool(BaseModelPool):
40
41
  def load_processor(self, *args, **kwargs) -> CLIPProcessor:
41
42
  assert self._processor is not None, "Processor is not defined in the config"
42
43
  if isinstance(self._processor, str):
43
- log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
44
+ if rank_zero_only.rank == 0:
45
+ log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
44
46
  processor = CLIPProcessor.from_pretrained(self._processor)
45
47
  else:
46
48
  processor = instantiate(self._processor, *args, **kwargs)
@@ -50,7 +52,8 @@ class CLIPVisionModelPool(BaseModelPool):
50
52
  model_config = self._models[model_name]
51
53
 
52
54
  if isinstance(model_config, str):
53
- log.info(f"Loading `transformers.CLIPModel`: {model_config}")
55
+ if rank_zero_only.rank == 0:
56
+ log.info(f"Loading `transformers.CLIPModel`: {model_config}")
54
57
  clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
55
58
  return clip_model
56
59
  else:
@@ -102,10 +105,12 @@ class CLIPVisionModelPool(BaseModelPool):
102
105
  ):
103
106
  model = self._models[model_name_or_config]
104
107
  if isinstance(model, str):
105
- log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
108
+ if rank_zero_only.rank == 0:
109
+ log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
106
110
  return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
107
111
  if isinstance(model, nn.Module):
108
- log.info(f"Returning existing model: {model}")
112
+ if rank_zero_only.rank == 0:
113
+ log.info(f"Returning existing model: {model}")
109
114
  return model
110
115
 
111
116
  # If the model is not a string, we use the default load_model method
@@ -114,9 +119,10 @@ class CLIPVisionModelPool(BaseModelPool):
114
119
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
115
120
  dataset_config = self._train_datasets[dataset_name]
116
121
  if isinstance(dataset_config, str):
117
- log.info(
118
- f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
119
- )
122
+ if rank_zero_only.rank == 0:
123
+ log.info(
124
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
125
+ )
120
126
  dataset = load_dataset(dataset_config, split="train")
121
127
  else:
122
128
  dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
@@ -125,9 +131,10 @@ class CLIPVisionModelPool(BaseModelPool):
125
131
  def load_val_dataset(self, dataset_name: str, *args, **kwargs):
126
132
  dataset_config = self._val_datasets[dataset_name]
127
133
  if isinstance(dataset_config, str):
128
- log.info(
129
- f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
130
- )
134
+ if rank_zero_only.rank == 0:
135
+ log.info(
136
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
137
+ )
131
138
  dataset = load_dataset(dataset_config, split="validation")
132
139
  else:
133
140
  dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
@@ -136,9 +143,10 @@ class CLIPVisionModelPool(BaseModelPool):
136
143
  def load_test_dataset(self, dataset_name: str, *args, **kwargs):
137
144
  dataset_config = self._test_datasets[dataset_name]
138
145
  if isinstance(dataset_config, str):
139
- log.info(
140
- f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
141
- )
146
+ if rank_zero_only.rank == 0:
147
+ log.info(
148
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
149
+ )
142
150
  dataset = load_dataset(dataset_config, split="test")
143
151
  else:
144
152
  dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
@@ -0,0 +1,15 @@
1
+ from fusion_bench import BaseModelPool
2
+ from fusion_bench.utils import instantiate
3
+ from fusion_bench.utils.lazy_state_dict import LazyStateDict
4
+
5
+
6
+ class LazyStateDictPool(BaseModelPool):
7
+ def load_model(self, model_name_or_config: str, *args, **kwargs) -> LazyStateDict:
8
+ if model_name_or_config in self._models:
9
+ checkpoint_config = self._models[model_name_or_config]
10
+ else:
11
+ checkpoint_config = model_name_or_config
12
+ if isinstance(checkpoint_config, str):
13
+ return LazyStateDict(checkpoint_config, *args, **kwargs)
14
+ else:
15
+ return instantiate(checkpoint_config, *args, **kwargs)
@@ -0,0 +1,15 @@
1
+ """
2
+ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/tree/main
3
+ """
4
+
5
+ from .configuration_deepseek import DeepseekV2Config
6
+ from .modeling_deepseek import (
7
+ DeepseekV2ForCausalLM,
8
+ DeepseekV2ForSequenceClassification,
9
+ DeepseekV2MLP,
10
+ DeepseekV2Model,
11
+ DeepseekV2MoE,
12
+ DeepseekV2DecoderLayer,
13
+ )
14
+ from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
15
+ from .tokenization_deepseek_fast import DeepseekTokenizerFast