fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. fusion_bench/__main__.py +4 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/__init__.py +26 -4
  4. fusion_bench/method/classification/__init__.py +1 -0
  5. fusion_bench/method/classification/clip_finetune.py +1 -3
  6. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  7. fusion_bench/method/dare/__init__.py +1 -0
  8. fusion_bench/method/dare/task_arithmetic.py +14 -7
  9. fusion_bench/method/dare/ties_merging.py +100 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  11. fusion_bench/method/isotropic_merging/iso.py +114 -0
  12. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  13. fusion_bench/method/opcm/__init__.py +4 -0
  14. fusion_bench/method/opcm/opcm.py +277 -0
  15. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  16. fusion_bench/method/opcm/ties_merging.py +156 -0
  17. fusion_bench/method/opcm/utils.py +73 -0
  18. fusion_bench/method/opcm/weight_average.py +120 -0
  19. fusion_bench/method/slerp/slerp.py +1 -1
  20. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  21. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
  22. fusion_bench/method/ties_merging/ties_merging.py +10 -0
  23. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  24. fusion_bench/mixins/clip_classification.py +4 -1
  25. fusion_bench/programs/fabric_fusion_program.py +22 -11
  26. fusion_bench/scripts/cli.py +1 -0
  27. fusion_bench/taskpool/base_pool.py +1 -1
  28. fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
  29. fusion_bench/utils/__init__.py +2 -1
  30. fusion_bench/utils/dict.py +43 -0
  31. fusion_bench/utils/expr.py +90 -0
  32. fusion_bench/utils/fabric.py +17 -0
  33. fusion_bench/utils/instantiate.py +7 -1
  34. fusion_bench/utils/json.py +30 -0
  35. fusion_bench/utils/parameters.py +27 -7
  36. fusion_bench/utils/path.py +15 -0
  37. fusion_bench/utils/plot/color_data.py +1726 -0
  38. fusion_bench/utils/rich_utils.py +15 -0
  39. fusion_bench/utils/set.py +8 -0
  40. fusion_bench/utils/tensorboard.py +51 -0
  41. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
  42. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
  43. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
  44. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  45. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  46. fusion_bench_config/method/clip_finetune.yaml +2 -2
  47. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  48. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  49. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  50. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  51. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  52. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  53. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  54. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  56. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
  57. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4 @@
1
+ from fusion_bench.scripts.cli import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -7,6 +7,7 @@ def load_fer2013(path: str = "clip-benchmark/wds_fer2013", split: str = "train")
7
7
  dataset = dataset.rename_columns({"jpg": "image", "cls": "label"})
8
8
  return dataset
9
9
 
10
+
10
11
  if __name__ == "__main__":
11
12
  dataset = load_fer2013(split="test")
12
13
  print(dataset)
