fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. fusion_bench/dataset/clip_dataset.py +1 -0
  2. fusion_bench/method/__init__.py +4 -0
  3. fusion_bench/method/adamerging/__init__.py +28 -5
  4. fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
  5. fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
  6. fusion_bench/method/adamerging/utils.py +58 -0
  7. fusion_bench/method/classification/clip_finetune.py +6 -4
  8. fusion_bench/method/classification/image_classification_finetune.py +156 -12
  9. fusion_bench/method/dare/simple_average.py +3 -2
  10. fusion_bench/method/dare/task_arithmetic.py +3 -2
  11. fusion_bench/method/dop/__init__.py +1 -0
  12. fusion_bench/method/dop/dop.py +366 -0
  13. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  14. fusion_bench/method/dop/utils.py +73 -0
  15. fusion_bench/method/simple_average.py +6 -4
  16. fusion_bench/mixins/lightning_fabric.py +9 -0
  17. fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
  18. fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
  19. fusion_bench/models/hf_clip.py +4 -7
  20. fusion_bench/models/hf_utils.py +4 -1
  21. fusion_bench/taskpool/__init__.py +2 -0
  22. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  23. fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
  24. fusion_bench/utils/state_dict_arithmetic.py +91 -10
  25. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
  26. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
  27. fusion_bench_config/fabric/auto.yaml +1 -1
  28. fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
  29. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  30. fusion_bench_config/fabric_model_fusion.yaml +1 -0
  31. fusion_bench_config/method/adamerging/resnet.yaml +18 -0
  32. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  33. fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
  34. fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
  35. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  36. fusion_bench_config/method/dop/dop.yaml +30 -0
  37. fusion_bench_config/method/dummy.yaml +6 -0
  38. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  39. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  40. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  41. fusion_bench_config/method/linear/expo.yaml +5 -0
  42. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  43. fusion_bench_config/method/linear/llama_expo.yaml +5 -0
  44. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
  45. fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
  46. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
  47. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
  48. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  49. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
  50. fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
  51. fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
  52. fusion_bench_config/method/model_recombination.yaml +8 -0
  53. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  54. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  55. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  56. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  57. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  58. fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
  59. fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
  60. fusion_bench_config/method/regmean/regmean.yaml +3 -0
  61. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
  62. fusion_bench_config/method/simple_average.yaml +9 -0
  63. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  64. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  65. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
  66. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  67. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
  68. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
  69. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
  70. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
  71. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  72. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  73. fusion_bench_config/method/ties_merging.yaml +3 -0
  74. fusion_bench_config/method/wudi/wudi.yaml +3 -0
  75. fusion_bench_config/model_fusion.yaml +2 -1
  76. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
  77. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
  78. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
  79. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
  80. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
  81. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
  82. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
  83. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
  84. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
  85. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
  86. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
  87. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
  88. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
  89. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
  90. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
  91. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
  92. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
  93. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
  94. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
  95. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
  96. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
  97. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
  98. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
  99. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
  100. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
  101. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
  102. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
  103. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
  104. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
  105. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
  106. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
  107. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
  108. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
  109. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
  110. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
  111. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
  112. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
  113. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
  114. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
  115. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
  116. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
  117. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
  118. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
  119. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
  120. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
  121. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
  122. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
  123. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
  124. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
  125. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
  126. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
  127. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
  128. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
  129. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
  130. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
  131. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
  132. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
  133. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
  134. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
  135. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
  136. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
  137. fusion_bench_config/method/clip_finetune.yaml +0 -26
  138. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
  139. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
  140. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
  141. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,9 @@
1
+ """Image Classification Fine-tuning Module.
2
+
3
+ This module provides algorithms for fine-tuning and evaluating image classification models
4
+ using PyTorch Lightning.
5
+ """
6
+
1
7
  import os
2
8
  from typing import Optional
3
9
 
