fusion-bench 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/image_classification_finetune.py +168 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/__init__.py +24 -2
- fusion_bench/modelpool/base_pool.py +8 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
- fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
- fusion_bench/modelpool/resnet_for_image_classification.py +289 -5
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/models/model_card_templates/default.md +1 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/json.py +49 -8
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/METADATA +2 -2
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/RECORD +124 -62
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +5 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
- fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""Image Classification Fine-tuning Module.
|
|
2
|
+
|
|
3
|
+
This module provides algorithms for fine-tuning and evaluating image classification models
|
|
4
|
+
using PyTorch Lightning.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
import os
|
|
2
8
|
from typing import Optional
|
|
3
9
|
|
|
@@ -23,35 +29,93 @@ from fusion_bench import (
|
|
|
23
29
|
from fusion_bench.dataset import CLIPDataset
|
|
24
30
|
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
25
31
|
from fusion_bench.tasks.clip_classification import get_num_classes
|
|
32
|
+
from torch.utils.data import random_split
|
|
26
33
|
|
|
27
34
|
log = get_rankzero_logger(__name__)
|
|
28
35
|
|
|
29
36
|
|
|
37
|
+
def _get_base_model_name(model) -> Optional[str]:
|
|
38
|
+
if hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
|
|
39
|
+
return model.config._name_or_path
|
|
40
|
+
else:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
|
|
30
44
|
@auto_register_config
|
|
31
45
|
class ImageClassificationFineTuning(BaseAlgorithm):
|
|
46
|
+
"""Fine-tuning algorithm for image classification models.
|
|
47
|
+
|
|
48
|
+
This class implements end-to-end fine-tuning for image classification tasks using PyTorch Lightning.
|
|
49
|
+
It supports both epoch-based and step-based training with configurable optimizers, learning rate
|
|
50
|
+
schedulers, and data loaders.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
max_epochs (Optional[int]): Maximum number of training epochs. Mutually exclusive with max_steps.
|
|
54
|
+
max_steps (Optional[int]): Maximum number of training steps. Mutually exclusive with max_epochs.
|
|
55
|
+
label_smoothing (float): Label smoothing factor for cross-entropy loss (0.0 = no smoothing).
|
|
56
|
+
optimizer (DictConfig): Configuration for the optimizer (e.g., Adam, SGD).
|
|
57
|
+
lr_scheduler (DictConfig): Configuration for the learning rate scheduler.
|
|
58
|
+
dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
|
|
59
|
+
**kwargs: Additional arguments passed to the base class.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
AssertionError: If both max_epochs and max_steps are provided.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
```python
|
|
66
|
+
>>> config = {
|
|
67
|
+
... 'max_epochs': 10,
|
|
68
|
+
... 'max_steps': None,
|
|
69
|
+
... 'label_smoothing': 0.1,
|
|
70
|
+
... 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001},
|
|
71
|
+
... 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.StepLR', 'step_size': 5},
|
|
72
|
+
... 'dataloader_kwargs': {'batch_size': 32, 'num_workers': 4}
|
|
73
|
+
... }
|
|
74
|
+
>>> algorithm = ImageClassificationFineTuning(**config)
|
|
75
|
+
```
|
|
76
|
+
"""
|
|
77
|
+
|
|
32
78
|
def __init__(
|
|
33
79
|
self,
|
|
34
80
|
max_epochs: Optional[int],
|
|
35
81
|
max_steps: Optional[int],
|
|
82
|
+
training_data_ratio: Optional[float],
|
|
36
83
|
label_smoothing: float,
|
|
37
84
|
optimizer: DictConfig,
|
|
38
85
|
lr_scheduler: DictConfig,
|
|
39
86
|
dataloader_kwargs: DictConfig,
|
|
87
|
+
save_top_k: int,
|
|
88
|
+
save_interval: int,
|
|
89
|
+
save_on_train_epoch_end: bool,
|
|
40
90
|
**kwargs,
|
|
41
91
|
):
|
|
42
92
|
super().__init__(**kwargs)
|
|
43
|
-
assert (max_epochs is None) or (
|
|
93
|
+
assert (max_epochs is None or max_epochs < 0) or (
|
|
44
94
|
max_steps is None or max_steps < 0
|
|
45
95
|
), "Only one of max_epochs or max_steps should be set."
|
|
46
|
-
self.training_interval =
|
|
96
|
+
self.training_interval = (
|
|
97
|
+
"epoch" if max_epochs is not None and max_epochs > 0 else "step"
|
|
98
|
+
)
|
|
47
99
|
if self.training_interval == "epoch":
|
|
48
100
|
self.max_steps = -1
|
|
49
101
|
log.info(f"Training interval: {self.training_interval}")
|
|
50
102
|
log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
|
|
51
103
|
|
|
52
104
|
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
105
|
+
"""Execute the fine-tuning process on the provided model pool.
|
|
106
|
+
|
|
107
|
+
This method performs the complete fine-tuning workflow:
|
|
108
|
+
1. Loads the pretrained model from the model pool
|
|
109
|
+
2. Prepares training and validation datasets
|
|
110
|
+
3. Configures optimizer and learning rate scheduler
|
|
111
|
+
4. Sets up Lightning trainer with appropriate callbacks
|
|
112
|
+
5. Executes the training process
|
|
113
|
+
6. Saves the final fine-tuned model
|
|
114
|
+
"""
|
|
53
115
|
# load model and dataset
|
|
54
116
|
model = modelpool.load_pretrained_or_first_model()
|
|
117
|
+
base_model_name = _get_base_model_name(model)
|
|
118
|
+
|
|
55
119
|
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
56
120
|
|
|
57
121
|
assert (
|
|
@@ -59,7 +123,17 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
59
123
|
), "Exactly one training dataset is required."
|
|
60
124
|
self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
|
|
61
125
|
num_classes = get_num_classes(dataset_name)
|
|
126
|
+
log.info(f"Number of classes for dataset {dataset_name}: {num_classes}")
|
|
62
127
|
train_dataset = modelpool.load_train_dataset(dataset_name)
|
|
128
|
+
log.info(f"Training dataset size: {len(train_dataset)}")
|
|
129
|
+
if self.training_data_ratio is not None and 0 < self.training_data_ratio < 1:
|
|
130
|
+
train_dataset, _ = random_split(
|
|
131
|
+
train_dataset,
|
|
132
|
+
lengths=[self.training_data_ratio, 1 - self.training_data_ratio],
|
|
133
|
+
)
|
|
134
|
+
log.info(
|
|
135
|
+
f"Using {len(train_dataset)} samples for training after applying training_data_ratio={self.training_data_ratio}."
|
|
136
|
+
)
|
|
63
137
|
train_dataset = CLIPDataset(
|
|
64
138
|
train_dataset, processor=modelpool.load_processor(stage="train")
|
|
65
139
|
)
|
|
@@ -70,6 +144,8 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
70
144
|
val_dataset, processor=modelpool.load_processor(stage="val")
|
|
71
145
|
)
|
|
72
146
|
val_loader = self.get_dataloader(val_dataset, stage="val")
|
|
147
|
+
else:
|
|
148
|
+
val_loader = None
|
|
73
149
|
|
|
74
150
|
# configure optimizer
|
|
75
151
|
optimizer = instantiate(self.optimizer, params=model.parameters())
|
|
@@ -91,7 +167,11 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
91
167
|
objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
|
|
92
168
|
metrics={
|
|
93
169
|
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
94
|
-
"acc@5": Accuracy(
|
|
170
|
+
f"acc@{min(5,num_classes)}": Accuracy(
|
|
171
|
+
task="multiclass",
|
|
172
|
+
num_classes=num_classes,
|
|
173
|
+
top_k=min(5, num_classes),
|
|
174
|
+
),
|
|
95
175
|
},
|
|
96
176
|
)
|
|
97
177
|
|
|
@@ -108,11 +188,21 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
108
188
|
callbacks=[
|
|
109
189
|
pl_callbacks.LearningRateMonitor(logging_interval="step"),
|
|
110
190
|
pl_callbacks.DeviceStatsMonitor(),
|
|
191
|
+
pl_callbacks.ModelCheckpoint(
|
|
192
|
+
save_top_k=self.save_top_k,
|
|
193
|
+
every_n_train_steps=(
|
|
194
|
+
self.save_interval if self.training_interval == "step" else None
|
|
195
|
+
),
|
|
196
|
+
every_n_epochs=(
|
|
197
|
+
self.save_interval
|
|
198
|
+
if self.training_interval == "epoch"
|
|
199
|
+
else None
|
|
200
|
+
),
|
|
201
|
+
save_on_train_epoch_end=self.save_on_train_epoch_end,
|
|
202
|
+
save_last=True,
|
|
203
|
+
),
|
|
111
204
|
],
|
|
112
|
-
logger=TensorBoardLogger(
|
|
113
|
-
save_dir=log_dir,
|
|
114
|
-
name="",
|
|
115
|
-
),
|
|
205
|
+
logger=TensorBoardLogger(save_dir=log_dir, name="", version=""),
|
|
116
206
|
fast_dev_run=RuntimeConstants.debug,
|
|
117
207
|
)
|
|
118
208
|
|
|
@@ -129,10 +219,27 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
129
219
|
"raw_checkpoints",
|
|
130
220
|
"final",
|
|
131
221
|
),
|
|
222
|
+
algorithm_config=self.config,
|
|
223
|
+
description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
|
|
224
|
+
base_model=base_model_name,
|
|
132
225
|
)
|
|
133
226
|
return model
|
|
134
227
|
|
|
135
228
|
def get_dataloader(self, dataset, stage: str):
|
|
229
|
+
"""Create a DataLoader for the specified dataset and training stage.
|
|
230
|
+
|
|
231
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations:
|
|
232
|
+
- Training stage: shuffling enabled by default
|
|
233
|
+
- Validation/test stages: shuffling disabled by default
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
237
|
+
stage (str): Training stage, must be one of "train", "val", or "test".
|
|
238
|
+
Determines default shuffling behavior.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
242
|
+
"""
|
|
136
243
|
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
137
244
|
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
138
245
|
if "shuffle" not in dataloader_kwargs:
|
|
@@ -142,10 +249,42 @@ class ImageClassificationFineTuning(BaseAlgorithm):
|
|
|
142
249
|
|
|
143
250
|
@auto_register_config
|
|
144
251
|
class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
252
|
+
"""Test/evaluation algorithm for fine-tuned image classification models.
|
|
253
|
+
|
|
254
|
+
This class implements model evaluation on test or validation datasets using PyTorch Lightning.
|
|
255
|
+
It can either evaluate a model directly or load a model from a checkpoint before evaluation.
|
|
256
|
+
The evaluation computes standard classification metrics including top-1 and top-5 accuracy.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
checkpoint_path (str): Path to the model checkpoint file. If None, uses the model
|
|
260
|
+
directly from the model pool without loading from checkpoint.
|
|
261
|
+
dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
|
|
262
|
+
**kwargs: Additional arguments passed to the base class.
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
```python
|
|
266
|
+
>>> config = {
|
|
267
|
+
... 'checkpoint_path': '/path/to/model/checkpoint.ckpt',
|
|
268
|
+
... 'dataloader_kwargs': {'batch_size': 64, 'num_workers': 4}
|
|
269
|
+
... }
|
|
270
|
+
>>> test_algorithm = ImageClassificationFineTuning_Test(**config)
|
|
271
|
+
```
|
|
272
|
+
"""
|
|
273
|
+
|
|
145
274
|
def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
|
|
146
275
|
super().__init__(**kwargs)
|
|
147
276
|
|
|
148
|
-
def run(self, modelpool:
|
|
277
|
+
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
278
|
+
"""Execute model evaluation on the provided model pool's test/validation dataset.
|
|
279
|
+
|
|
280
|
+
This method performs the complete evaluation workflow:
|
|
281
|
+
1. Loads the model from the model pool (pretrained or first available)
|
|
282
|
+
2. Prepares the test or validation dataset (prioritizes test if both available)
|
|
283
|
+
3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy)
|
|
284
|
+
4. Loads from checkpoint if specified, otherwise uses the model directly
|
|
285
|
+
5. Executes the evaluation using Lightning trainer
|
|
286
|
+
6. Logs and returns the test metrics
|
|
287
|
+
"""
|
|
149
288
|
assert (
|
|
150
289
|
modelpool.has_val_dataset or modelpool.has_test_dataset
|
|
151
290
|
), "No validation or test dataset found in the model pool."
|
|
@@ -181,8 +320,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
|
181
320
|
model,
|
|
182
321
|
metrics={
|
|
183
322
|
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
184
|
-
"acc@5": Accuracy(
|
|
185
|
-
task="multiclass",
|
|
323
|
+
f"acc@{min(5,num_classes)}": Accuracy(
|
|
324
|
+
task="multiclass",
|
|
325
|
+
num_classes=num_classes,
|
|
326
|
+
top_k=min(5, num_classes),
|
|
186
327
|
),
|
|
187
328
|
},
|
|
188
329
|
)
|
|
@@ -192,8 +333,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
|
192
333
|
model=model,
|
|
193
334
|
metrics={
|
|
194
335
|
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
195
|
-
"acc@5": Accuracy(
|
|
196
|
-
task="multiclass",
|
|
336
|
+
f"acc@{min(5,num_classes)}": Accuracy(
|
|
337
|
+
task="multiclass",
|
|
338
|
+
num_classes=num_classes,
|
|
339
|
+
top_k=min(5, num_classes),
|
|
197
340
|
),
|
|
198
341
|
},
|
|
199
342
|
)
|
|
@@ -207,6 +350,19 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
|
207
350
|
return model
|
|
208
351
|
|
|
209
352
|
def get_dataloader(self, dataset, stage: str):
|
|
353
|
+
"""Create a DataLoader for the specified dataset and evaluation stage.
|
|
354
|
+
|
|
355
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation.
|
|
356
|
+
Similar to the training version but typically used for test/validation datasets.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
360
|
+
stage (str): Evaluation stage, must be one of "train", "val", or "test".
|
|
361
|
+
Determines default shuffling behavior (disabled for non-train stages).
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
365
|
+
"""
|
|
210
366
|
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
211
367
|
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
212
368
|
if "shuffle" not in dataloader_kwargs:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
3
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
4
4
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
5
5
|
|
|
6
6
|
from .task_arithmetic import DareTaskArithmetic
|
|
@@ -8,6 +8,7 @@ from .task_arithmetic import DareTaskArithmetic
|
|
|
8
8
|
log = logging.getLogger(__name__)
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@auto_register_config
|
|
11
12
|
class DareSimpleAverage(BaseAlgorithm):
|
|
12
13
|
|
|
13
14
|
def __init__(
|
|
@@ -17,10 +18,10 @@ class DareSimpleAverage(BaseAlgorithm):
|
|
|
17
18
|
rescale: bool = True,
|
|
18
19
|
**kwargs,
|
|
19
20
|
):
|
|
21
|
+
super().__init__(**kwargs)
|
|
20
22
|
self.sparsity_ratio = sparsity_ratio
|
|
21
23
|
self.only_on_linear_weight = only_on_linear_weights
|
|
22
24
|
self.rescale = rescale
|
|
23
|
-
super().__init__(**kwargs)
|
|
24
25
|
|
|
25
26
|
def run(self, modelpool: BaseModelPool):
|
|
26
27
|
return DareTaskArithmetic(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor, nn
|
|
3
3
|
|
|
4
|
-
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
4
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
5
5
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
|
|
6
6
|
|
|
7
7
|
from .utils import (
|
|
@@ -12,6 +12,7 @@ from .utils import (
|
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
@auto_register_config
|
|
15
16
|
class DareTaskArithmetic(BaseAlgorithm):
|
|
16
17
|
"""
|
|
17
18
|
Implementation of Task Arithmetic w/ DARE.
|
|
@@ -27,11 +28,11 @@ class DareTaskArithmetic(BaseAlgorithm):
|
|
|
27
28
|
rescale: bool = True,
|
|
28
29
|
**kwargs,
|
|
29
30
|
):
|
|
31
|
+
super().__init__(**kwargs)
|
|
30
32
|
self.scaling_factor = scaling_factor
|
|
31
33
|
self.sparsity_ratio = sparsity_ratio
|
|
32
34
|
self.only_on_linear_weights = only_on_linear_weights
|
|
33
35
|
self.rescale = rescale
|
|
34
|
-
super().__init__(**kwargs)
|
|
35
36
|
|
|
36
37
|
def _load_task_vector(
|
|
37
38
|
self,
|
|
@@ -64,10 +64,12 @@ class SimpleAverageAlgorithm(
|
|
|
64
64
|
SimpleProfilerMixin,
|
|
65
65
|
BaseAlgorithm,
|
|
66
66
|
):
|
|
67
|
-
def __init__(self, show_pbar: bool = False, **kwargs):
|
|
67
|
+
def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
|
|
68
68
|
"""
|
|
69
69
|
Args:
|
|
70
70
|
show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
|
|
71
|
+
inplace (bool): If True, overwrites the weights of the first model in the model pool.
|
|
72
|
+
If False, creates a new model for the merged weights. Default is True.
|
|
71
73
|
"""
|
|
72
74
|
super().__init__(**kwargs)
|
|
73
75
|
|
|
@@ -104,12 +106,12 @@ class SimpleAverageAlgorithm(
|
|
|
104
106
|
with self.profile("merge weights"):
|
|
105
107
|
if sd is None:
|
|
106
108
|
# Initialize the state dictionary with the first model's state dictionary
|
|
107
|
-
sd = model.state_dict(
|
|
108
|
-
forward_model = model
|
|
109
|
+
sd = model.state_dict()
|
|
110
|
+
forward_model = model if self.inplace else deepcopy(model)
|
|
109
111
|
else:
|
|
110
112
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
111
113
|
sd = state_dict_add(
|
|
112
|
-
sd, model.state_dict(
|
|
114
|
+
sd, model.state_dict(), show_pbar=self.show_pbar
|
|
113
115
|
)
|
|
114
116
|
with self.profile("merge weights"):
|
|
115
117
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
@@ -149,7 +149,10 @@ class TaskArithmeticAlgorithm(
|
|
|
149
149
|
)
|
|
150
150
|
with self.profile("merge weights"):
|
|
151
151
|
# scale the task vector
|
|
152
|
-
|
|
152
|
+
# here we keep the dtype when the elements of value are all zeros to avoid dtype mismatch
|
|
153
|
+
task_vector = state_dict_mul(
|
|
154
|
+
task_vector, self.config.scaling_factor, keep_dtype_when_zero=True
|
|
155
|
+
)
|
|
153
156
|
# add the task vector to the pretrained model
|
|
154
157
|
state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
|
|
155
158
|
|
|
@@ -111,6 +111,15 @@ class LightningFabricMixin:
|
|
|
111
111
|
"""
|
|
112
112
|
if self.fabric is not None and len(self.fabric._loggers) > 0:
|
|
113
113
|
log_dir = self.fabric.logger.log_dir
|
|
114
|
+
|
|
115
|
+
# Special handling for SwanLabLogger to get the correct log directory
|
|
116
|
+
if (
|
|
117
|
+
log_dir is None
|
|
118
|
+
and self.fabric.logger.__class__.__name__ == "SwanLabLogger"
|
|
119
|
+
):
|
|
120
|
+
log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
|
|
121
|
+
|
|
122
|
+
assert log_dir is not None, "log_dir should not be None"
|
|
114
123
|
if self.fabric.is_global_zero and not os.path.exists(log_dir):
|
|
115
124
|
os.makedirs(log_dir, exist_ok=True)
|
|
116
125
|
return log_dir
|
|
@@ -8,6 +8,14 @@ _import_structure = {
|
|
|
8
8
|
"base_pool": ["BaseModelPool"],
|
|
9
9
|
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
10
10
|
"clip_vision": ["CLIPVisionModelPool"],
|
|
11
|
+
"convnext_for_image_classification": [
|
|
12
|
+
"ConvNextForImageClassificationPool",
|
|
13
|
+
"load_transformers_convnext",
|
|
14
|
+
],
|
|
15
|
+
"dinov2_for_image_classification": [
|
|
16
|
+
"Dinov2ForImageClassificationPool",
|
|
17
|
+
"load_transformers_dinov2",
|
|
18
|
+
],
|
|
11
19
|
"nyuv2_modelpool": ["NYUv2ModelPool"],
|
|
12
20
|
"huggingface_automodel": ["AutoModelPool"],
|
|
13
21
|
"seq2seq_lm": ["Seq2SeqLMPool"],
|
|
@@ -18,7 +26,10 @@ _import_structure = {
|
|
|
18
26
|
"GPT2ForSequenceClassificationPool",
|
|
19
27
|
],
|
|
20
28
|
"seq_classification_lm": ["SequenceClassificationModelPool"],
|
|
21
|
-
"resnet_for_image_classification": [
|
|
29
|
+
"resnet_for_image_classification": [
|
|
30
|
+
"ResNetForImageClassificationPool",
|
|
31
|
+
"load_transformers_resnet",
|
|
32
|
+
],
|
|
22
33
|
}
|
|
23
34
|
|
|
24
35
|
|
|
@@ -26,6 +37,14 @@ if TYPE_CHECKING:
|
|
|
26
37
|
from .base_pool import BaseModelPool
|
|
27
38
|
from .causal_lm import CausalLMBackbonePool, CausalLMPool
|
|
28
39
|
from .clip_vision import CLIPVisionModelPool
|
|
40
|
+
from .convnext_for_image_classification import (
|
|
41
|
+
ConvNextForImageClassificationPool,
|
|
42
|
+
load_transformers_convnext,
|
|
43
|
+
)
|
|
44
|
+
from .dinov2_for_image_classification import (
|
|
45
|
+
Dinov2ForImageClassificationPool,
|
|
46
|
+
load_transformers_dinov2,
|
|
47
|
+
)
|
|
29
48
|
from .huggingface_automodel import AutoModelPool
|
|
30
49
|
from .huggingface_gpt2_classification import (
|
|
31
50
|
GPT2ForSequenceClassificationPool,
|
|
@@ -34,7 +53,10 @@ if TYPE_CHECKING:
|
|
|
34
53
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
35
54
|
from .openclip_vision import OpenCLIPVisionModelPool
|
|
36
55
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
37
|
-
from .resnet_for_image_classification import
|
|
56
|
+
from .resnet_for_image_classification import (
|
|
57
|
+
ResNetForImageClassificationPool,
|
|
58
|
+
load_transformers_resnet,
|
|
59
|
+
)
|
|
38
60
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
39
61
|
from .seq_classification_lm import SequenceClassificationModelPool
|
|
40
62
|
|
|
@@ -3,7 +3,7 @@ from copy import deepcopy
|
|
|
3
3
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from omegaconf import DictConfig
|
|
6
|
+
from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
@@ -52,6 +52,13 @@ class BaseModelPool(
|
|
|
52
52
|
):
|
|
53
53
|
if isinstance(models, List):
|
|
54
54
|
models = {str(model_idx): model for model_idx, model in enumerate(models)}
|
|
55
|
+
|
|
56
|
+
if isinstance(models, dict):
|
|
57
|
+
try: # try to convert to DictConfig
|
|
58
|
+
models = OmegaConf.create(models)
|
|
59
|
+
except UnsupportedValueType:
|
|
60
|
+
pass
|
|
61
|
+
|
|
55
62
|
self._models = models
|
|
56
63
|
self._train_datasets = train_datasets
|
|
57
64
|
self._val_datasets = val_datasets
|
|
@@ -8,6 +8,7 @@ from copy import deepcopy
|
|
|
8
8
|
from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
|
|
9
9
|
|
|
10
10
|
import peft
|
|
11
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
11
12
|
from omegaconf import DictConfig, OmegaConf, flag_override
|
|
12
13
|
from torch import nn
|
|
13
14
|
from torch.nn.modules import Module
|
|
@@ -342,7 +343,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
342
343
|
)
|
|
343
344
|
|
|
344
345
|
# Create and save model card if algorithm_config is provided
|
|
345
|
-
if algorithm_config is not None:
|
|
346
|
+
if algorithm_config is not None and rank_zero_only.rank == 0:
|
|
346
347
|
if description is None:
|
|
347
348
|
description = "Model created using FusionBench."
|
|
348
349
|
model_card_str = create_default_model_card(
|