fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/clip_finetune.py +6 -4
- fusion_bench/method/classification/image_classification_finetune.py +156 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/dop/__init__.py +1 -0
- fusion_bench/method/dop/dop.py +366 -0
- fusion_bench/method/dop/min_norm_solvers.py +227 -0
- fusion_bench/method/dop/utils.py +73 -0
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/depth_upscaling.yaml +9 -0
- fusion_bench_config/method/dop/dop.yaml +30 -0
- fusion_bench_config/method/dummy.yaml +6 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average.yaml +3 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/model_recombination.yaml +8 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
- fusion_bench_config/method/opcm/opcm.yaml +5 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
- fusion_bench_config/method/opcm/weight_average.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/simple_average.yaml +9 -0
- fusion_bench_config/method/slerp/slerp.yaml +9 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
- fusion_bench_config/method/task_arithmetic.yaml +9 -0
- fusion_bench_config/method/ties_merging.yaml +3 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,34 @@
|
|
|
1
|
+
"""ResNet Model Pool for Image Classification.
|
|
2
|
+
|
|
3
|
+
This module provides a flexible model pool implementation for ResNet models used in image
|
|
4
|
+
classification tasks. It supports both torchvision and transformers implementations of ResNet
|
|
5
|
+
architectures with configurable preprocessing, loading, and saving capabilities.
|
|
6
|
+
|
|
7
|
+
Example Usage:
|
|
8
|
+
Create a pool with a torchvision ResNet model:
|
|
9
|
+
|
|
10
|
+
```python
|
|
11
|
+
>>> # Torchvision ResNet pool
|
|
12
|
+
>>> pool = ResNetForImageClassificationPool(
|
|
13
|
+
... type="torchvision",
|
|
14
|
+
... models={"resnet18_cifar10": {"model_name": "resnet18", "dataset_name": "cifar10"}}
|
|
15
|
+
... )
|
|
16
|
+
>>> model = pool.load_model("resnet18_cifar10")
|
|
17
|
+
>>> processor = pool.load_processor(stage="train")
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Create a pool with a transformers ResNet model:
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
>>> # Transformers ResNet pool
|
|
24
|
+
>>> pool = ResNetForImageClassificationPool(
|
|
25
|
+
... type="transformers",
|
|
26
|
+
... models={"resnet_model": {"config_path": "microsoft/resnet-50", "pretrained": True}}
|
|
27
|
+
... )
|
|
28
|
+
```
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import os
|
|
1
32
|
from typing import (
|
|
2
33
|
TYPE_CHECKING,
|
|
3
34
|
Any,
|
|
@@ -11,6 +42,7 @@ from typing import (
|
|
|
11
42
|
)
|
|
12
43
|
|
|
13
44
|
import torch
|
|
45
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
14
46
|
from omegaconf import DictConfig
|
|
15
47
|
from torch import nn
|
|
16
48
|
|
|
@@ -26,6 +58,31 @@ log = get_rankzero_logger(__name__)
|
|
|
26
58
|
def load_torchvision_resnet(
|
|
27
59
|
model_name: str, weights: Optional[str], num_classes: Optional[int]
|
|
28
60
|
) -> "TorchVisionResNet":
|
|
61
|
+
"""Load a ResNet model from torchvision with optional custom classifier head.
|
|
62
|
+
|
|
63
|
+
This function creates a ResNet model using torchvision's model zoo and optionally
|
|
64
|
+
replaces the final classification layer to match the required number of classes.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
model_name (str): Name of the ResNet model to load (e.g., 'resnet18', 'resnet50').
|
|
68
|
+
Must be a valid torchvision model name.
|
|
69
|
+
weights (Optional[str]): Pretrained weights to load. Can be 'DEFAULT', 'IMAGENET1K_V1',
|
|
70
|
+
or None for random initialization. See torchvision documentation for available options.
|
|
71
|
+
num_classes (Optional[int]): Number of output classes. If provided, replaces the final
|
|
72
|
+
fully connected layer. If None, keeps the original classifier (typically 1000 classes).
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
TorchVisionResNet: The loaded ResNet model with appropriate classifier head.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
AttributeError: If model_name is not a valid torchvision model.
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
```python
|
|
82
|
+
>>> model = load_torchvision_resnet("resnet18", "DEFAULT", 10) # CIFAR-10
|
|
83
|
+
>>> model = load_torchvision_resnet("resnet50", None, 100) # Random init, 100 classes
|
|
84
|
+
```
|
|
85
|
+
"""
|
|
29
86
|
import torchvision.models
|
|
30
87
|
|
|
31
88
|
model_fn = getattr(torchvision.models, model_name)
|
|
@@ -40,6 +97,31 @@ def load_torchvision_resnet(
|
|
|
40
97
|
def load_transformers_resnet(
|
|
41
98
|
config_path: str, pretrained: bool, dataset_name: Optional[str]
|
|
42
99
|
):
|
|
100
|
+
"""Load a ResNet model from transformers with optional dataset-specific adaptation.
|
|
101
|
+
|
|
102
|
+
This function creates a ResNet model using the transformers library and optionally
|
|
103
|
+
adapts it for a specific dataset by updating the classifier head and label mappings.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
config_path (str): Path or identifier for the model configuration. Can be a local path
|
|
107
|
+
or a Hugging Face model identifier (e.g., 'microsoft/resnet-50').
|
|
108
|
+
pretrained (bool): Whether to load pretrained weights. If True, loads from the
|
|
109
|
+
specified config_path. If False, initializes with random weights using the config.
|
|
110
|
+
dataset_name (Optional[str]): Name of the target dataset for adaptation. If provided,
|
|
111
|
+
updates the model's classifier and label mappings to match the dataset's classes.
|
|
112
|
+
If None, keeps the original model configuration.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
ResNetForImageClassification: The loaded and optionally adapted ResNet model.
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
```python
|
|
119
|
+
>>> # Load pretrained model adapted for CIFAR-10
|
|
120
|
+
>>> model = load_transformers_resnet("microsoft/resnet-50", True, "cifar10")
|
|
121
|
+
>>> # Load random initialized model with default classes
|
|
122
|
+
>>> model = load_transformers_resnet("microsoft/resnet-50", False, None)
|
|
123
|
+
```
|
|
124
|
+
"""
|
|
43
125
|
from transformers import AutoConfig, ResNetForImageClassification
|
|
44
126
|
|
|
45
127
|
if pretrained:
|
|
@@ -70,13 +152,107 @@ def load_transformers_resnet(
|
|
|
70
152
|
|
|
71
153
|
@auto_register_config
|
|
72
154
|
class ResNetForImageClassificationPool(BaseModelPool):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
155
|
+
"""Model pool for ResNet-based image classification models.
|
|
156
|
+
|
|
157
|
+
This class provides a unified interface for managing ResNet models from different sources
|
|
158
|
+
(torchvision and transformers) with automatic preprocessing, loading, and saving capabilities.
|
|
159
|
+
It supports multiple ResNet architectures and can automatically adapt models to different
|
|
160
|
+
datasets by adjusting the number of output classes.
|
|
161
|
+
|
|
162
|
+
The pool supports two main types:
|
|
163
|
+
- "torchvision": Uses torchvision's ResNet implementations with standard ImageNet preprocessing
|
|
164
|
+
- "transformers": Uses Hugging Face transformers' ResNetForImageClassification with auto processors
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
type (str): Model source type, must be either "torchvision" or "transformers".
|
|
168
|
+
**kwargs: Additional arguments passed to the base BaseModelPool class.
|
|
169
|
+
|
|
170
|
+
Attributes:
|
|
171
|
+
type (str): The model source type specified during initialization.
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
AssertionError: If type is not "torchvision" or "transformers".
|
|
175
|
+
|
|
176
|
+
Example:
|
|
177
|
+
Create a pool with a torchvision ResNet model:
|
|
178
|
+
|
|
179
|
+
```python
|
|
180
|
+
>>> # Torchvision-based pool
|
|
181
|
+
>>> pool = ResNetForImageClassificationPool(
|
|
182
|
+
... type="torchvision",
|
|
183
|
+
... models={
|
|
184
|
+
... "resnet18_cifar10": {
|
|
185
|
+
... "model_name": "resnet18",
|
|
186
|
+
... "weights": "DEFAULT",
|
|
187
|
+
... "dataset_name": "cifar10"
|
|
188
|
+
... }
|
|
189
|
+
... }
|
|
190
|
+
... )
|
|
191
|
+
```
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
Create a pool with a transformers ResNet model:
|
|
195
|
+
|
|
196
|
+
```python
|
|
197
|
+
>>> # Transformers-based pool
|
|
198
|
+
>>> pool = ResNetForImageClassificationPool(
|
|
199
|
+
... type="transformers",
|
|
200
|
+
... models={
|
|
201
|
+
... "resnet_model": {
|
|
202
|
+
... "config_path": "microsoft/resnet-50",
|
|
203
|
+
... "pretrained": True,
|
|
204
|
+
... "dataset_name": "imagenet"
|
|
205
|
+
... }
|
|
206
|
+
... }
|
|
207
|
+
... )
|
|
208
|
+
```
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(self, models, type: str, **kwargs):
|
|
212
|
+
super().__init__(models=models, **kwargs)
|
|
213
|
+
assert type in [
|
|
214
|
+
"torchvision",
|
|
215
|
+
"transformers",
|
|
216
|
+
], "type must be either 'torchvision' or 'transformers'"
|
|
76
217
|
|
|
77
218
|
def load_processor(
|
|
78
219
|
self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
|
|
79
220
|
):
|
|
221
|
+
"""Load the appropriate image processor/transform for the specified training stage.
|
|
222
|
+
|
|
223
|
+
Creates stage-specific image preprocessing pipelines optimized for the model type:
|
|
224
|
+
|
|
225
|
+
For torchvision models:
|
|
226
|
+
- Train stage: Includes data augmentation (random resize crop, horizontal flip)
|
|
227
|
+
- Val/test stages: Standard preprocessing (resize, center crop) without augmentation
|
|
228
|
+
- All stages: Apply ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
229
|
+
|
|
230
|
+
For transformers models:
|
|
231
|
+
- Uses AutoImageProcessor from the pretrained model configuration
|
|
232
|
+
- Automatically handles model-specific preprocessing requirements
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
stage (Literal["train", "val", "test"]): The training stage determining preprocessing type.
|
|
236
|
+
- "train": Applies data augmentation for training
|
|
237
|
+
- "val"/"test": Uses standard preprocessing for evaluation
|
|
238
|
+
*args: Additional positional arguments (unused).
|
|
239
|
+
**kwargs: Additional keyword arguments (unused).
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Union[transforms.Compose, AutoImageProcessor]: The image processor/transform pipeline
|
|
243
|
+
appropriate for the specified stage and model type.
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
ValueError: If no valid config_path can be found for transformers models.
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
```python
|
|
250
|
+
>>> # Get training transforms for torchvision model
|
|
251
|
+
>>> train_transform = pool.load_processor(stage="train")
|
|
252
|
+
>>> # Get evaluation processor for transformers model
|
|
253
|
+
>>> eval_processor = pool.load_processor(stage="test")
|
|
254
|
+
```
|
|
255
|
+
"""
|
|
80
256
|
if self.type == "torchvision":
|
|
81
257
|
from torchvision import transforms
|
|
82
258
|
|
|
@@ -122,6 +298,58 @@ class ResNetForImageClassificationPool(BaseModelPool):
|
|
|
122
298
|
|
|
123
299
|
@override
|
|
124
300
|
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
|
|
301
|
+
"""Load a ResNet model based on the provided configuration or model name.
|
|
302
|
+
|
|
303
|
+
This method supports flexible model loading from different sources and configurations:
|
|
304
|
+
- Direct model names (e.g., "resnet18", "resnet50") for standard architectures
|
|
305
|
+
- Model pool keys that map to configurations
|
|
306
|
+
- Dictionary/DictConfig objects with detailed model specifications
|
|
307
|
+
- Hugging Face model identifiers for transformers models
|
|
308
|
+
|
|
309
|
+
For torchvision models, supports:
|
|
310
|
+
- Standard ResNet architectures: resnet18, resnet34, resnet50, resnet101, resnet152
|
|
311
|
+
- Custom configurations with model_name, weights, and num_classes specifications
|
|
312
|
+
- Automatic dataset adaptation with class number inference
|
|
313
|
+
|
|
314
|
+
For transformers models:
|
|
315
|
+
- Loading from Hugging Face Hub or local paths
|
|
316
|
+
- Pretrained or randomly initialized models
|
|
317
|
+
- Automatic logits extraction by overriding forward method
|
|
318
|
+
- Dataset-specific label mapping configuration
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
model_name_or_config (Union[str, DictConfig]): Model specification that can be:
|
|
322
|
+
- A string model name (e.g., "resnet18") for standard architectures
|
|
323
|
+
- A model pool key referencing a stored configuration
|
|
324
|
+
- A dict/DictConfig with model parameters like:
|
|
325
|
+
* For torchvision: {"model_name": "resnet18", "weights": "DEFAULT", "num_classes": 10}
|
|
326
|
+
* For transformers: {"config_path": "microsoft/resnet-50", "pretrained": True, "dataset_name": "cifar10"}
|
|
327
|
+
*args: Additional positional arguments (unused).
|
|
328
|
+
**kwargs: Additional keyword arguments (unused).
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Union[TorchVisionResNet, ResNetForImageClassification]: The loaded ResNet model
|
|
332
|
+
configured for the specified task. For transformers models, the forward method
|
|
333
|
+
is modified to return logits directly instead of the full model output.
|
|
334
|
+
|
|
335
|
+
Raises:
|
|
336
|
+
ValueError: If model_name_or_config type is invalid or if model type is unknown.
|
|
337
|
+
AssertionError: If num_classes from dataset doesn't match explicit num_classes specification.
|
|
338
|
+
|
|
339
|
+
Example:
|
|
340
|
+
```python
|
|
341
|
+
>>> # Load standard torchvision model
|
|
342
|
+
>>> model = pool.load_model("resnet18")
|
|
343
|
+
|
|
344
|
+
>>> # Load with custom configuration
|
|
345
|
+
>>> config = {"model_name": "resnet50", "weights": "DEFAULT", "dataset_name": "cifar10"}
|
|
346
|
+
>>> model = pool.load_model(config)
|
|
347
|
+
|
|
348
|
+
>>> # Load transformers model
|
|
349
|
+
>>> config = {"config_path": "microsoft/resnet-50", "pretrained": True}
|
|
350
|
+
>>> model = pool.load_model(config)
|
|
351
|
+
```
|
|
352
|
+
"""
|
|
125
353
|
log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
126
354
|
if (
|
|
127
355
|
isinstance(model_name_or_config, str)
|
|
@@ -198,11 +426,64 @@ class ResNetForImageClassificationPool(BaseModelPool):
|
|
|
198
426
|
return model
|
|
199
427
|
|
|
200
428
|
@override
|
|
201
|
-
def save_model(
|
|
429
|
+
def save_model(
|
|
430
|
+
self,
|
|
431
|
+
model,
|
|
432
|
+
path,
|
|
433
|
+
algorithm_config: Optional[DictConfig] = None,
|
|
434
|
+
description: Optional[str] = None,
|
|
435
|
+
*args,
|
|
436
|
+
**kwargs,
|
|
437
|
+
):
|
|
438
|
+
"""Save a ResNet model to the specified path using the appropriate format.
|
|
439
|
+
|
|
440
|
+
This method handles model saving based on the model pool type:
|
|
441
|
+
- For torchvision models: Saves only the state_dict using torch.save()
|
|
442
|
+
- For transformers models: Saves the complete model and processor using save_pretrained()
|
|
443
|
+
|
|
444
|
+
The saving format ensures compatibility with the corresponding loading mechanisms
|
|
445
|
+
and preserves all necessary components for model restoration.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
model: The ResNet model to save. Should be compatible with the pool's model type.
|
|
449
|
+
path (str): Destination path for saving the model. For torchvision models, this
|
|
450
|
+
should be a file path (e.g., "model.pth"). For transformers models, this
|
|
451
|
+
should be a directory path where model files will be stored.
|
|
452
|
+
*args: Additional positional arguments (unused).
|
|
453
|
+
**kwargs: Additional keyword arguments (unused).
|
|
454
|
+
|
|
455
|
+
Raises:
|
|
456
|
+
ValueError: If the model type is unknown or unsupported.
|
|
457
|
+
|
|
458
|
+
Note:
|
|
459
|
+
For transformers models, both the model weights and the associated image processor
|
|
460
|
+
are saved to ensure complete reproducibility of the preprocessing pipeline.
|
|
461
|
+
|
|
462
|
+
Example:
|
|
463
|
+
```python
|
|
464
|
+
>>> # Save torchvision model
|
|
465
|
+
>>> pool.save_model(model, "checkpoints/resnet18_cifar10.pth")
|
|
466
|
+
|
|
467
|
+
>>> # Save transformers model (saves to directory)
|
|
468
|
+
>>> pool.save_model(model, "checkpoints/resnet50_model/")
|
|
469
|
+
```
|
|
470
|
+
"""
|
|
202
471
|
if self.type == "torchvision":
|
|
472
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
203
473
|
torch.save(model.state_dict(), path)
|
|
204
474
|
elif self.type == "transformers":
|
|
205
475
|
model.save_pretrained(path)
|
|
206
476
|
self.load_processor().save_pretrained(path)
|
|
477
|
+
|
|
478
|
+
if algorithm_config is not None and rank_zero_only.rank == 0:
|
|
479
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
480
|
+
|
|
481
|
+
model_card_str = create_default_model_card(
|
|
482
|
+
algorithm_config=algorithm_config,
|
|
483
|
+
description=description,
|
|
484
|
+
modelpool_config=self.config,
|
|
485
|
+
)
|
|
486
|
+
with open(os.path.join(path, "README.md"), "w") as f:
|
|
487
|
+
f.write(model_card_str)
|
|
207
488
|
else:
|
|
208
489
|
raise ValueError(f"Unknown model type: {self.type}")
|
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional # noqa: F401
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor, nn
|
|
@@ -39,7 +39,6 @@ class HFCLIPClassifier(nn.Module):
|
|
|
39
39
|
self,
|
|
40
40
|
clip_model: CLIPModel,
|
|
41
41
|
processor: CLIPProcessor,
|
|
42
|
-
extra_module=None,
|
|
43
42
|
):
|
|
44
43
|
"""
|
|
45
44
|
Initialize the HFCLIPClassifier.
|
|
@@ -63,8 +62,6 @@ class HFCLIPClassifier(nn.Module):
|
|
|
63
62
|
persistent=False,
|
|
64
63
|
)
|
|
65
64
|
|
|
66
|
-
self.extra_module = extra_module
|
|
67
|
-
|
|
68
65
|
@property
|
|
69
66
|
def text_model(self):
|
|
70
67
|
"""Get the text model component of CLIP."""
|
|
@@ -123,9 +120,9 @@ class HFCLIPClassifier(nn.Module):
|
|
|
123
120
|
def forward(
|
|
124
121
|
self,
|
|
125
122
|
images: Tensor,
|
|
126
|
-
return_image_embeds=False,
|
|
127
|
-
return_dict=False,
|
|
128
|
-
task_name=None,
|
|
123
|
+
return_image_embeds: bool = False,
|
|
124
|
+
return_dict: bool = False,
|
|
125
|
+
task_name: Optional[str] = None,
|
|
129
126
|
):
|
|
130
127
|
"""
|
|
131
128
|
Perform forward pass for zero-shot image classification.
|
fusion_bench/models/hf_utils.py
CHANGED
|
@@ -142,7 +142,7 @@ def save_pretrained_with_remote_code(
|
|
|
142
142
|
|
|
143
143
|
|
|
144
144
|
def create_default_model_card(
|
|
145
|
-
models: list[str],
|
|
145
|
+
models: Optional[list[str]] = None,
|
|
146
146
|
base_model: Optional[str] = None,
|
|
147
147
|
title: str = "Deep Model Fusion",
|
|
148
148
|
tags: list[str] = ["fusion-bench", "merge"],
|
|
@@ -152,6 +152,9 @@ def create_default_model_card(
|
|
|
152
152
|
):
|
|
153
153
|
from jinja2 import Template
|
|
154
154
|
|
|
155
|
+
if models is None:
|
|
156
|
+
models = []
|
|
157
|
+
|
|
155
158
|
template: Template = Template(load_model_card_template("default.md"))
|
|
156
159
|
card = template.render(
|
|
157
160
|
base_model=base_model,
|
|
@@ -18,6 +18,7 @@ _import_structure = {
|
|
|
18
18
|
"lm_eval_harness": ["LMEvalHarnessTaskPool"],
|
|
19
19
|
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
20
20
|
"openclip_vision": ["OpenCLIPVisionModelTaskPool"],
|
|
21
|
+
"resnet_for_image_classification": ["ResNetForImageClassificationTaskPool"],
|
|
21
22
|
}
|
|
22
23
|
|
|
23
24
|
|
|
@@ -34,6 +35,7 @@ if TYPE_CHECKING:
|
|
|
34
35
|
from .lm_eval_harness import LMEvalHarnessTaskPool
|
|
35
36
|
from .nyuv2_taskpool import NYUv2TaskPool
|
|
36
37
|
from .openclip_vision import OpenCLIPVisionModelTaskPool
|
|
38
|
+
from .resnet_for_image_classification import ResNetForImageClassificationTaskPool
|
|
37
39
|
|
|
38
40
|
else:
|
|
39
41
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import (
|
|
5
|
+
TYPE_CHECKING,
|
|
6
|
+
Any,
|
|
7
|
+
Callable,
|
|
8
|
+
Dict,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
TypeVar,
|
|
12
|
+
Union,
|
|
13
|
+
override,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
import lightning as L
|
|
17
|
+
import torch
|
|
18
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
19
|
+
from omegaconf import DictConfig
|
|
20
|
+
from torch import Tensor, nn
|
|
21
|
+
from torch.nn import functional as F
|
|
22
|
+
from torch.utils.data import DataLoader
|
|
23
|
+
from torchmetrics import Accuracy, MeanMetric
|
|
24
|
+
from tqdm.auto import tqdm
|
|
25
|
+
|
|
26
|
+
from fusion_bench import (
|
|
27
|
+
BaseTaskPool,
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
RuntimeConstants,
|
|
30
|
+
auto_register_config,
|
|
31
|
+
get_rankzero_logger,
|
|
32
|
+
)
|
|
33
|
+
from fusion_bench.dataset import CLIPDataset
|
|
34
|
+
from fusion_bench.modelpool.resnet_for_image_classification import (
|
|
35
|
+
ResNetForImageClassificationPool,
|
|
36
|
+
load_torchvision_resnet,
|
|
37
|
+
load_transformers_resnet,
|
|
38
|
+
)
|
|
39
|
+
from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
|
|
40
|
+
from fusion_bench.utils import count_parameters
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from torchvision.models import ResNet as TorchVisionResNet
|
|
44
|
+
from transformers import ResNetForImageClassification
|
|
45
|
+
|
|
46
|
+
log = get_rankzero_logger(__name__)
|
|
47
|
+
|
|
48
|
+
__all__ = ["ResNetForImageClassificationTaskPool"]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@auto_register_config
|
|
52
|
+
class ResNetForImageClassificationTaskPool(
|
|
53
|
+
BaseTaskPool,
|
|
54
|
+
LightningFabricMixin,
|
|
55
|
+
ResNetForImageClassificationPool,
|
|
56
|
+
):
|
|
57
|
+
|
|
58
|
+
_is_setup = False
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
type: str,
|
|
63
|
+
test_datasets: DictConfig,
|
|
64
|
+
dataloader_kwargs: DictConfig,
|
|
65
|
+
processor_config_path: str,
|
|
66
|
+
**kwargs,
|
|
67
|
+
):
|
|
68
|
+
if type == "transformers":
|
|
69
|
+
super().__init__(
|
|
70
|
+
models=DictConfig(
|
|
71
|
+
{"_pretrained_": {"config_path": processor_config_path}}
|
|
72
|
+
),
|
|
73
|
+
type=type,
|
|
74
|
+
test_datasets=test_datasets,
|
|
75
|
+
**kwargs,
|
|
76
|
+
)
|
|
77
|
+
elif type == "torchvision":
|
|
78
|
+
super().__init__(type=type, test_datasets=test_datasets, **kwargs)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"Unknown ResNet type: {type}")
|
|
81
|
+
|
|
82
|
+
def setup(self):
|
|
83
|
+
processor = self.load_processor(stage="test")
|
|
84
|
+
|
|
85
|
+
# Load test datasets
|
|
86
|
+
test_datasets = {
|
|
87
|
+
ds_name: CLIPDataset(self.load_test_dataset(ds_name), processor=processor)
|
|
88
|
+
for ds_name in self._test_datasets
|
|
89
|
+
}
|
|
90
|
+
self.test_dataloaders = {
|
|
91
|
+
ds_name: self.fabric.setup_dataloaders(
|
|
92
|
+
self.get_dataloader(ds, stage="test")
|
|
93
|
+
)
|
|
94
|
+
for ds_name, ds in test_datasets.items()
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def _evaluate(
|
|
98
|
+
self,
|
|
99
|
+
classifier,
|
|
100
|
+
test_loader,
|
|
101
|
+
num_classes: int,
|
|
102
|
+
task_name: str = None,
|
|
103
|
+
):
|
|
104
|
+
classifier.eval()
|
|
105
|
+
accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
|
106
|
+
loss_metric = MeanMetric()
|
|
107
|
+
if RuntimeConstants.debug:
|
|
108
|
+
log.info("Running under fast_dev_run mode, evaluating on a single batch.")
|
|
109
|
+
test_loader = itertools.islice(test_loader, 1)
|
|
110
|
+
else:
|
|
111
|
+
test_loader = test_loader
|
|
112
|
+
|
|
113
|
+
pbar = tqdm(
|
|
114
|
+
test_loader,
|
|
115
|
+
desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
|
|
116
|
+
leave=False,
|
|
117
|
+
dynamic_ncols=True,
|
|
118
|
+
)
|
|
119
|
+
for batch in pbar:
|
|
120
|
+
inputs, targets = batch
|
|
121
|
+
outputs = classifier(inputs)
|
|
122
|
+
logits: Tensor = outputs["logits"]
|
|
123
|
+
if logits.device != targets.device:
|
|
124
|
+
targets = targets.to(logits.device)
|
|
125
|
+
|
|
126
|
+
loss = F.cross_entropy(logits, targets)
|
|
127
|
+
loss_metric.update(loss.detach().cpu())
|
|
128
|
+
acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
|
|
129
|
+
pbar.set_postfix(
|
|
130
|
+
{
|
|
131
|
+
"accuracy": accuracy.compute().item(),
|
|
132
|
+
"loss": loss_metric.compute().item(),
|
|
133
|
+
}
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
acc = accuracy.compute().item()
|
|
137
|
+
loss = loss_metric.compute().item()
|
|
138
|
+
results = {"accuracy": acc, "loss": loss}
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
def evaluate(
|
|
142
|
+
self,
|
|
143
|
+
model: Union["ResNetForImageClassification", "TorchVisionResNet"],
|
|
144
|
+
name: str = None,
|
|
145
|
+
**kwargs,
|
|
146
|
+
) -> Dict[str, Any]:
|
|
147
|
+
assert isinstance(
|
|
148
|
+
model, nn.Module
|
|
149
|
+
), f"Expected model to be an instance of nn.Module, but got {type(model)}"
|
|
150
|
+
|
|
151
|
+
if not self._is_setup:
|
|
152
|
+
self.setup()
|
|
153
|
+
|
|
154
|
+
classifier = self.fabric.to_device(model)
|
|
155
|
+
classifier.eval()
|
|
156
|
+
report = {}
|
|
157
|
+
# collect basic model information
|
|
158
|
+
training_params, all_params = count_parameters(model)
|
|
159
|
+
report["model_info"] = {
|
|
160
|
+
"trainable_params": training_params,
|
|
161
|
+
"all_params": all_params,
|
|
162
|
+
"trainable_percentage": training_params / all_params,
|
|
163
|
+
}
|
|
164
|
+
if name is not None:
|
|
165
|
+
report["model_info"]["name"] = name
|
|
166
|
+
|
|
167
|
+
# evaluate on each task
|
|
168
|
+
pbar = tqdm(
|
|
169
|
+
self.test_dataloaders.items(),
|
|
170
|
+
desc="Evaluating tasks",
|
|
171
|
+
total=len(self.test_dataloaders),
|
|
172
|
+
)
|
|
173
|
+
for task_name, test_dataloader in pbar:
|
|
174
|
+
num_classes = get_num_classes(task_name)
|
|
175
|
+
result = self._evaluate(
|
|
176
|
+
classifier,
|
|
177
|
+
test_dataloader,
|
|
178
|
+
num_classes=num_classes,
|
|
179
|
+
task_name=task_name,
|
|
180
|
+
)
|
|
181
|
+
report[task_name] = result
|
|
182
|
+
|
|
183
|
+
# calculate the average accuracy and loss
|
|
184
|
+
if "average" not in report:
|
|
185
|
+
report["average"] = {}
|
|
186
|
+
accuracies = [
|
|
187
|
+
value["accuracy"]
|
|
188
|
+
for key, value in report.items()
|
|
189
|
+
if "accuracy" in value
|
|
190
|
+
]
|
|
191
|
+
if len(accuracies) > 0:
|
|
192
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
193
|
+
report["average"]["accuracy"] = average_accuracy
|
|
194
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
195
|
+
if len(losses) > 0:
|
|
196
|
+
average_loss = sum(losses) / len(losses)
|
|
197
|
+
report["average"]["loss"] = average_loss
|
|
198
|
+
|
|
199
|
+
log.info(f"Evaluation Result: {report}")
|
|
200
|
+
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
|
|
201
|
+
save_path = os.path.join(self.log_dir, "report.json")
|
|
202
|
+
for version in itertools.count(1):
|
|
203
|
+
if not os.path.exists(save_path):
|
|
204
|
+
break
|
|
205
|
+
# if the file already exists, increment the version to avoid overwriting
|
|
206
|
+
save_path = os.path.join(self.log_dir, f"report_{version}.json")
|
|
207
|
+
with open(save_path, "w") as fp:
|
|
208
|
+
json.dump(report, fp)
|
|
209
|
+
log.info(f"Evaluation report saved to {save_path}")
|
|
210
|
+
return report
|
|
211
|
+
|
|
212
|
+
def get_dataloader(self, dataset, stage: str):
|
|
213
|
+
"""Create a DataLoader for the specified dataset and training stage.
|
|
214
|
+
|
|
215
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations:
|
|
216
|
+
- Training stage: shuffling enabled by default
|
|
217
|
+
- Validation/test stages: shuffling disabled by default
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
221
|
+
stage (str): Training stage, must be one of "train", "val", or "test".
|
|
222
|
+
Determines default shuffling behavior.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
226
|
+
"""
|
|
227
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
228
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
229
|
+
if "shuffle" not in dataloader_kwargs:
|
|
230
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
231
|
+
return DataLoader(dataset, **dataloader_kwargs)
|