@@ -23,33 +29,82 @@ from fusion_bench import (
23
29
  from fusion_bench.dataset import CLIPDataset
24
30
  from fusion_bench.modelpool import ResNetForImageClassificationPool
25
31
  from fusion_bench.tasks.clip_classification import get_num_classes
32
+ from torch.utils.data import random_split
26
33
 
27
34
  log = get_rankzero_logger(__name__)
28
35
 
29
36
 
30
37
  @auto_register_config
31
38
  class ImageClassificationFineTuning(BaseAlgorithm):
39
+ """Fine-tuning algorithm for image classification models.
40
+
41
+ This class implements end-to-end fine-tuning for image classification tasks using PyTorch Lightning.
42
+ It supports both epoch-based and step-based training with configurable optimizers, learning rate
43
+ schedulers, and data loaders.
44
+
45
+ Args:
46
+ max_epochs (Optional[int]): Maximum number of training epochs. Mutually exclusive with max_steps.
47
+ max_steps (Optional[int]): Maximum number of training steps. Mutually exclusive with max_epochs.
48
+ label_smoothing (float): Label smoothing factor for cross-entropy loss (0.0 = no smoothing).
49
+ optimizer (DictConfig): Configuration for the optimizer (e.g., Adam, SGD).
50
+ lr_scheduler (DictConfig): Configuration for the learning rate scheduler.
51
+ dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
52
+ **kwargs: Additional arguments passed to the base class.
53
+
54
+ Raises:
55
+ AssertionError: If both max_epochs and max_steps are provided.
56
+
57
+ Example:
58
+ ```python
59
+ >>> config = {
60
+ ... 'max_epochs': 10,
61
+ ... 'max_steps': None,
62
+ ... 'label_smoothing': 0.1,
63
+ ... 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001},
64
+ ... 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.StepLR', 'step_size': 5},
65
+ ... 'dataloader_kwargs': {'batch_size': 32, 'num_workers': 4}
66
+ ... }
67
+ >>> algorithm = ImageClassificationFineTuning(**config)
68
+ ```
69
+ """
70
+
32
71
  def __init__(
33
72
  self,
34
73
  max_epochs: Optional[int],
35
74
  max_steps: Optional[int],
75
+ training_data_ratio: Optional[float],
36
76
  label_smoothing: float,
37
77
  optimizer: DictConfig,
38
78
  lr_scheduler: DictConfig,
39
79
  dataloader_kwargs: DictConfig,
80
+ save_top_k: int,
81
+ save_interval: int,
82
+ save_on_train_epoch_end: bool,
40
83
  **kwargs,
41
84
  ):
42
85
  super().__init__(**kwargs)
43
- assert (max_epochs is None) or (
86
+ assert (max_epochs is None or max_epochs < 0) or (
44
87
  max_steps is None or max_steps < 0
45
88
  ), "Only one of max_epochs or max_steps should be set."
46
- self.training_interval = "epoch" if max_epochs is not None else "step"
89
+ self.training_interval = (
90
+ "epoch" if max_epochs is not None and max_epochs > 0 else "step"
91
+ )
47
92
  if self.training_interval == "epoch":
48
93
  self.max_steps = -1
49
94
  log.info(f"Training interval: {self.training_interval}")
50
95
  log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
51
96
 
52
97
  def run(self, modelpool: ResNetForImageClassificationPool):
98
+ """Execute the fine-tuning process on the provided model pool.
99
+
100
+ This method performs the complete fine-tuning workflow:
101
+ 1. Loads the pretrained model from the model pool
102
+ 2. Prepares training and validation datasets
103
+ 3. Configures optimizer and learning rate scheduler
104
+ 4. Sets up Lightning trainer with appropriate callbacks
105
+ 5. Executes the training process
106
+ 6. Saves the final fine-tuned model
107
+ """
53
108
  # load model and dataset
54
109
  model = modelpool.load_pretrained_or_first_model()
55
110
  assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
@@ -59,7 +114,17 @@ class ImageClassificationFineTuning(BaseAlgorithm):
59
114
  ), "Exactly one training dataset is required."
60
115
  self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
61
116
  num_classes = get_num_classes(dataset_name)
117
+ log.info(f"Number of classes for dataset {dataset_name}: {num_classes}")
62
118
  train_dataset = modelpool.load_train_dataset(dataset_name)
119
+ log.info(f"Training dataset size: {len(train_dataset)}")
120
+ if self.training_data_ratio is not None and 0 < self.training_data_ratio < 1:
121
+ train_dataset, _ = random_split(
122
+ train_dataset,
123
+ lengths=[self.training_data_ratio, 1 - self.training_data_ratio],
124
+ )
125
+ log.info(
126
+ f"Using {len(train_dataset)} samples for training after applying training_data_ratio={self.training_data_ratio}."
127
+ )
63
128
  train_dataset = CLIPDataset(
64
129
  train_dataset, processor=modelpool.load_processor(stage="train")
65
130
  )
@@ -91,7 +156,11 @@ class ImageClassificationFineTuning(BaseAlgorithm):
91
156
  objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
92
157
  metrics={
93
158
  "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
94
- "acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
159
+ f"acc@{min(5,num_classes)}": Accuracy(
160
+ task="multiclass",
161
+ num_classes=num_classes,
162
+ top_k=min(5, num_classes),
163
+ ),
95
164
  },
96
165
  )
