fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/clip_finetune.py +6 -4
- fusion_bench/method/classification/image_classification_finetune.py +156 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/dop/__init__.py +1 -0
- fusion_bench/method/dop/dop.py +366 -0
- fusion_bench/method/dop/min_norm_solvers.py +227 -0
- fusion_bench/method/dop/utils.py +73 -0
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/depth_upscaling.yaml +9 -0
- fusion_bench_config/method/dop/dop.yaml +30 -0
- fusion_bench_config/method/dummy.yaml +6 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average.yaml +3 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/model_recombination.yaml +8 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
- fusion_bench_config/method/opcm/opcm.yaml +5 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
- fusion_bench_config/method/opcm/weight_average.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/simple_average.yaml +9 -0
- fusion_bench_config/method/slerp/slerp.yaml +9 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
- fusion_bench_config/method/task_arithmetic.yaml +9 -0
- fusion_bench_config/method/ties_merging.yaml +3 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""Image Classification Fine-tuning Module.
|
|
2
|
+
|
|
3
|
+
This module provides algorithms for fine-tuning and evaluating image classification models
|
|
4
|
+
using PyTorch Lightning.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
import os
|
|
2
8
|
from typing import Optional
|
|
3
9
|
|
|
@@ -23,33 +29,82 @@ from fusion_bench import (
|
|
|
23
29
|
from fusion_bench.dataset import CLIPDataset
|
|
24
30
|
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
25
31
|
from fusion_bench.tasks.clip_classification import get_num_classes
|
|
32
|
+
from torch.utils.data import random_split
|
|
26
33
|
|
|
27
34
|
log = get_rankzero_logger(__name__)
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
@auto_register_config
|
|
31
38
|
class ImageClassificationFineTuning(BaseAlgorithm):
|
|
39
|
+
"""Fine-tuning algorithm for image classification models.
|
|
40
|
+
|
|
41
|
+
This class implements end-to-end fine-tuning for image classification tasks using PyTorch Lightning.
|
|
42
|
+
It supports both epoch-based and step-based training with configurable optimizers, learning rate
|
|
43
|
+
schedulers, and data loaders.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
max_epochs (Optional[int]): Maximum number of training epochs. Mutually exclusive with max_steps.
|
|
47
|
+
max_steps (Optional[int]): Maximum number of training steps. Mutually exclusive with max_epochs.
|
|
48
|
+
label_smoothing (float): Label smoothing factor for cross-entropy loss (0.0 = no smoothing).
|
|
49
|
+
optimizer (DictConfig): Configuration for the optimizer (e.g., Adam, SGD).
|
|
50
|
+
lr_scheduler (DictConfig): Configuration for the learning rate scheduler.
|
|
51
|
+
dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
|
|
52
|
+
**kwargs: Additional arguments passed to the base class.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
AssertionError: If both max_epochs and max_steps are provided.
|
|
56
|
+
|
|
57
|
+
Example:
|
|
58
|
+
```python
|
|
59
|
+
>>> config = {
|
|
60
|
+
... 'max_epochs': 10,
|
|
61
|
+
... 'max_steps': None,
|
|
62
|
+
... 'label_smoothing': 0.1,
|
|
63
|
+
... 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001},
|
|
64
|
+
... 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.StepLR', 'step_size': 5},
|
|
65
|
+
... 'dataloader_kwargs': {'batch_size': 32, 'num_workers': 4}
|
|
66
|
+
... }
|
|
67
|
+
>>> algorithm = ImageClassificationFineTuning(**config)
|
|
68
|
+
```
|
|
69
|
+
"""
|
|
70
|
+
|
|
32
71
|
def __init__(
|
|
33
72
|
self,
|
|
34
73
|
max_epochs: Optional[int],
|
|
35
74
|
max_steps: Optional[int],
|
|
75
|
+
training_data_ratio: Optional[float],
|
|
36
76
|
label_smoothing: float,
|
|
37
77
|
optimizer: DictConfig,
|
|
38
78
|
lr_scheduler: DictConfig,
|
|
39
79
|
dataloader_kwargs: DictConfig,
|
|
80
|
+
save_top_k: int,
|
|
81
|
+
save_interval: int,
|
|
82
|
+
save_on_train_epoch_end: bool,
|
|
40
83
|
**kwargs,
|
|
41
84
|
):
|
|
42
85
|
super().__init__(**kwargs)
|
|
43
|
-
assert (max_epochs is None) or (
|
|
86
|
+
assert (max_epochs is None or max_epochs < 0) or (
|
|
44
87
|
max_steps is None or max_steps < 0
|
|
45
88
|
), "Only one of max_epochs or max_steps should be set."
|
|
46
|
-
self.training_interval =
|
|
89
|
+
self.training_interval = (
|
|
90
|
+
"epoch" if max_epochs is not None and max_epochs > 0 else "step"
|
|
91
|
+
)
|
|
47
92
|
if self.training_interval == "epoch":
|
|
48
93
|
self.max_steps = -1
|
|
49
94
|
log.info(f"Training interval: {self.training_interval}")
|
|
50
95
|
log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
|
|
51
96
|
|
|
52
97
|
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
98
|
+
"""Execute the fine-tuning process on the provided model pool.
|
|
99
|
+
|
|
100
|
+
This method performs the complete fine-tuning workflow:
|
|
101
|
+
1. Loads the pretrained model from the model pool
|
|
102
|
+
2. Prepares training and validation datasets
|
|
103
|
+
3. Configures optimizer and learning rate scheduler
|
|
104
|
+
4. Sets up Lightning trainer with appropriate callbacks
|
|
105
|
+
5. Executes the training process
|
|
106
|
+
6. Saves the final fine-tuned model
|
|
107
|
+
"""
|
|
53
108
|
# load model and dataset
|
|
54
109
|
model = modelpool.load_pretrained_or_first_model()
|
|
55
110
|
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
@@ -59,7 +114,17 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
59
114
|
), "Exactly one training dataset is required."
|
|
60
115
|
self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
|
|
61
116
|
num_classes = get_num_classes(dataset_name)
|
|
117
|
+
log.info(f"Number of classes for dataset {dataset_name}: {num_classes}")
|
|
62
118
|
train_dataset = modelpool.load_train_dataset(dataset_name)
|
|
119
|
+
log.info(f"Training dataset size: {len(train_dataset)}")
|
|
120
|
+
if self.training_data_ratio is not None and 0 < self.training_data_ratio < 1:
|
|
121
|
+
train_dataset, _ = random_split(
|
|
122
|
+
train_dataset,
|
|
123
|
+
lengths=[self.training_data_ratio, 1 - self.training_data_ratio],
|
|
124
|
+
)
|
|
125
|
+
log.info(
|
|
126
|
+
f"Using {len(train_dataset)} samples for training after applying training_data_ratio={self.training_data_ratio}."
|
|
127
|
+
)
|
|
63
128
|
train_dataset = CLIPDataset(
|
|
64
129
|
train_dataset, processor=modelpool.load_processor(stage="train")
|
|
65
130
|
)
|
|
@@ -91,7 +156,11 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
91
156
|
objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
|
|
92
157
|
metrics={
|
|
93
158
|
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
94
|
-
"acc@5": Accuracy(
|
|
159
|
+
f"acc@{min(5,num_classes)}": Accuracy(
|
|
160
|
+
task="multiclass",
|
|
161
|
+
num_classes=num_classes,
|
|
162
|
+
top_k=min(5, num_classes),
|
|
163
|
+
),
|
|
95
164
|
},
|
|
96
165
|
)
|
|
97
166
|
|
|
@@ -108,11 +177,21 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
108
177
|
callbacks=[
|
|
109
178
|
pl_callbacks.LearningRateMonitor(logging_interval="step"),
|
|
110
179
|
pl_callbacks.DeviceStatsMonitor(),
|
|
180
|
+
pl_callbacks.ModelCheckpoint(
|
|
181
|
+
save_top_k=self.save_top_k,
|
|
182
|
+
every_n_train_steps=(
|
|
183
|
+
self.save_interval if self.training_interval == "step" else None
|
|
184
|
+
),
|
|
185
|
+
every_n_epochs=(
|
|
186
|
+
self.save_interval
|
|
187
|
+
if self.training_interval == "epoch"
|
|
188
|
+
else None
|
|
189
|
+
),
|
|
190
|
+
save_on_train_epoch_end=self.save_on_train_epoch_end,
|
|
191
|
+
save_last=True,
|
|
192
|
+
),
|
|
111
193
|
],
|
|
112
|
-
logger=TensorBoardLogger(
|
|
113
|
-
save_dir=log_dir,
|
|
114
|
-
name="",
|
|
115
|
-
),
|
|
194
|
+
logger=TensorBoardLogger(save_dir=log_dir, name="", version=""),
|
|
116
195
|
fast_dev_run=RuntimeConstants.debug,
|
|
117
196
|
)
|
|
118
197
|
|
|
@@ -129,10 +208,26 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
129
208
|
"raw_checkpoints",
|
|
130
209
|
"final",
|
|
131
210
|
),
|
|
211
|
+
algorithm_config=self.config,
|
|
212
|
+
description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
|
|
132
213
|
)
|
|
133
214
|
return model
|
|
134
215
|
|
|
135
216
|
def get_dataloader(self, dataset, stage: str):
|
|
217
|
+
"""Create a DataLoader for the specified dataset and training stage.
|
|
218
|
+
|
|
219
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations:
|
|
220
|
+
- Training stage: shuffling enabled by default
|
|
221
|
+
- Validation/test stages: shuffling disabled by default
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
225
|
+
stage (str): Training stage, must be one of "train", "val", or "test".
|
|
226
|
+
Determines default shuffling behavior.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
230
|
+
"""
|
|
136
231
|
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
137
232
|
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
138
233
|
if "shuffle" not in dataloader_kwargs:
|
|
@@ -142,10 +237,42 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
142
237
|
|
|
143
238
|
@auto_register_config
|
|
144
239
|
class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
240
|
+
"""Test/evaluation algorithm for fine-tuned image classification models.
|
|
241
|
+
|
|
242
|
+
This class implements model evaluation on test or validation datasets using PyTorch Lightning.
|
|
243
|
+
It can either evaluate a model directly or load a model from a checkpoint before evaluation.
|
|
244
|
+
The evaluation computes standard classification metrics including top-1 and top-5 accuracy.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
checkpoint_path (str): Path to the model checkpoint file. If None, uses the model
|
|
248
|
+
directly from the model pool without loading from checkpoint.
|
|
249
|
+
dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
|
|
250
|
+
**kwargs: Additional arguments passed to the base class.
|
|
251
|
+
|
|
252
|
+
Example:
|
|
253
|
+
```python
|
|
254
|
+
>>> config = {
|
|
255
|
+
... 'checkpoint_path': '/path/to/model/checkpoint.ckpt',
|
|
256
|
+
... 'dataloader_kwargs': {'batch_size': 64, 'num_workers': 4}
|
|
257
|
+
... }
|
|
258
|
+
>>> test_algorithm = ImageClassificationFineTuning_Test(**config)
|
|
259
|
+
```
|
|
260
|
+
"""
|
|
261
|
+
|
|
145
262
|
def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
|
|
146
263
|
super().__init__(**kwargs)
|
|
147
264
|
|
|
148
|
-
def run(self, modelpool:
|
|
265
|
+
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
266
|
+
"""Execute model evaluation on the provided model pool's test/validation dataset.
|
|
267
|
+
|
|
268
|
+
This method performs the complete evaluation workflow:
|
|
269
|
+
1. Loads the model from the model pool (pretrained or first available)
|
|
270
|
+
2. Prepares the test or validation dataset (prioritizes test if both available)
|
|
271
|
+
3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy)
|
|
272
|
+
4. Loads from checkpoint if specified, otherwise uses the model directly
|
|
273
|
+
5. Executes the evaluation using Lightning trainer
|
|
274
|
+
6. Logs and returns the test metrics
|
|
275
|
+
"""
|
|
149
276
|
assert (
|
|
150
277
|
modelpool.has_val_dataset or modelpool.has_test_dataset
|
|
151
278
|
), "No validation or test dataset found in the model pool."
|
|
@@ -181,8 +308,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
|
181
308
|
model,
|
|
182
309
|
metrics={
|
|
183
310
|
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
184
|
-
"acc@5": Accuracy(
|
|
185
|
-
task="multiclass",
|
|
311
|
+
f"acc@{min(5,num_classes)}": Accuracy(
|
|
312
|
+
task="multiclass",
|
|
313
|
+
num_classes=num_classes,
|
|
314
|
+
top_k=min(5, num_classes),
|
|
186
315
|
),
|
|
187
316
|
},
|
|
188
317
|
)
|
|
@@ -192,8 +321,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
|
192
321
|
model=model,
|
|
193
322
|
metrics={
|
|
194
323
|
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
195
|
-
"acc@5": Accuracy(
|
|
196
|
-
task="multiclass",
|
|
324
|
+
f"acc@{min(5,num_classes)}": Accuracy(
|
|
325
|
+
task="multiclass",
|
|
326
|
+
num_classes=num_classes,
|
|
327
|
+
top_k=min(5, num_classes),
|
|
197
328
|
),
|
|
198
329
|
},
|
|
199
330
|
)
|
|
@@ -207,6 +338,19 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
|
207
338
|
return model
|
|
208
339
|
|
|
209
340
|
def get_dataloader(self, dataset, stage: str):
|
|
341
|
+
"""Create a DataLoader for the specified dataset and evaluation stage.
|
|
342
|
+
|
|
343
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation.
|
|
344
|
+
Similar to the training version but typically used for test/validation datasets.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
348
|
+
stage (str): Evaluation stage, must be one of "train", "val", or "test".
|
|
349
|
+
Determines default shuffling behavior (disabled for non-train stages).
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
353
|
+
"""
|
|
210
354
|
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
211
355
|
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
212
356
|
if "shuffle" not in dataloader_kwargs:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
3
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
4
4
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
5
5
|
|
|
6
6
|
from .task_arithmetic import DareTaskArithmetic
|
|
@@ -8,6 +8,7 @@ from .task_arithmetic import DareTaskArithmetic
|
|
|
8
8
|
log = logging.getLogger(__name__)
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@auto_register_config
|
|
11
12
|
class DareSimpleAverage(BaseAlgorithm):
|
|
12
13
|
|
|
13
14
|
def __init__(
|
|
@@ -17,10 +18,10 @@ class DareSimpleAverage(BaseAlgorithm):
|
|
|
17
18
|
rescale: bool = True,
|
|
18
19
|
**kwargs,
|
|
19
20
|
):
|
|
21
|
+
super().__init__(**kwargs)
|
|
20
22
|
self.sparsity_ratio = sparsity_ratio
|
|
21
23
|
self.only_on_linear_weight = only_on_linear_weights
|
|
22
24
|
self.rescale = rescale
|
|
23
|
-
super().__init__(**kwargs)
|
|
24
25
|
|
|
25
26
|
def run(self, modelpool: BaseModelPool):
|
|
26
27
|
return DareTaskArithmetic(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor, nn
|
|
3
3
|
|
|
4
|
-
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
4
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
5
5
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
|
|
6
6
|
|
|
7
7
|
from .utils import (
|
|
@@ -12,6 +12,7 @@ from .utils import (
|
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
@auto_register_config
|
|
15
16
|
class DareTaskArithmetic(BaseAlgorithm):
|
|
16
17
|
"""
|
|
17
18
|
Implementation of Task Arithmetic w/ DARE.
|
|
@@ -27,11 +28,11 @@ class DareTaskArithmetic(BaseAlgorithm):
|
|
|
27
28
|
rescale: bool = True,
|
|
28
29
|
**kwargs,
|
|
29
30
|
):
|
|
31
|
+
super().__init__(**kwargs)
|
|
30
32
|
self.scaling_factor = scaling_factor
|
|
31
33
|
self.sparsity_ratio = sparsity_ratio
|
|
32
34
|
self.only_on_linear_weights = only_on_linear_weights
|
|
33
35
|
self.rescale = rescale
|
|
34
|
-
super().__init__(**kwargs)
|
|
35
36
|
|
|
36
37
|
def _load_task_vector(
|
|
37
38
|
self,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .dop import ContinualDOPForCLIP
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity. NeurIPS, 2025.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
|
|
7
|
+
fusion_bench \
|
|
8
|
+
method=dop/dop \
|
|
9
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
|
|
10
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
import random
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Dict, List, Literal, Optional, Tuple, cast
|
|
19
|
+
|
|
20
|
+
import lightning as L
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
from omegaconf import DictConfig
|
|
24
|
+
from torch import Tensor, nn
|
|
25
|
+
from torch.autograd import Variable
|
|
26
|
+
from tqdm.auto import tqdm
|
|
27
|
+
from transformers import CLIPVisionModel
|
|
28
|
+
|
|
29
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
30
|
+
from fusion_bench.method.simple_average import simple_average
|
|
31
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
32
|
+
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
33
|
+
from fusion_bench.utils import seed_everything_by_time
|
|
34
|
+
from fusion_bench.utils.json import save_to_json
|
|
35
|
+
|
|
36
|
+
from .min_norm_solvers import MinNormSolver, gradient_normalizers
|
|
37
|
+
from .utils import is_leaf_module, svd
|
|
38
|
+
|
|
39
|
+
log = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@auto_register_config
|
|
43
|
+
class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
seed: Optional[int] = None,
|
|
48
|
+
shuffle_order: bool = False,
|
|
49
|
+
save_on_every_step: bool = True,
|
|
50
|
+
evaluate_on_every_step: bool = False,
|
|
51
|
+
lr: float = 1e-4,
|
|
52
|
+
num_steps: int = 200,
|
|
53
|
+
mgda: bool = True,
|
|
54
|
+
ema: bool = True,
|
|
55
|
+
ema_beta: float = 0.99,
|
|
56
|
+
alpha: float = None,
|
|
57
|
+
svd_epsilon: float = 1.0,
|
|
58
|
+
svd_proj_space: str = "uv",
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
self.lr = lr
|
|
62
|
+
self.num_steps = num_steps
|
|
63
|
+
self.mgda = mgda
|
|
64
|
+
self.ema = ema
|
|
65
|
+
self.ema_beta = ema_beta
|
|
66
|
+
self.alpha = alpha
|
|
67
|
+
self.svd_epsilon = svd_epsilon
|
|
68
|
+
self.svd_proj_space = svd_proj_space
|
|
69
|
+
self.seed = seed
|
|
70
|
+
self.shuffle_order = shuffle_order
|
|
71
|
+
self.save_on_every_step = save_on_every_step
|
|
72
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
73
|
+
|
|
74
|
+
assert (
|
|
75
|
+
self.svd_epsilon >= 0 and self.svd_epsilon <= 1
|
|
76
|
+
), "The svd_epsilon should be in the range of [0, 1]"
|
|
77
|
+
assert (
|
|
78
|
+
self.alpha >= 0 and self.alpha <= 1
|
|
79
|
+
), "The alpha should be in the range of [0, 1]"
|
|
80
|
+
super().__init__(**kwargs)
|
|
81
|
+
|
|
82
|
+
def print_params(self, pretrained_model):
|
|
83
|
+
total_params = 0
|
|
84
|
+
linear_params = 0
|
|
85
|
+
linear_weight_params = 0
|
|
86
|
+
for module_name, module in pretrained_model.named_modules():
|
|
87
|
+
if not is_leaf_module(module):
|
|
88
|
+
continue
|
|
89
|
+
if isinstance(module, nn.Linear):
|
|
90
|
+
linear_params += sum(p.numel() for n, p in module.named_parameters())
|
|
91
|
+
linear_weight_params += sum(
|
|
92
|
+
p.numel() for n, p in module.named_parameters() if "weight" in n
|
|
93
|
+
)
|
|
94
|
+
total_params += sum(p.numel() for p in module.parameters())
|
|
95
|
+
|
|
96
|
+
linear_ratio = linear_params / total_params * 100
|
|
97
|
+
linear_weight_ratio = linear_weight_params / total_params * 100
|
|
98
|
+
print(f"Total Parameters: {total_params}")
|
|
99
|
+
print(f"Linear Parameters: {linear_params}")
|
|
100
|
+
print(f"Linear Weight Parameters: {linear_weight_params}")
|
|
101
|
+
print(f"Linear Ratio: {linear_ratio:.2f}%")
|
|
102
|
+
print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
|
|
103
|
+
|
|
104
|
+
def run(self, modelpool: BaseModelPool):
|
|
105
|
+
if self.seed is not None:
|
|
106
|
+
L.seed_everything(self.seed)
|
|
107
|
+
else:
|
|
108
|
+
seed_everything_by_time(self.fabric)
|
|
109
|
+
|
|
110
|
+
# get the model names, shuffle if needed
|
|
111
|
+
# the model names will be saved to the log directory as `model_names.json`
|
|
112
|
+
model_names = modelpool.model_names
|
|
113
|
+
if self.shuffle_order:
|
|
114
|
+
random.shuffle(model_names)
|
|
115
|
+
if self.log_dir is not None:
|
|
116
|
+
save_to_json(model_names, os.path.join(self.log_dir, "model_names.json"))
|
|
117
|
+
|
|
118
|
+
if self.evaluate_on_every_step:
|
|
119
|
+
"""Configuration for the test datasets"""
|
|
120
|
+
self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
|
|
121
|
+
self._test_datasets = deepcopy(self.taskpool._test_datasets)
|
|
122
|
+
|
|
123
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
124
|
+
|
|
125
|
+
merged_model = None
|
|
126
|
+
for model_idx, model_name in enumerate(model_names):
|
|
127
|
+
print(
|
|
128
|
+
f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
|
|
129
|
+
)
|
|
130
|
+
if model_idx == 0:
|
|
131
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
132
|
+
else:
|
|
133
|
+
merged_model = self._layer_wise_optimize(
|
|
134
|
+
model_names=["merged", model_name],
|
|
135
|
+
pretrained_model=deepcopy(pretrained_model),
|
|
136
|
+
finetuned_models={
|
|
137
|
+
"merged": merged_model,
|
|
138
|
+
model_name: modelpool.load_model(model_name),
|
|
139
|
+
},
|
|
140
|
+
model_idx=model_idx,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if self.save_on_every_step:
|
|
144
|
+
self.save_merged_model(merged_model, model_idx)
|
|
145
|
+
|
|
146
|
+
if self.evaluate_on_every_step:
|
|
147
|
+
self.taskpool._is_setup = False
|
|
148
|
+
self.taskpool._test_datasets = DictConfig(
|
|
149
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
150
|
+
)
|
|
151
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
152
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
153
|
+
|
|
154
|
+
return merged_model
|
|
155
|
+
|
|
156
|
+
def _layer_wise_optimize(
|
|
157
|
+
self,
|
|
158
|
+
model_names: List[str],
|
|
159
|
+
pretrained_model: nn.Module,
|
|
160
|
+
finetuned_models: Dict[str, nn.Module],
|
|
161
|
+
model_idx: int,
|
|
162
|
+
):
|
|
163
|
+
time_cost = []
|
|
164
|
+
for module_name, module in pretrained_model.named_modules():
|
|
165
|
+
if not is_leaf_module(module):
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
if isinstance(module, nn.Linear):
|
|
169
|
+
if module.weight.requires_grad:
|
|
170
|
+
import time
|
|
171
|
+
|
|
172
|
+
start_time = time.time()
|
|
173
|
+
merged_weight = self._optimize_weight(
|
|
174
|
+
module.weight,
|
|
175
|
+
{
|
|
176
|
+
model_name: finetuned_models[model_name]
|
|
177
|
+
.get_submodule(module_name)
|
|
178
|
+
.weight
|
|
179
|
+
for model_name in model_names
|
|
180
|
+
},
|
|
181
|
+
module_name,
|
|
182
|
+
model_idx,
|
|
183
|
+
)
|
|
184
|
+
end_time = time.time()
|
|
185
|
+
time_cost.append(end_time - start_time)
|
|
186
|
+
module.weight.data = merged_weight.data
|
|
187
|
+
else:
|
|
188
|
+
module.weight.data = simple_average(
|
|
189
|
+
[
|
|
190
|
+
finetuned_models[model_name]
|
|
191
|
+
.get_submodule(module_name)
|
|
192
|
+
.weight
|
|
193
|
+
for model_name in model_names
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
if module.bias is not None:
|
|
197
|
+
module.bias.data = simple_average(
|
|
198
|
+
[
|
|
199
|
+
finetuned_models[model_name].get_submodule(module_name).bias
|
|
200
|
+
for model_name in model_names
|
|
201
|
+
]
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
simple_average(
|
|
205
|
+
[
|
|
206
|
+
finetuned_models[model_name].get_submodule(module_name)
|
|
207
|
+
for model_name in model_names
|
|
208
|
+
],
|
|
209
|
+
base_module=module,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return pretrained_model
|
|
213
|
+
|
|
214
|
+
def _optimize_weight(
|
|
215
|
+
self,
|
|
216
|
+
pretrained_weight: Tensor,
|
|
217
|
+
finetuned_weights: Dict[str, Tensor],
|
|
218
|
+
module_name: str,
|
|
219
|
+
model_idx: int,
|
|
220
|
+
):
|
|
221
|
+
assert (
|
|
222
|
+
self.fabric.world_size == 1
|
|
223
|
+
), "This algorithm is not currently supported in distributed training"
|
|
224
|
+
|
|
225
|
+
pretrained_weight = self.fabric.to_device(pretrained_weight.detach())
|
|
226
|
+
finetuned_weights = {
|
|
227
|
+
model_name: self.fabric.to_device(finetuned_weight.detach())
|
|
228
|
+
for model_name, finetuned_weight in finetuned_weights.items()
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
merged_weight = self.fabric.to_device(
|
|
232
|
+
nn.Parameter(
|
|
233
|
+
simple_average(
|
|
234
|
+
[
|
|
235
|
+
finetuned_weight.detach()
|
|
236
|
+
for finetuned_weight in finetuned_weights.values()
|
|
237
|
+
]
|
|
238
|
+
),
|
|
239
|
+
requires_grad=True,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Compute SVD of the difference between the finetuned and pretrained weights
|
|
244
|
+
proj_u_dict = {}
|
|
245
|
+
proj_v_dict = {}
|
|
246
|
+
proj_s_dict = {}
|
|
247
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
248
|
+
finetuned_tv = finetuned_weight - pretrained_weight
|
|
249
|
+
u, s, v = svd(finetuned_tv, full_matrices=True)
|
|
250
|
+
epsilon = 1.0 if self.svd_epsilon > 1.0 else self.svd_epsilon
|
|
251
|
+
cumsum_ratio = s.cumsum(dim=0) / s.sum()
|
|
252
|
+
split_rank = torch.searchsorted(cumsum_ratio, epsilon).item()
|
|
253
|
+
u_main = u[:, :split_rank]
|
|
254
|
+
v_main = v[:, :split_rank]
|
|
255
|
+
s_main = s[:split_rank]
|
|
256
|
+
proj_u_dict[i] = u_main
|
|
257
|
+
proj_v_dict[i] = v_main
|
|
258
|
+
proj_s_dict[i] = s_main
|
|
259
|
+
|
|
260
|
+
if self.mgda:
|
|
261
|
+
if self.ema:
|
|
262
|
+
ema_sol = [self.alpha, 1 - self.alpha]
|
|
263
|
+
# This is multiple-gradient descent algorithm (MGDA) optimization
|
|
264
|
+
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
|
|
265
|
+
all_losses = [[], []]
|
|
266
|
+
all_alphas = [[], []]
|
|
267
|
+
for step_idx in tqdm(
|
|
268
|
+
range(self.num_steps), desc=f"Optimizing {module_name} weight"
|
|
269
|
+
):
|
|
270
|
+
# Scaling the loss functions based on the algorithm choice
|
|
271
|
+
loss_data = {}
|
|
272
|
+
grads = {}
|
|
273
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
274
|
+
proj_u = proj_u_dict[i]
|
|
275
|
+
proj_v = proj_v_dict[i]
|
|
276
|
+
proj_s = proj_s_dict[i]
|
|
277
|
+
delta_tv = merged_weight - finetuned_weight
|
|
278
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
279
|
+
loss_data[i] = float(loss_i.data)
|
|
280
|
+
|
|
281
|
+
all_losses[i].append(float(loss_i.data))
|
|
282
|
+
|
|
283
|
+
optimizer.zero_grad()
|
|
284
|
+
loss_i.backward()
|
|
285
|
+
grads[i] = Variable(
|
|
286
|
+
merged_weight.grad.data.clone(), requires_grad=False
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Normalize all gradients
|
|
290
|
+
gn = gradient_normalizers(
|
|
291
|
+
grads=grads, losses=loss_data, normalization_type="loss"
|
|
292
|
+
)
|
|
293
|
+
for i, _ in enumerate(finetuned_weights.values()):
|
|
294
|
+
grads[i] = grads[i] / float(gn[i])
|
|
295
|
+
|
|
296
|
+
# Frank-Wolfe iteration to compute scales.
|
|
297
|
+
sol, min_norm = MinNormSolver.find_min_norm_element(
|
|
298
|
+
[[grads[i]] for i in range(len(finetuned_weights.values()))]
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if self.ema:
|
|
302
|
+
ema_sol = [
|
|
303
|
+
self.ema_beta * ema_sol[i] + (1 - self.ema_beta) * float(sol[i])
|
|
304
|
+
for i in range(len(sol))
|
|
305
|
+
]
|
|
306
|
+
sol = ema_sol
|
|
307
|
+
all_alphas[0].append(ema_sol[0])
|
|
308
|
+
all_alphas[1].append(ema_sol[1])
|
|
309
|
+
|
|
310
|
+
# Scaled back-propagation
|
|
311
|
+
loss = 0
|
|
312
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
313
|
+
# Comptue gradients of each loss function wrt parameters
|
|
314
|
+
proj_u = proj_u_dict[i]
|
|
315
|
+
proj_v = proj_v_dict[i]
|
|
316
|
+
proj_s = proj_s_dict[i]
|
|
317
|
+
delta_tv = merged_weight - finetuned_weight
|
|
318
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
319
|
+
loss += float(sol[i]) * loss_i
|
|
320
|
+
|
|
321
|
+
optimizer.zero_grad()
|
|
322
|
+
loss.backward()
|
|
323
|
+
optimizer.step()
|
|
324
|
+
|
|
325
|
+
else:
|
|
326
|
+
# This is a naive weighted optimization
|
|
327
|
+
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
|
|
328
|
+
for step_idx in tqdm(
|
|
329
|
+
range(self.num_steps), desc=f"Optimizing {module_name} weight"
|
|
330
|
+
):
|
|
331
|
+
loss = 0
|
|
332
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
333
|
+
proj_u = proj_u_dict[i]
|
|
334
|
+
proj_v = proj_v_dict[i]
|
|
335
|
+
proj_s = proj_s_dict[i]
|
|
336
|
+
delta_tv = merged_weight - finetuned_weight
|
|
337
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
338
|
+
loss += self.alpha * loss_i if i == 0 else (1 - self.alpha) * loss_i
|
|
339
|
+
|
|
340
|
+
optimizer.zero_grad()
|
|
341
|
+
loss.backward()
|
|
342
|
+
optimizer.step()
|
|
343
|
+
|
|
344
|
+
return merged_weight.detach().cpu()
|
|
345
|
+
|
|
346
|
+
def cal_loss_i(self, delta_tv, proj_s, proj_u, proj_v):
|
|
347
|
+
proj_delta_1 = torch.diag(proj_s) @ proj_u.T @ delta_tv
|
|
348
|
+
proj_delta_2 = delta_tv @ proj_v @ torch.diag(proj_s)
|
|
349
|
+
loss_i_u = torch.linalg.matrix_norm(proj_delta_1, ord="fro") ** 2
|
|
350
|
+
loss_i_v = torch.linalg.matrix_norm(proj_delta_2, ord="fro") ** 2
|
|
351
|
+
if self.svd_proj_space == "uv":
|
|
352
|
+
loss_i = loss_i_u + loss_i_v
|
|
353
|
+
elif self.svd_proj_space == "u":
|
|
354
|
+
loss_i = loss_i_u
|
|
355
|
+
elif self.svd_proj_space == "v":
|
|
356
|
+
loss_i = loss_i_v
|
|
357
|
+
else:
|
|
358
|
+
raise ValueError("Invalid svd_proj_space")
|
|
359
|
+
|
|
360
|
+
return loss_i
|
|
361
|
+
|
|
362
|
+
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
363
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
364
|
+
merged_model.save_pretrained(
|
|
365
|
+
Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
|
|
366
|
+
)
|