@@ -9,7 +9,10 @@ _import_structure = {
9
9
  "base_algorithm": ["BaseModelFusionAlgorithm", "BaseAlgorithm"],
10
10
  "dummy": ["DummyAlgorithm"],
11
11
  # single task learning (fine-tuning)
12
- "classification": ["ImageClassificationFineTuningForCLIP"],
12
+ "classification": [
13
+ "ImageClassificationFineTuningForCLIP",
14
+ "ContinualImageClassificationFineTuningForCLIP",
15
+ ],
13
16
  "lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
14
17
  # analysis
15
18
  "analysis": ["TaskVectorCosSimilarity", "TaskVectorViolinPlot"],
@@ -27,11 +30,12 @@ _import_structure = {
27
30
  "TaskArithmeticForLlama",
28
31
  "LinearInterpolationAlgorithm",
29
32
  ],
33
+ "slerp": ["SlerpMergeAlgorithm"],
30
34
  "simple_average": ["SimpleAverageAlgorithm"],
31
35
  "weighted_average": ["WeightedAverageAlgorithm", "WeightedAverageForLLama"],
32
36
  "task_arithmetic": ["TaskArithmeticAlgorithm"],
33
37
  "ties_merging": ["TiesMergingAlgorithm"],
34
- "dare": ["DareSimpleAverage", "DareTaskArithmetic"],
38
+ "dare": ["DareSimpleAverage", "DareTaskArithmetic", "DareTiesMerging"],
35
39
  "fisher_merging": [
36
40
  "FisherMergingForCLIPVisionModel",
37
41
  "FisherMergingAlgorithmForGPT2",
@@ -50,6 +54,13 @@ _import_structure = {
50
54
  ],
51
55
  "ada_svd": ["AdaSVDMergingForCLIPVisionModel"],
52
56
  "task_singular_vector": ["TaskSingularVectorMerging"],
57
+ "isotropic_merging": [
58
+ "ISO_C_Merge", # alias
59
+ "ISO_CTS_Merge", # alias
60
+ "IsotropicMergingInCommonAndTaskSubspace",
61
+ "IsotropicMergingInCommonSubspace",
62
+ ],
63
+ "opcm": ["OPCMForCLIP"],
53
64
  # plug-and-play model merging methods
54
65
  "concrete_subspace": [
55
66
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -96,13 +107,16 @@ if TYPE_CHECKING:
96
107
  from .adamerging import *
97
108
  from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
98
109
  from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
99
- from .classification import ImageClassificationFineTuningForCLIP
110
+ from .classification import (
111
+ ContinualImageClassificationFineTuningForCLIP,
112
+ ImageClassificationFineTuningForCLIP,
113
+ )
100
114
  from .concrete_subspace import (
101
115
  ConcreteLayerWiseAdaMergingForCLIP,
102
116
  ConcreteTaskArithmeticAlgorithmForCLIP,
103
117
  ConcreteTaskWiseAdaMergingForCLIP,
104
118
  )
105
- from .dare import DareSimpleAverage, DareTaskArithmetic
119
+ from .dare import DareSimpleAverage, DareTaskArithmetic, DareTiesMerging
106
120
  from .dawe import DataAdaptiveWeightEnsemblingForCLIP
107
121
  from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
108
122
  from .dummy import DummyAlgorithm
@@ -112,6 +126,12 @@ if TYPE_CHECKING:
112
126
  WeightedEnsembleAlgorithm,
113
127
  )
114
128
  from .fisher_merging import FisherMergingForCLIPVisionModel
129
+ from .isotropic_merging import (
130
+ ISO_C_Merge,
131
+ ISO_CTS_Merge,
132
+ IsotropicMergingInCommonAndTaskSubspace,
133
+ IsotropicMergingInCommonSubspace,
134
+ )
115
135
  from .linear import (
116
136
  ExPOAlgorithm,
117
137
  ExPOAlgorithmForLlama,
@@ -127,6 +147,7 @@ if TYPE_CHECKING:
127
147
  MixtralUpscalingAlgorithm,
128
148
  )
129
149
  from .model_recombination import ModelRecombinationAlgorithm
150
+ from .opcm import OPCMForCLIP
130
151
  from .pruning import (
131
152
  MagnitudeDiffPruningAlgorithm,
132
153
  MagnitudePruningForLlama,
@@ -140,6 +161,7 @@ if TYPE_CHECKING:
140
161
  from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
141
162
  from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
142
163
  from .simple_average import SimpleAverageAlgorithm
164
+ from .slerp import SlerpMergeAlgorithm
143
165
  from .smile_upscaling import (
144
166
  SingularProjectionMergingAlgorithm,
145
167
  SmileUpscalingAlgorithm,
@@ -1,2 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_finetune import ImageClassificationFineTuningForCLIP
3
+ from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
@@ -184,9 +184,7 @@ class ImageClassificationFineTuningForCLIP(
184
184
  self.save_model(classifier, save_path)
185
185
 
186
186
  if config.state_dict_save_path is not None:
187
- self.save_model(
188
- classifier, config.state_dict_save_path, trainable_only=True
189
- )
187
+ self.save_model(classifier, config.state_dict_save_path)
190
188
  self.print_profile_summary()
191
189
  return classifier.clip_model.vision_model
192
190
 
@@ -0,0 +1,297 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from copy import deepcopy
5
+ from typing import Optional, Tuple, cast
6
+
7
+ import lightning as L
8
+ import torch
9
+ from omegaconf import DictConfig, OmegaConf
10
+ from peft import LoraConfig, PeftModel, get_peft_model
11
+ from peft.tuners.lora import LoraLayer
12
+ from safetensors.torch import save_file
13
+ from torch import nn
14
+ from torch.utils.data import DataLoader
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
17
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
18
+
19
+ from fusion_bench import BaseAlgorithm, print_parameters
20
+ from fusion_bench.compat.modelpool import to_modelpool
21
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
22
+ from fusion_bench.mixins import CLIPClassificationMixin
23
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
24
+ from fusion_bench.modelpool import CLIPVisionModelPool
25
+ from fusion_bench.models.hf_clip import HFCLIPClassifier
26
+ from fusion_bench.models.linearized.linearized_model_utils import LinearizedModelWraper
27
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
28
+ from fusion_bench.utils.data import InfiniteDataLoader
29
+ from fusion_bench.utils.fabric import seed_everything_by_time
30
+ from fusion_bench.utils.json import load_from_json, save_to_json
31
+
32
+
33
+ class ContinualImageClassificationFineTuningForCLIP(
34
+ CLIPClassificationMixin,
35
+ SimpleProfilerMixin,
36
+ BaseAlgorithm,
37
+ ):
38
+ # attributes to configuration keys mapping
39
+ _config_mapping = BaseAlgorithm._config_mapping | {
40
+ "seed": "seed",
41
+ "shuffle_order": "shuffle_order",
42
+ "learning_rate": "learning_rate",
43
+ "weight_decay": "weight_decay",
44
+ "num_steps": "num_steps",
45
+ "batch_size": "batch_size",
46
+ "num_workers": "num_workers",
47
+ "save_interval": "save_interval",
48
+ "state_dict_load_path": "state_dict_load_path",
49
+ "state_dict_save_path": "state_dict_save_path",
50
+ "skip_training": "skip_training",
51
+ "use_lora": "use_lora",
52
+ "lora_config": "lora_config",
53
+ }
54
+
55
+ def __init__(
56
+ self,
57
+ seed: int = 42,
58
+ shuffle_order: bool = True,
59
+ learning_rate: float = 1e-5,
60
+ weight_decay: float = 0,
61
+ num_steps: int = 4000,
62
+ batch_size: int = 128,
63
+ num_workers: int = 16,
64
+ save_interval: int = 500,
65
+ state_dict_load_path: Optional[str] = None,
66
+ state_dict_save_path: Optional[str] = None,
67
+ skip_training: bool = False,
68
+ use_lora: bool = False,
69
+ lora_config: Optional[LoraConfig] = None,
70
+ ):
71
+ self.seed = seed
72
+ self.shuffle_order = shuffle_order
73
+ self.learning_rate = learning_rate
74
+ self.weight_decay = weight_decay
75
+ self.num_steps = num_steps
76
+ self.batch_size = batch_size
77
+ self.num_workers = num_workers
78
+ self.save_interval = save_interval
79
+ self.state_dict_load_path = state_dict_load_path
80
+ self.state_dict_save_path = state_dict_save_path
81
+ self.skip_training = skip_training
82
+ self.use_lora = use_lora
83
+ self.lora_config = lora_config
84
+
85
+ def run(self, modelpool: CLIPVisionModelPool):
86
+ self.modelpool = to_modelpool(modelpool)
87
+ config = self.config
88
+ self.log_hyperparams(config, filename="method_config.yaml")
89
+ self.finetune_method = "fine-tune"
90
+
91
+ if self.seed is not None:
92
+ L.seed_everything(self.seed)
93
+ else:
94
+ seed_everything_by_time(self.fabric)
95
+
96
+ task_names = list(modelpool.train_dataset_names)
97
+ if self.shuffle_order:
98
+ random.shuffle(task_names)
99
+ if self.fabric.is_global_zero:
100
+ save_to_json(task_names, os.path.join(self.log_dir, "task_names.json"))
101
+
102
+ if self._program.taskpool is not None and isinstance(
103
+ self._program.taskpool, CLIPVisionModelTaskPool
104
+ ):
105
+ has_taskpool = True
106
+ taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
107
+ test_datasets = taskpool._test_datasets
108
+ else:
109
+ has_taskpool = False
110
+
111
+ with self.profile("setup model and optimizer"):
112
+ processor, classifier, optimizer, lr_scheduler = self.setup_model()
113
+
114
+ if self.state_dict_load_path is not None:
115
+ self.fabric.load(
116
+ self.state_dict_load_path,
117
+ {"vision_model": classifier.clip_model.vision_model},
118
+ )
119
+ if self.skip_training:
120
+ return classifier.clip_model.vision_model
121
+
122
+ self.setup_zero_shot_classification_head(
123
+ clip_processor=processor,
124
+ clip_model=classifier.clip_model,
125
+ task_names=task_names,
126
+ )
127
+
128
+ init_optimizer_state_dict = optimizer.state_dict()
129
+ init_lr_scheduler_state_dict = lr_scheduler.state_dict()
130
+ self.fabric.setup(classifier, optimizer)
131
+
132
+ with self.profile("setup data"):
133
+ train_datasets = [
134
+ CLIPDataset(modelpool.load_train_dataset(task_name), processor)
135
+ for task_name in task_names
136
+ ]
137
+ train_dataloaders = [
138
+ DataLoader(
139
+ dataset,
140
+ shuffle=True,
141
+ batch_size=self.batch_size,
142
+ num_workers=self.num_workers,
143
+ )
144
+ for dataset in train_datasets
145
+ ]
146
+ train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
147
+ if not isinstance(train_dataloaders, (list, tuple)):
148
+ train_dataloaders = [train_dataloaders]
149
+ train_dataloader_iters = [
150
+ iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
151
+ ]
152
+
153
+ # continual train
154
+ for task_idx, task_name in tqdm(
155
+ enumerate(task_names),
156
+ dynamic_ncols=True,
157
+ disable=not self.fabric.is_global_zero,
158
+ ):
159
+ train_dataloader_iter = train_dataloader_iters[task_idx]
160
+
161
+ # reset optimizer and lr scheduler
162
+ print("reset optimizer and lr scheduler")
163
+ optimizer.load_state_dict(init_optimizer_state_dict)
164
+ lr_scheduler.load_state_dict(init_lr_scheduler_state_dict)
165
+
166
+ for step_idx in tqdm(
167
+ range(self.num_steps),
168
+ desc=f"continual fine-tune on {task_name}",
169
+ disable=not self.fabric.is_global_zero,
170
+ dynamic_ncols=True,
171
+ leave=False,
172
+ ):
173
+ optimizer.zero_grad()
174
+ loss = 0
175
+ with self.profile("data loading"):
176
+ batch = next(train_dataloader_iter)
177
+ images, labels = batch
178
+ with self.profile("forward"):
179
+ classifier.zeroshot_weights = self.zeroshot_weights[task_name]
180
+ logits = classifier(images)
181
+ assert (
182
+ labels.max() < logits.shape[1]
183
+ ), f"for task {task_name}, labels.max() = {labels.max()}, logits.shape[1] = {logits.shape[1]}"
184
+ loss = loss + nn.functional.cross_entropy(logits, labels)
185
+
186
+ with self.profile("backward"):
187
+ self.fabric.backward(loss)
188
+ with self.profile("optimizer step"):
189
+ optimizer.step()
190
+ lr_scheduler.step()
191
+
192
+ metrics = {"train/loss": loss}
193
+ self.fabric.log_dict(metrics, step=step_idx)
194
+
195
+ if (step_idx + 1) % self.save_interval == 0:
196
+ save_path = os.path.join(
197
+ self.log_dir,
198
+ "checkpoints",
199
+ f"task={task_idx}_step={step_idx}.ckpt",
200
+ )
201
+ self.save_model(classifier, save_path)
202
+
203
+ if has_taskpool:
204
+ taskpool._is_setup = False
205
+ taskpool._test_datasets = DictConfig(
206
+ {t: test_datasets[t] for t in task_names[: task_idx + 1]}
207
+ )
208
+ eval_report = taskpool.evaluate(
209
+ deepcopy(classifier.clip_model.vision_model),
210
+ name=task_name,
211
+ )
212
+ if self.fabric.is_global_zero:
213
+ save_to_json(
214
+ eval_report,
215
+ os.path.join(self.log_dir, f"results_{task_idx}.json"),
216
+ )
217
+
218
+ if self.state_dict_save_path is not None:
219
+ self.save_model(classifier, self.state_dict_save_path)
220
+ self.print_profile_summary()
221
+ return classifier.clip_model.vision_model
222
+
223
+ def save_model(
224
+ self,
225
+ model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
226
+ save_path: str,
227
+ ):
228
+ """
229
+ Save the vision model to the specified path.
230
+
231
+ Args:
232
+ model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
233
+ save_path (str): The path to save the model.
234
+ """
235
+ if isinstance(model, HFCLIPClassifier):
236
+ vision_model = model.clip_model.vision_model
237
+ elif isinstance(model, CLIPModel):
238
+ vision_model = model.vision_model
239
+ elif isinstance(model, CLIPVisionModel):
240
+ vision_model = model.vision_model
241
+ elif isinstance(model, CLIPVisionTransformer):
242
+ vision_model = model
243
+ else:
244
+ raise ValueError(f"Unsupported model type: {type(model)}")
245
+
246
+ save_dir = os.path.dirname(save_path)
247
+ if save_dir and not os.path.exists(save_dir):
248
+ os.makedirs(save_dir, exist_ok=True)
249
+ self.fabric.save(save_path, {"vision_model": vision_model})
250
+
251
+ def setup_model(self):
252
+ """
253
+ Sets up the model, optimizer, and learning rate scheduler.
254
+
255
+ This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.
256
+
257
+ Returns:
258
+ Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
259
+ """
260
+ config = self.config
261
+ modelpool = self.modelpool
262
+
263
+ clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
264
+ processor = modelpool.load_processor()
265
+
266
+ self.finetune_method = "full fine-tune"
267
+ if self.use_lora:
268
+ self.finetune_method = "lora fine-tune"
269
+ lora_config = LoraConfig(
270
+ **OmegaConf.to_container(
271
+ self.lora_config, resolve=True, enum_to_str=True
272
+ )
273
+ )
274
+ clip_model.vision_model = get_peft_model(
275
+ clip_model.vision_model, lora_config
276
+ )
277
+
278
+ classifier = HFCLIPClassifier(clip_model, processor=processor)
279
+
280
+ if self.fabric.is_global_zero:
281
+ print("=== Model Summary (For Vision Model Only) ===")
282
+ print_parameters(classifier.clip_model.vision_model)
283
+ # configure optimizers
284
+ optimizer = torch.optim.Adam(
285
+ [
286
+ p
287
+ for p in classifier.clip_model.vision_model.parameters()
288
+ if p.requires_grad
289
+ ],
290
+ lr=self.learning_rate,
291
+ weight_decay=self.weight_decay,
292
+ )
293
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
294
+ optimizer=optimizer, T_max=self.num_steps
295
+ )
296
+
297
+ return processor, classifier, optimizer, lr_scheduler
@@ -1,3 +1,4 @@
1
1
  # flake8: noqa F401
2
2
  from .simple_average import DareSimpleAverage
3
3
  from .task_arithmetic import DareTaskArithmetic
4
+ from .ties_merging import DareTiesMerging
@@ -33,21 +33,28 @@ class DareTaskArithmetic(BaseAlgorithm):
33
33
  self.rescale = rescale
34
34
  super().__init__(**kwargs)
35
35
 
36
+ def _load_task_vector(
37
+ self,
38
+ modelpool: BaseModelPool,
39
+ model_name: str,
40
+ pretrained_model: nn.Module,
41
+ ):
42
+ finetuned_model = modelpool.load_model(model_name)
43
+ task_vector = module_sub_(finetuned_model, pretrained_model)
44
+ return task_vector
45
+
36
46
  @torch.no_grad()
37
47
  def run(self, modelpool: BaseModelPool):
38
48
  assert (
39
49
  self.sparsity_ratio >= 0 and self.sparsity_ratio <= 1
40
50
  ), "Sparsity ratio must be between 0 and 1"
41
51
  pretrained_model = modelpool.load_pretrained_model()
42
- finetuned_models = {
43
- model_name: modelpool.load_model(model_name)
44
- for model_name in modelpool.model_names
45
- }
52
+
53
+ # load task vectors
46
54
  task_vectors = {
47
- model_name: module_sub_(finetuned_models[model_name], pretrained_model)
48
- for model_name in finetuned_models
55
+ model_name: self._load_task_vector(modelpool, model_name, pretrained_model)
56
+ for model_name in modelpool.model_names
49
57
  }
50
- del finetuned_models
51
58
 
52
59
  # drop and rescale task vectors
53
60
  for model_name, tv in task_vectors.items():
@@ -0,0 +1,100 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from fusion_bench import BaseAlgorithm, BaseModelPool
7
+ from fusion_bench.method.ties_merging.ties_merging_utils import ties_merging
8
+ from fusion_bench.utils.parameters import state_dict_to_vector, vector_to_state_dict
9
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
10
+
11
+ from .utils import (
12
+ module_random_drop_,
13
+ module_sub_,
14
+ param_random_drop_,
15
+ trainable_state_dict,
16
+ )
17
+
18
+
19
+ class DareTiesMerging(BaseAlgorithm):
20
+ def __init__(
21
+ self,
22
+ # DARE parameters
23
+ sparsity_ratio: float,
24
+ only_on_linear_weights: bool,
25
+ rescale: bool,
26
+ # Ties merging parameters
27
+ scaling_factor: float,
28
+ threshold: int,
29
+ remove_keys: list[str],
30
+ merge_func: Literal["sum", "mean", "max"],
31
+ **kwargs,
32
+ ):
33
+ self.sparsity_ratio = sparsity_ratio
34
+ self.only_on_linear_weights = only_on_linear_weights
35
+ self.rescale = rescale
36
+ self.scaling_factor = scaling_factor
37
+ self.threshold = threshold
38
+ self.remove_keys = remove_keys
39
+ self.merge_func = merge_func
40
+ super().__init__(**kwargs)
41
+
42
+ @torch.no_grad()
43
+ def _load_task_vector(
44
+ self,
45
+ modelpool: BaseModelPool,
46
+ model_name: str,
47
+ pretrained_model: nn.Module,
48
+ ):
49
+ finetuned_model = modelpool.load_model(model_name)
50
+ task_vector = module_sub_(finetuned_model, pretrained_model)
51
+ return task_vector
52
+
53
+ def run(self, modelpool: BaseModelPool):
54
+ assert (
55
+ self.sparsity_ratio >= 0 and self.sparsity_ratio <= 1
56
+ ), "Sparsity ratio must be between 0 and 1"
57
+ pretrained_model = modelpool.load_pretrained_model()
58
+
59
+ # load task vectors
60
+ task_vectors = {
61
+ model_name: self._load_task_vector(modelpool, model_name, pretrained_model)
62
+ for model_name in modelpool.model_names
63
+ }
64
+
65
+ # drop and rescale task vectors
66
+ for model_name, tv in task_vectors.items():
67
+ if self.only_on_linear_weights:
68
+ for module_name, module in tv.named_modules():
69
+ if isinstance(module, nn.Linear):
70
+ print(f"pruning model: `{model_name}`, layer: {module_name}.")
71
+ param_random_drop_(
72
+ module.weight, self.sparsity_ratio, rescale=self.rescale
73
+ )
74
+ else:
75
+ print(f"pruning model: `{model_name}`")
76
+ module_random_drop_(tv, self.sparsity_ratio, rescale=self.rescale)
77
+
78
+ ptm_check = pretrained_model.state_dict()
79
+ flat_ptm = state_dict_to_vector(ptm_check, self.remove_keys)
80
+ tv_flat_checks = torch.vstack(
81
+ [
82
+ state_dict_to_vector(check.state_dict(), self.remove_keys)
83
+ for check in task_vectors.values()
84
+ ]
85
+ )
86
+ del task_vectors
87
+
88
+ # Perform TIES Merging
89
+ merged_tv = ties_merging(
90
+ tv_flat_checks,
91
+ reset_thresh=self.threshold,
92
+ merge_func=self.merge_func,
93
+ )
94
+ merged_check = flat_ptm + self.scaling_factor * merged_tv
95
+ merged_state_dict = vector_to_state_dict(
96
+ merged_check, ptm_check, remove_keys=self.remove_keys
97
+ )
98
+
99
+ pretrained_model.load_state_dict(merged_state_dict)
100
+ return pretrained_model
@@ -0,0 +1,15 @@
1
+ """
2
+ This module contains the implementation of the Isotropic Merging in Common Subspace (ISO-C) algorithm and Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS) algorithm.
3
+ Modified from the original implementation: https://github.com/danielm1405/iso-merging
4
+
5
+ Reference:
6
+ - Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
7
+ https://arxiv.org/abs/2502.04959
8
+ """
9
+
10
+ from .iso import (
11
+ ISO_C_Merge,
12
+ ISO_CTS_Merge,
13
+ IsotropicMergingInCommonSubspace,
14
+ IsotropicMergingInCommonAndTaskSubspace,
15
+ )
@@ -0,0 +1,114 @@
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ from fusion_bench import BaseAlgorithm, BaseModelPool
6
+ from fusion_bench.mixins import LightningFabricMixin
7
+ from fusion_bench.utils.state_dict_arithmetic import (
8
+ state_dict_add,
9
+ state_dict_sub,
10
+ state_dict_mul,
11
+ )
12
+
13
+ from .iso_utils import iso_c, iso_cts, check_parameterNamesMatch
14
+
15
+
16
+ class IsotropicMergingInCommonSubspace(BaseAlgorithm, LightningFabricMixin):
17
+ """
18
+ Isotropic Merging in Common Subspace (Iso-C)
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ scaling_factor: float,
24
+ exclude_keys: List[str] = None,
25
+ ):
26
+ self.scaling_factor = scaling_factor
27
+ self.exclude_keys = exclude_keys
28
+ super().__init__()
29
+
30
+ def run(self, modelpool: BaseModelPool):
31
+ # load the pretrained model and the task vectors of all the finetuned models
32
+ with torch.no_grad():
33
+ pretrained_model = modelpool.load_pretrained_model()
34
+ task_vectors = []
35
+ for model_name in modelpool.model_names:
36
+ finetuned_model = modelpool.load_model(model_name)
37
+ task_vectors.append(
38
+ state_dict_sub(
39
+ finetuned_model.state_dict(), pretrained_model.state_dict()
40
+ )
41
+ )
42
+ del finetuned_model # free memory
43
+ check_parameterNamesMatch(task_vectors)
44
+
45
+ # compute the merged task vector
46
+ merged_tv = iso_c(
47
+ task_vectors,
48
+ accelerator=self.fabric.device,
49
+ exclude_keys=self.exclude_keys,
50
+ )
51
+
52
+ # merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
53
+ pretrained_model.load_state_dict(
54
+ state_dict_add(
55
+ pretrained_model.state_dict(),
56
+ state_dict_mul(merged_tv, self.scaling_factor),
57
+ )
58
+ )
59
+
60
+ return pretrained_model
61
+
62
+
63
+ class IsotropicMergingInCommonAndTaskSubspace(BaseAlgorithm, LightningFabricMixin):
64
+ """
65
+ Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS)
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ scaling_factor: float,
71
+ common_space_fraction: float,
72
+ exclude_keys: List[str] = None,
73
+ ):
74
+ self.common_space_fraction = common_space_fraction
75
+ self.scaling_factor = scaling_factor
76
+ self.exclude_keys = exclude_keys
77
+ super().__init__()
78
+
79
+ def run(self, modelpool: BaseModelPool):
80
+ # load the pretrained model and the task vectors of all the finetuned models
81
+ with torch.no_grad():
82
+ pretrained_model = modelpool.load_pretrained_model()
83
+ task_vectors = []
84
+ for model_name in modelpool.model_names:
85
+ finetuned_model = modelpool.load_model(model_name)
86
+ task_vectors.append(
87
+ state_dict_sub(
88
+ finetuned_model.state_dict(), pretrained_model.state_dict()
89
+ )
90
+ )
91
+ del finetuned_model # free memory
92
+ check_parameterNamesMatch(task_vectors)
93
+
94
+ # compute the merged task vector
95
+ merged_tv = iso_cts(
96
+ task_vectors,
97
+ common_space_fraction=self.common_space_fraction,
98
+ accelerator=self.fabric.device,
99
+ exclude_keys=self.exclude_keys,
100
+ )
101
+
102
+ # merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
103
+ pretrained_model.load_state_dict(
104
+ state_dict_add(
105
+ pretrained_model.state_dict(),
106
+ state_dict_mul(merged_tv, self.scaling_factor),
107
+ )
108
+ )
109
+
110
+ return pretrained_model
111
+
112
+
113
+ ISO_C_Merge = IsotropicMergingInCommonSubspace # alias
114
+ ISO_CTS_Merge = IsotropicMergingInCommonAndTaskSubspace # alias