97
166
 
@@ -108,11 +177,21 @@ class ImageClassificationFineTuning(BaseAlgorithm):
108
177
  callbacks=[
109
178
  pl_callbacks.LearningRateMonitor(logging_interval="step"),
110
179
  pl_callbacks.DeviceStatsMonitor(),
180
+ pl_callbacks.ModelCheckpoint(
181
+ save_top_k=self.save_top_k,
182
+ every_n_train_steps=(
183
+ self.save_interval if self.training_interval == "step" else None
184
+ ),
185
+ every_n_epochs=(
186
+ self.save_interval
187
+ if self.training_interval == "epoch"
188
+ else None
189
+ ),
190
+ save_on_train_epoch_end=self.save_on_train_epoch_end,
191
+ save_last=True,
192
+ ),
111
193
  ],
112
- logger=TensorBoardLogger(
113
- save_dir=log_dir,
114
- name="",
115
- ),
194
+ logger=TensorBoardLogger(save_dir=log_dir, name="", version=""),
116
195
  fast_dev_run=RuntimeConstants.debug,
117
196
  )
118
197
 
@@ -129,10 +208,26 @@ class ImageClassificationFineTuning(BaseAlgorithm):
129
208
  "raw_checkpoints",
130
209
  "final",
131
210
  ),
211
+ algorithm_config=self.config,
212
+ description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
132
213
  )
133
214
  return model
134
215
 
135
216
  def get_dataloader(self, dataset, stage: str):
217
+ """Create a DataLoader for the specified dataset and training stage.
218
+
219
+ Constructs a PyTorch DataLoader with stage-appropriate configurations:
220
+ - Training stage: shuffling enabled by default
221
+ - Validation/test stages: shuffling disabled by default
222
+
223
+ Args:
224
+ dataset: The dataset to wrap in a DataLoader.
225
+ stage (str): Training stage, must be one of "train", "val", or "test".
226
+ Determines default shuffling behavior.
227
+
228
+ Returns:
229
+ DataLoader: Configured DataLoader for the given dataset and stage.
230
+ """
136
231
  assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
137
232
  dataloader_kwargs = dict(self.dataloader_kwargs)
138
233
  if "shuffle" not in dataloader_kwargs:
@@ -142,10 +237,42 @@ class ImageClassificationFineTuning(BaseAlgorithm):
142
237
 
143
238
  @auto_register_config
144
239
  class ImageClassificationFineTuning_Test(BaseAlgorithm):
