fusion-bench 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/image_classification_finetune.py +168 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/__init__.py +24 -2
- fusion_bench/modelpool/base_pool.py +8 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
- fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
- fusion_bench/modelpool/resnet_for_image_classification.py +289 -5
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/models/model_card_templates/default.md +1 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/json.py +49 -8
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/METADATA +2 -2
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/RECORD +124 -62
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +5 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
- fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hugging Face ConvNeXt image classification model pool.
|
|
3
|
+
|
|
4
|
+
This module provides a `BaseModelPool` implementation that loads and saves
|
|
5
|
+
ConvNeXt models for image classification via `transformers`. It optionally
|
|
6
|
+
reconfigures the classification head to match a dataset's class names and
|
|
7
|
+
overrides `forward` to return logits only for simpler downstream usage.
|
|
8
|
+
|
|
9
|
+
See also: `fusion_bench.modelpool.resnet_for_image_classification` for a
|
|
10
|
+
parallel implementation for ResNet-based classifiers.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
from typing import (
|
|
15
|
+
TYPE_CHECKING,
|
|
16
|
+
Any,
|
|
17
|
+
Callable,
|
|
18
|
+
Dict,
|
|
19
|
+
Literal,
|
|
20
|
+
Optional,
|
|
21
|
+
TypeVar,
|
|
22
|
+
Union,
|
|
23
|
+
override,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
28
|
+
from omegaconf import DictConfig
|
|
29
|
+
from torch import nn
|
|
30
|
+
|
|
31
|
+
from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
|
|
32
|
+
from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
|
|
33
|
+
|
|
34
|
+
log = get_rankzero_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_transformers_convnext(
|
|
38
|
+
config_path: str, pretrained: bool, dataset_name: Optional[str]
|
|
39
|
+
):
|
|
40
|
+
"""Create a ConvNeXt image classification model from a config or checkpoint.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config_path: A model identifier or local path understood by
|
|
44
|
+
`transformers.AutoConfig/AutoModel` (e.g., "facebook/convnext-base-224").
|
|
45
|
+
pretrained: If True, load weights via `from_pretrained`; otherwise, build
|
|
46
|
+
the model from config only.
|
|
47
|
+
dataset_name: Optional dataset key used by FusionBench to derive class
|
|
48
|
+
names via `get_classnames`. When provided, the model's id/label maps
|
|
49
|
+
are updated and the classifier head is resized accordingly.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
ConvNextForImageClassification: A `transformers.ConvNextForImageClassification` instance. If
|
|
53
|
+
`dataset_name` is set, the classifier head is adapted to the number of
|
|
54
|
+
classes. The model's `config.id2label` and `config.label2id` are also
|
|
55
|
+
populated.
|
|
56
|
+
|
|
57
|
+
Notes:
|
|
58
|
+
The overall structure mirrors the ResNet implementation in
|
|
59
|
+
`fusion_bench.modelpool.resnet_for_image_classification`.
|
|
60
|
+
"""
|
|
61
|
+
from transformers import AutoConfig, ConvNextForImageClassification
|
|
62
|
+
|
|
63
|
+
if pretrained:
|
|
64
|
+
model = ConvNextForImageClassification.from_pretrained(config_path)
|
|
65
|
+
else:
|
|
66
|
+
config = AutoConfig.from_pretrained(config_path)
|
|
67
|
+
model = ConvNextForImageClassification(config)
|
|
68
|
+
|
|
69
|
+
if dataset_name is None:
|
|
70
|
+
return model
|
|
71
|
+
|
|
72
|
+
classnames = get_classnames(dataset_name)
|
|
73
|
+
id2label = {i: c for i, c in enumerate(classnames)}
|
|
74
|
+
label2id = {c: i for i, c in enumerate(classnames)}
|
|
75
|
+
model.config.id2label = id2label
|
|
76
|
+
model.config.label2id = label2id
|
|
77
|
+
model.num_labels = model.config.num_labels
|
|
78
|
+
|
|
79
|
+
model.classifier = (
|
|
80
|
+
nn.Linear(
|
|
81
|
+
model.classifier.in_features,
|
|
82
|
+
len(classnames),
|
|
83
|
+
device=model.classifier.weight.device,
|
|
84
|
+
dtype=model.classifier.weight.dtype,
|
|
85
|
+
)
|
|
86
|
+
if model.config.num_labels > 0
|
|
87
|
+
else nn.Identity()
|
|
88
|
+
)
|
|
89
|
+
return model
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@auto_register_config
|
|
93
|
+
class ConvNextForImageClassificationPool(BaseModelPool):
|
|
94
|
+
"""Model pool for ConvNeXt image classification models (HF Transformers).
|
|
95
|
+
|
|
96
|
+
Responsibilities:
|
|
97
|
+
- Load an `AutoImageProcessor` compatible with the configured ConvNeXt model.
|
|
98
|
+
- Load ConvNeXt models either from a pretrained checkpoint or from config.
|
|
99
|
+
- Optionally adapt the classifier head to match dataset classnames.
|
|
100
|
+
- Override `forward` to return logits for consistent interfaces within
|
|
101
|
+
FusionBench.
|
|
102
|
+
|
|
103
|
+
See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
|
|
104
|
+
related ResNet-based pool with analogous behavior.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def load_processor(self, *args, **kwargs):
|
|
108
|
+
from transformers import AutoImageProcessor
|
|
109
|
+
|
|
110
|
+
if self.has_pretrained:
|
|
111
|
+
config_path = self._models["_pretrained_"].config_path
|
|
112
|
+
else:
|
|
113
|
+
for model_cfg in self._models.values():
|
|
114
|
+
if isinstance(model_cfg, str):
|
|
115
|
+
config_path = model_cfg
|
|
116
|
+
break
|
|
117
|
+
if "config_path" in model_cfg:
|
|
118
|
+
config_path = model_cfg["config_path"]
|
|
119
|
+
break
|
|
120
|
+
return AutoImageProcessor.from_pretrained(config_path)
|
|
121
|
+
|
|
122
|
+
@override
|
|
123
|
+
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
|
|
124
|
+
"""Load a ConvNeXt model described by a name, path, or DictConfig.
|
|
125
|
+
|
|
126
|
+
Accepts either a string (pretrained identifier or local path) or a
|
|
127
|
+
config mapping with keys: `config_path`, optional `pretrained` (bool),
|
|
128
|
+
and optional `dataset_name` to resize the classifier.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A model whose `forward` is wrapped to return only logits to align
|
|
132
|
+
with FusionBench expectations.
|
|
133
|
+
"""
|
|
134
|
+
log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
135
|
+
if (
|
|
136
|
+
isinstance(model_name_or_config, str)
|
|
137
|
+
and model_name_or_config in self._models
|
|
138
|
+
):
|
|
139
|
+
model_name_or_config = self._models[model_name_or_config]
|
|
140
|
+
|
|
141
|
+
match model_name_or_config:
|
|
142
|
+
case str() as model_path:
|
|
143
|
+
from transformers import AutoModelForImageClassification
|
|
144
|
+
|
|
145
|
+
model = AutoModelForImageClassification.from_pretrained(model_path)
|
|
146
|
+
case dict() | DictConfig() as model_config:
|
|
147
|
+
model = load_transformers_convnext(
|
|
148
|
+
model_config["config_path"],
|
|
149
|
+
pretrained=model_config.get("pretrained", True),
|
|
150
|
+
dataset_name=model_config.get("dataset_name", None),
|
|
151
|
+
)
|
|
152
|
+
case _:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# override forward to return logits only
|
|
158
|
+
original_forward = model.forward
|
|
159
|
+
model.forward = lambda pixel_values, **kwargs: original_forward(
|
|
160
|
+
pixel_values=pixel_values, **kwargs
|
|
161
|
+
).logits
|
|
162
|
+
model.original_forward = original_forward
|
|
163
|
+
|
|
164
|
+
return model
|
|
165
|
+
|
|
166
|
+
@override
|
|
167
|
+
def save_model(
|
|
168
|
+
self,
|
|
169
|
+
model,
|
|
170
|
+
path,
|
|
171
|
+
algorithm_config: Optional[DictConfig] = None,
|
|
172
|
+
description: Optional[str] = None,
|
|
173
|
+
base_model: Optional[str] = None,
|
|
174
|
+
*args,
|
|
175
|
+
**kwargs,
|
|
176
|
+
):
|
|
177
|
+
"""Save the model, processor, and an optional model card to disk.
|
|
178
|
+
|
|
179
|
+
Artifacts written to `path`:
|
|
180
|
+
- The ConvNeXt model via `model.save_pretrained`.
|
|
181
|
+
- The paired image processor via `AutoImageProcessor.save_pretrained`.
|
|
182
|
+
- If `algorithm_config` is provided and on rank-zero, a README model card
|
|
183
|
+
documenting the FusionBench configuration.
|
|
184
|
+
"""
|
|
185
|
+
model.save_pretrained(path)
|
|
186
|
+
self.load_processor().save_pretrained(path)
|
|
187
|
+
|
|
188
|
+
if algorithm_config is not None and rank_zero_only.rank == 0:
|
|
189
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
190
|
+
|
|
191
|
+
model_card_str = create_default_model_card(
|
|
192
|
+
algorithm_config=algorithm_config,
|
|
193
|
+
description=description,
|
|
194
|
+
modelpool_config=self.config,
|
|
195
|
+
base_model=base_model,
|
|
196
|
+
)
|
|
197
|
+
with open(os.path.join(path, "README.md"), "w") as f:
|
|
198
|
+
f.write(model_card_str)
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hugging Face DINOv2 image classification model pool.
|
|
3
|
+
|
|
4
|
+
This module provides a `BaseModelPool` implementation that loads and saves
|
|
5
|
+
DINOv2 models for image classification via `transformers`. It optionally
|
|
6
|
+
reconfigures the classification head to match a dataset's class names and
|
|
7
|
+
overrides `forward` to return logits only for simpler downstream usage.
|
|
8
|
+
|
|
9
|
+
See also: `fusion_bench.modelpool.convnext_for_image_classification` for a
|
|
10
|
+
parallel implementation for ConvNeXt-based classifiers.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
from typing import (
|
|
15
|
+
TYPE_CHECKING,
|
|
16
|
+
Any,
|
|
17
|
+
Callable,
|
|
18
|
+
Dict,
|
|
19
|
+
Literal,
|
|
20
|
+
Optional,
|
|
21
|
+
TypeVar,
|
|
22
|
+
Union,
|
|
23
|
+
override,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
28
|
+
from omegaconf import DictConfig
|
|
29
|
+
from torch import nn
|
|
30
|
+
|
|
31
|
+
from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
|
|
32
|
+
from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
|
|
33
|
+
|
|
34
|
+
log = get_rankzero_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_transformers_dinov2(
|
|
38
|
+
config_path: str, pretrained: bool, dataset_name: Optional[str]
|
|
39
|
+
):
|
|
40
|
+
"""Create a DINOv2 image classification model from a config or checkpoint.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config_path: A model identifier or local path understood by
|
|
44
|
+
`transformers.AutoConfig/AutoModel` (e.g., "facebook/dinov2-base").
|
|
45
|
+
pretrained: If True, load weights via `from_pretrained`; otherwise, build
|
|
46
|
+
the model from config only.
|
|
47
|
+
dataset_name: Optional dataset key used by FusionBench to derive class
|
|
48
|
+
names via `get_classnames`. When provided, the model's id/label maps
|
|
49
|
+
are updated and the classifier head is resized accordingly.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Dinov2ForImageClassification: A `transformers.Dinov2ForImageClassification` instance. If
|
|
53
|
+
`dataset_name` is set, the classifier head is adapted to the number of
|
|
54
|
+
classes. The model's `config.id2label` and `config.label2id` are also
|
|
55
|
+
populated.
|
|
56
|
+
|
|
57
|
+
Notes:
|
|
58
|
+
The overall structure mirrors the ConvNeXt implementation in
|
|
59
|
+
`fusion_bench.modelpool.convnext_for_image_classification`.
|
|
60
|
+
"""
|
|
61
|
+
from transformers import AutoConfig, Dinov2ForImageClassification
|
|
62
|
+
|
|
63
|
+
if pretrained:
|
|
64
|
+
model = Dinov2ForImageClassification.from_pretrained(config_path)
|
|
65
|
+
else:
|
|
66
|
+
config = AutoConfig.from_pretrained(config_path)
|
|
67
|
+
model = Dinov2ForImageClassification(config)
|
|
68
|
+
|
|
69
|
+
if dataset_name is None:
|
|
70
|
+
return model
|
|
71
|
+
|
|
72
|
+
classnames = get_classnames(dataset_name)
|
|
73
|
+
id2label = {i: c for i, c in enumerate(classnames)}
|
|
74
|
+
label2id = {c: i for i, c in enumerate(classnames)}
|
|
75
|
+
model.config.id2label = id2label
|
|
76
|
+
model.config.label2id = label2id
|
|
77
|
+
model.num_labels = model.config.num_labels
|
|
78
|
+
|
|
79
|
+
# If the model is configured with a positive number of labels, resize the
|
|
80
|
+
# classifier to match the dataset classes; otherwise leave it as identity.
|
|
81
|
+
model.classifier = (
|
|
82
|
+
nn.Linear(
|
|
83
|
+
model.classifier.in_features,
|
|
84
|
+
len(classnames),
|
|
85
|
+
device=model.classifier.weight.device,
|
|
86
|
+
dtype=model.classifier.weight.dtype,
|
|
87
|
+
)
|
|
88
|
+
if model.config.num_labels > 0
|
|
89
|
+
else nn.Identity()
|
|
90
|
+
)
|
|
91
|
+
return model
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@auto_register_config
|
|
95
|
+
class Dinov2ForImageClassificationPool(BaseModelPool):
|
|
96
|
+
"""Model pool for DINOv2 image classification models (HF Transformers)."""
|
|
97
|
+
|
|
98
|
+
def load_processor(self, *args, **kwargs):
|
|
99
|
+
"""Load the paired image processor for this model pool.
|
|
100
|
+
|
|
101
|
+
Uses the configured model's identifier or config path to retrieve the
|
|
102
|
+
appropriate `transformers.AutoImageProcessor` instance. If a pretrained
|
|
103
|
+
model entry exists in the pool configuration, it is preferred to derive
|
|
104
|
+
the processor to ensure tokenization/normalization parity.
|
|
105
|
+
"""
|
|
106
|
+
from transformers import AutoImageProcessor
|
|
107
|
+
|
|
108
|
+
if self.has_pretrained:
|
|
109
|
+
config_path = self._models["_pretrained_"].config_path
|
|
110
|
+
else:
|
|
111
|
+
for model_cfg in self._models.values():
|
|
112
|
+
if isinstance(model_cfg, str):
|
|
113
|
+
config_path = model_cfg
|
|
114
|
+
break
|
|
115
|
+
if "config_path" in model_cfg:
|
|
116
|
+
config_path = model_cfg["config_path"]
|
|
117
|
+
break
|
|
118
|
+
return AutoImageProcessor.from_pretrained(config_path)
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
|
|
122
|
+
"""Load a DINOv2 model described by a name, path, or DictConfig.
|
|
123
|
+
|
|
124
|
+
Accepts either a string (pretrained identifier or local path) or a
|
|
125
|
+
config mapping with keys: `config_path`, optional `pretrained` (bool),
|
|
126
|
+
and optional `dataset_name` to resize the classifier.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A model whose `forward` is wrapped to return only logits to align
|
|
130
|
+
with FusionBench expectations.
|
|
131
|
+
"""
|
|
132
|
+
log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
133
|
+
if (
|
|
134
|
+
isinstance(model_name_or_config, str)
|
|
135
|
+
and model_name_or_config in self._models
|
|
136
|
+
):
|
|
137
|
+
model_name_or_config = self._models[model_name_or_config]
|
|
138
|
+
|
|
139
|
+
match model_name_or_config:
|
|
140
|
+
case str() as model_path:
|
|
141
|
+
from transformers import AutoModelForImageClassification
|
|
142
|
+
|
|
143
|
+
model = AutoModelForImageClassification.from_pretrained(model_path)
|
|
144
|
+
case dict() | DictConfig() as model_config:
|
|
145
|
+
model = load_transformers_dinov2(
|
|
146
|
+
model_config["config_path"],
|
|
147
|
+
pretrained=model_config.get("pretrained", True),
|
|
148
|
+
dataset_name=model_config.get("dataset_name", None),
|
|
149
|
+
)
|
|
150
|
+
case _:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Override forward to return logits only, to unify the interface across
|
|
156
|
+
# FusionBench model pools and simplify downstream usage.
|
|
157
|
+
original_forward = model.forward
|
|
158
|
+
model.forward = lambda pixel_values, **kwargs: original_forward(
|
|
159
|
+
pixel_values=pixel_values, **kwargs
|
|
160
|
+
).logits
|
|
161
|
+
model.original_forward = original_forward
|
|
162
|
+
|
|
163
|
+
return model
|
|
164
|
+
|
|
165
|
+
@override
|
|
166
|
+
def save_model(
|
|
167
|
+
self,
|
|
168
|
+
model,
|
|
169
|
+
path,
|
|
170
|
+
algorithm_config: Optional[DictConfig] = None,
|
|
171
|
+
description: Optional[str] = None,
|
|
172
|
+
base_model: Optional[str] = None,
|
|
173
|
+
*args,
|
|
174
|
+
**kwargs,
|
|
175
|
+
):
|
|
176
|
+
"""Save the model, processor, and an optional model card to disk.
|
|
177
|
+
|
|
178
|
+
Artifacts written to `path`:
|
|
179
|
+
- The DINOv2 model via `model.save_pretrained`.
|
|
180
|
+
- The paired image processor via `AutoImageProcessor.save_pretrained`.
|
|
181
|
+
- If `algorithm_config` is provided and on rank-zero, a README model card
|
|
182
|
+
documenting the FusionBench configuration.
|
|
183
|
+
"""
|
|
184
|
+
model.save_pretrained(path)
|
|
185
|
+
self.load_processor().save_pretrained(path)
|
|
186
|
+
|
|
187
|
+
if algorithm_config is not None and rank_zero_only.rank == 0:
|
|
188
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
189
|
+
|
|
190
|
+
model_card_str = create_default_model_card(
|
|
191
|
+
algorithm_config=algorithm_config,
|
|
192
|
+
description=description,
|
|
193
|
+
modelpool_config=self.config,
|
|
194
|
+
base_model=base_model,
|
|
195
|
+
)
|
|
196
|
+
with open(os.path.join(path, "README.md"), "w") as f:
|
|
197
|
+
f.write(model_card_str)
|