240
+ """Test/evaluation algorithm for fine-tuned image classification models.
241
+
242
+ This class implements model evaluation on test or validation datasets using PyTorch Lightning.
243
+ It can either evaluate a model directly or load a model from a checkpoint before evaluation.
244
+ The evaluation computes standard classification metrics including top-1 and top-5 accuracy.
245
+
246
+ Args:
247
+ checkpoint_path (str): Path to the model checkpoint file. If None, uses the model
248
+ directly from the model pool without loading from checkpoint.
249
+ dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
250
+ **kwargs: Additional arguments passed to the base class.
251
+
252
+ Example:
253
+ ```python
254
+ >>> config = {
255
+ ... 'checkpoint_path': '/path/to/model/checkpoint.ckpt',
256
+ ... 'dataloader_kwargs': {'batch_size': 64, 'num_workers': 4}
257
+ ... }
258
+ >>> test_algorithm = ImageClassificationFineTuning_Test(**config)
259
+ ```
260
+ """
261
+
145
262
  def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
146
263
  super().__init__(**kwargs)
147
264
 
148
- def run(self, modelpool: BaseModelPool):
265
+ def run(self, modelpool: ResNetForImageClassificationPool):
266
+ """Execute model evaluation on the provided model pool's test/validation dataset.
267
+
268
+ This method performs the complete evaluation workflow:
269
+ 1. Loads the model from the model pool (pretrained or first available)
270
+ 2. Prepares the test or validation dataset (prioritizes test if both available)
271
+ 3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy)
272
+ 4. Loads from checkpoint if specified, otherwise uses the model directly
273
+ 5. Executes the evaluation using Lightning trainer
274
+ 6. Logs and returns the test metrics
275
+ """
149
276
  assert (
150
277
  modelpool.has_val_dataset or modelpool.has_test_dataset
151
278
  ), "No validation or test dataset found in the model pool."
@@ -181,8 +308,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
181
308
  model,
182
309
  metrics={
183
310
  "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
184
- "acc@5": Accuracy(
185
- task="multiclass", num_classes=num_classes, top_k=5
311
+ f"acc@{min(5,num_classes)}": Accuracy(
312
+ task="multiclass",
313
+ num_classes=num_classes,
314
+ top_k=min(5, num_classes),
186
315
  ),
187
316
  },
188
317
  )
@@ -192,8 +321,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
192
321
  model=model,
193
322
  metrics={
194
323
  "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
195
- "acc@5": Accuracy(
196
- task="multiclass", num_classes=num_classes, top_k=5
324
+ f"acc@{min(5,num_classes)}": Accuracy(
325
+ task="multiclass",
326
+ num_classes=num_classes,
327
+ top_k=min(5, num_classes),
197
328
  ),
198
329
  },
199
330
  )
@@ -207,6 +338,19 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
207
338
  return model
208
339
 
209
340
  def get_dataloader(self, dataset, stage: str):
341
+ """Create a DataLoader for the specified dataset and evaluation stage.
342
+
343
+ Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation.
344
+ Similar to the training version but typically used for test/validation datasets.
345
+
346
+ Args:
347
+ dataset: The dataset to wrap in a DataLoader.
348
+ stage (str): Evaluation stage, must be one of "train", "val", or "test".
349
+ Determines default shuffling behavior (disabled for non-train stages).
350
+
351
+ Returns:
352
+ DataLoader: Configured DataLoader for the given dataset and stage.
353
+ """
210
354
  assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
211
355
  dataloader_kwargs = dict(self.dataloader_kwargs)
212
356
  if "shuffle" not in dataloader_kwargs:
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
 
3
- from fusion_bench import BaseAlgorithm, BaseModelPool
3
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
4
4
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
5
5
 
6
6
  from .task_arithmetic import DareTaskArithmetic
@@ -8,6 +8,7 @@ from .task_arithmetic import DareTaskArithmetic
8
8
  log = logging.getLogger(__name__)
9
9
 
10
10
 
11
+ @auto_register_config
11
12
  class DareSimpleAverage(BaseAlgorithm):
12
13
 
13
14
  def __init__(
@@ -17,10 +18,10 @@ class DareSimpleAverage(BaseAlgorithm):
17
18
  rescale: bool = True,
18
19
  **kwargs,
19
20
  ):
21
+ super().__init__(**kwargs)
20
22
  self.sparsity_ratio = sparsity_ratio
21
23
  self.only_on_linear_weight = only_on_linear_weights
22
24
  self.rescale = rescale
23
- super().__init__(**kwargs)
24
25
 
25
26
  def run(self, modelpool: BaseModelPool):
26
27
  return DareTaskArithmetic(
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor, nn
3
3
 
4
- from fusion_bench import BaseAlgorithm, BaseModelPool
4
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
5
5
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
6
6
 
7
7
  from .utils import (
@@ -12,6 +12,7 @@ from .utils import (
12
12
  )
13
13
 
14
14
 
15
+ @auto_register_config
15
16
  class DareTaskArithmetic(BaseAlgorithm):
16
17
  """
17
18
  Implementation of Task Arithmetic w/ DARE.
@@ -27,11 +28,11 @@ class DareTaskArithmetic(BaseAlgorithm):
27
28
  rescale: bool = True,
28
29
  **kwargs,
29
30
  ):
31
+ super().__init__(**kwargs)
30
32
  self.scaling_factor = scaling_factor
31
33
  self.sparsity_ratio = sparsity_ratio
32
34
  self.only_on_linear_weights = only_on_linear_weights
33
35
  self.rescale = rescale
34
- super().__init__(**kwargs)
35
36
 
36
37
  def _load_task_vector(
37
38
  self,
@@ -0,0 +1 @@
1
+ from .dop import ContinualDOPForCLIP
@@ -0,0 +1,366 @@
1
+ """
2
+ Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity. NeurIPS, 2025.
3
+
4
+
5
+ Example:
6
+
7
+ fusion_bench \
8
+ method=dop/dop \
9
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
10
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ import random
16
+ from copy import deepcopy
17
+ from pathlib import Path
18
+ from typing import Dict, List, Literal, Optional, Tuple, cast
19
+
20
+ import lightning as L
21
+ import numpy as np
22
+ import torch
23
+ from omegaconf import DictConfig
24
+ from torch import Tensor, nn
25
+ from torch.autograd import Variable
26
+ from tqdm.auto import tqdm
27
+ from transformers import CLIPVisionModel
28
+
29
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
30
+ from fusion_bench.method.simple_average import simple_average
31
+ from fusion_bench.mixins import LightningFabricMixin
32
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
33
+ from fusion_bench.utils import seed_everything_by_time
34
+ from fusion_bench.utils.json import save_to_json
35
+
36
+ from .min_norm_solvers import MinNormSolver, gradient_normalizers
37
+ from .utils import is_leaf_module, svd
38
+
39
+ log = logging.getLogger(__name__)
40
+
41
+
42
+ @auto_register_config
43
+ class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
44
+
45
+ def __init__(
46
+ self,
47
+ seed: Optional[int] = None,
48
+ shuffle_order: bool = False,
49
+ save_on_every_step: bool = True,
50
+ evaluate_on_every_step: bool = False,
51
+ lr: float = 1e-4,
52
+ num_steps: int = 200,
53
+ mgda: bool = True,
54
+ ema: bool = True,
55
+ ema_beta: float = 0.99,
56
+ alpha: float = None,
57
+ svd_epsilon: float = 1.0,
58
+ svd_proj_space: str = "uv",
59
+ **kwargs,
60
+ ):
61
+ self.lr = lr
62
+ self.num_steps = num_steps
63
+ self.mgda = mgda
64
+ self.ema = ema
65
+ self.ema_beta = ema_beta
66
+ self.alpha = alpha
67
+ self.svd_epsilon = svd_epsilon
68
+ self.svd_proj_space = svd_proj_space
69
+ self.seed = seed
70
+ self.shuffle_order = shuffle_order
71
+ self.save_on_every_step = save_on_every_step
72
+ self.evaluate_on_every_step = evaluate_on_every_step
73
+
74
+ assert (
75
+ self.svd_epsilon >= 0 and self.svd_epsilon <= 1
76
+ ), "The svd_epsilon should be in the range of [0, 1]"
77
+ assert (
78
+ self.alpha >= 0 and self.alpha <= 1
79
+ ), "The alpha should be in the range of [0, 1]"
80
+ super().__init__(**kwargs)
81
+
82
+ def print_params(self, pretrained_model):
83
+ total_params = 0
84
+ linear_params = 0
85
+ linear_weight_params = 0
86
+ for module_name, module in pretrained_model.named_modules():
87
+ if not is_leaf_module(module):
88
+ continue
89
+ if isinstance(module, nn.Linear):
90
+ linear_params += sum(p.numel() for n, p in module.named_parameters())
91
+ linear_weight_params += sum(
92
+ p.numel() for n, p in module.named_parameters() if "weight" in n
93
+ )
94
+ total_params += sum(p.numel() for p in module.parameters())
95
+
96
+ linear_ratio = linear_params / total_params * 100
97
+ linear_weight_ratio = linear_weight_params / total_params * 100
98
+ print(f"Total Parameters: {total_params}")
99
+ print(f"Linear Parameters: {linear_params}")
100
+ print(f"Linear Weight Parameters: {linear_weight_params}")
101
+ print(f"Linear Ratio: {linear_ratio:.2f}%")
102
+ print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
103
+
104
+ def run(self, modelpool: BaseModelPool):
105
+ if self.seed is not None:
106
+ L.seed_everything(self.seed)
107
+ else:
108
+ seed_everything_by_time(self.fabric)
109
+
110
+ # get the model names, shuffle if needed
111
+ # the model names will be saved to the log directory as `model_names.json`
112
+ model_names = modelpool.model_names
113
+ if self.shuffle_order:
114
+ random.shuffle(model_names)
115
+ if self.log_dir is not None:
116
+ save_to_json(model_names, os.path.join(self.log_dir, "model_names.json"))
117
+
118
+ if self.evaluate_on_every_step:
119
+ """Configuration for the test datasets"""
120
+ self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
121
+ self._test_datasets = deepcopy(self.taskpool._test_datasets)
122
+
123
+ pretrained_model = modelpool.load_pretrained_model()
124
+
125
+ merged_model = None
126
+ for model_idx, model_name in enumerate(model_names):
127
+ print(
128
+ f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
129
+ )
130
+ if model_idx == 0:
131
+ merged_model = modelpool.load_model(model_names[0])
132
+ else:
133
+ merged_model = self._layer_wise_optimize(
134
+ model_names=["merged", model_name],
135
+ pretrained_model=deepcopy(pretrained_model),
136
+ finetuned_models={
137
+ "merged": merged_model,
138
+ model_name: modelpool.load_model(model_name),
139
+ },
140
+ model_idx=model_idx,
141
+ )
142
+
143
+ if self.save_on_every_step:
144
+ self.save_merged_model(merged_model, model_idx)
145
+
146
+ if self.evaluate_on_every_step:
147
+ self.taskpool._is_setup = False
148
+ self.taskpool._test_datasets = DictConfig(
149
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
150
+ )
151
+ report = self.taskpool.evaluate(deepcopy(merged_model))
152
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
153
+
154
+ return merged_model
155
+
156
+ def _layer_wise_optimize(
157
+ self,
158
+ model_names: List[str],
159
+ pretrained_model: nn.Module,
160
+ finetuned_models: Dict[str, nn.Module],
161
+ model_idx: int,
162
+ ):
163
+ time_cost = []
164
+ for module_name, module in pretrained_model.named_modules():
165
+ if not is_leaf_module(module):
166
+ continue
167
+
168
+ if isinstance(module, nn.Linear):
169
+ if module.weight.requires_grad:
170
+ import time
171
+
172
+ start_time = time.time()
173
+ merged_weight = self._optimize_weight(
174
+ module.weight,
175
+ {
176
+ model_name: finetuned_models[model_name]
177
+ .get_submodule(module_name)
178
+ .weight
179
+ for model_name in model_names
180
+ },
181
+ module_name,
182
+ model_idx,
183
+ )
184
+ end_time = time.time()
185
+ time_cost.append(end_time - start_time)
186
+ module.weight.data = merged_weight.data
187
+ else:
188
+ module.weight.data = simple_average(
189
+ [
190
+ finetuned_models[model_name]
191
+ .get_submodule(module_name)
192
+ .weight
193
+ for model_name in model_names
194
+ ]
195
+ )
196
+ if module.bias is not None:
197
+ module.bias.data = simple_average(
198
+ [
199
+ finetuned_models[model_name].get_submodule(module_name).bias
200
+ for model_name in model_names
201
+ ]
202
+ )
203
+ else:
204
+ simple_average(
205
+ [
206
+ finetuned_models[model_name].get_submodule(module_name)
207
+ for model_name in model_names
208
+ ],
209
+ base_module=module,
210
+ )
211
+
212
+ return pretrained_model
213
+
214
+ def _optimize_weight(
215
+ self,
216
+ pretrained_weight: Tensor,
217
+ finetuned_weights: Dict[str, Tensor],
218
+ module_name: str,
219
+ model_idx: int,
220
+ ):
221
+ assert (
222
+ self.fabric.world_size == 1
223
+ ), "This algorithm is not currently supported in distributed training"
224
+
225
+ pretrained_weight = self.fabric.to_device(pretrained_weight.detach())
226
+ finetuned_weights = {
227
+ model_name: self.fabric.to_device(finetuned_weight.detach())
228
+ for model_name, finetuned_weight in finetuned_weights.items()
229
+ }
230
+
231
+ merged_weight = self.fabric.to_device(
232
+ nn.Parameter(
233
+ simple_average(
234
+ [
235
+ finetuned_weight.detach()
236
+ for finetuned_weight in finetuned_weights.values()
237
+ ]
238
+ ),
239
+ requires_grad=True,
240
+ )
241
+ )
242
+
243
+ # Compute SVD of the difference between the finetuned and pretrained weights
244
+ proj_u_dict = {}
245
+ proj_v_dict = {}
246
+ proj_s_dict = {}
247
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
248
+ finetuned_tv = finetuned_weight - pretrained_weight
249
+ u, s, v = svd(finetuned_tv, full_matrices=True)
250
+ epsilon = 1.0 if self.svd_epsilon > 1.0 else self.svd_epsilon
251
+ cumsum_ratio = s.cumsum(dim=0) / s.sum()
252
+ split_rank = torch.searchsorted(cumsum_ratio, epsilon).item()
253
+ u_main = u[:, :split_rank]
254
+ v_main = v[:, :split_rank]
255
+ s_main = s[:split_rank]
256
+ proj_u_dict[i] = u_main
257
+ proj_v_dict[i] = v_main
258
+ proj_s_dict[i] = s_main
259
+
260
+ if self.mgda:
261
+ if self.ema:
262
+ ema_sol = [self.alpha, 1 - self.alpha]
263
+ # This is multiple-gradient descent algorithm (MGDA) optimization
264
+ optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
265
+ all_losses = [[], []]
266
+ all_alphas = [[], []]
267
+ for step_idx in tqdm(
268
+ range(self.num_steps), desc=f"Optimizing {module_name} weight"
269
+ ):
270
+ # Scaling the loss functions based on the algorithm choice
271
+ loss_data = {}
272
+ grads = {}
273
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
274
+ proj_u = proj_u_dict[i]
275
+ proj_v = proj_v_dict[i]
276
+ proj_s = proj_s_dict[i]
277
+ delta_tv = merged_weight - finetuned_weight
278
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
279
+ loss_data[i] = float(loss_i.data)
280
+
281
+ all_losses[i].append(float(loss_i.data))
282
+
283
+ optimizer.zero_grad()
284
+ loss_i.backward()
285
+ grads[i] = Variable(
286
+ merged_weight.grad.data.clone(), requires_grad=False
287
+ )
288
+
289
+ # Normalize all gradients
290
+ gn = gradient_normalizers(
291
+ grads=grads, losses=loss_data, normalization_type="loss"
292
+ )
293
+ for i, _ in enumerate(finetuned_weights.values()):
294
+ grads[i] = grads[i] / float(gn[i])
295
+
296
+ # Frank-Wolfe iteration to compute scales.
297
+ sol, min_norm = MinNormSolver.find_min_norm_element(
298
+ [[grads[i]] for i in range(len(finetuned_weights.values()))]
299
+ )
300
+
301
+ if self.ema:
302
+ ema_sol = [
303
+ self.ema_beta * ema_sol[i] + (1 - self.ema_beta) * float(sol[i])
304
+ for i in range(len(sol))
305
+ ]
306
+ sol = ema_sol
307
+ all_alphas[0].append(ema_sol[0])
308
+ all_alphas[1].append(ema_sol[1])
309
+
310
+ # Scaled back-propagation
311
+ loss = 0
312
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
313
+ # Comptue gradients of each loss function wrt parameters
314
+ proj_u = proj_u_dict[i]
315
+ proj_v = proj_v_dict[i]
316
+ proj_s = proj_s_dict[i]
317
+ delta_tv = merged_weight - finetuned_weight
318
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
319
+ loss += float(sol[i]) * loss_i
320
+
321
+ optimizer.zero_grad()
322
+ loss.backward()
323
+ optimizer.step()
324
+
325
+ else:
326
+ # This is a naive weighted optimization
327
+ optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
328
+ for step_idx in tqdm(
329
+ range(self.num_steps), desc=f"Optimizing {module_name} weight"
330
+ ):
331
+ loss = 0
332
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
333
+ proj_u = proj_u_dict[i]
334
+ proj_v = proj_v_dict[i]
335
+ proj_s = proj_s_dict[i]
336
+ delta_tv = merged_weight - finetuned_weight
337
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
338
+ loss += self.alpha * loss_i if i == 0 else (1 - self.alpha) * loss_i
339
+
340
+ optimizer.zero_grad()
341
+ loss.backward()
342
+ optimizer.step()
343
+
344
+ return merged_weight.detach().cpu()
345
+
346
+ def cal_loss_i(self, delta_tv, proj_s, proj_u, proj_v):
347
+ proj_delta_1 = torch.diag(proj_s) @ proj_u.T @ delta_tv
348
+ proj_delta_2 = delta_tv @ proj_v @ torch.diag(proj_s)
349
+ loss_i_u = torch.linalg.matrix_norm(proj_delta_1, ord="fro") ** 2
350
+ loss_i_v = torch.linalg.matrix_norm(proj_delta_2, ord="fro") ** 2
351
+ if self.svd_proj_space == "uv":
352
+ loss_i = loss_i_u + loss_i_v
353
+ elif self.svd_proj_space == "u":
354
+ loss_i = loss_i_u
355
+ elif self.svd_proj_space == "v":
356
+ loss_i = loss_i_v
357
+ else:
358
+ raise ValueError("Invalid svd_proj_space")
359
+
360
+ return loss_i
361
+
362
+ def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
363
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
364
+ merged_model.save_pretrained(
365
+ Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
366
+ )