fusion-bench 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/method/__init__.py +9 -1
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/mixins/lightning_fabric.py +2 -8
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +68 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +21 -16
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +165 -25
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +7 -7
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +41 -34
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
import torch
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoConfig,
|
|
10
|
+
AutoModelForCausalLM,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
PretrainedConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Metric(str, Enum):
|
|
17
|
+
"""Enumeration of supported metrics"""
|
|
18
|
+
|
|
19
|
+
PCC = "pcc"
|
|
20
|
+
ED = "ed"
|
|
21
|
+
CS = "cs"
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def list(cls) -> List[str]:
|
|
25
|
+
"""Return list of supported metric values"""
|
|
26
|
+
return [metric.value for metric in cls]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
|
|
30
|
+
"""
|
|
31
|
+
Fetch the configuration of a pretrained model from HuggingFace.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model (str): The name or path of the model to load configuration for.
|
|
35
|
+
trust_remote_code (bool, optional): Whether to trust remote code during loading.
|
|
36
|
+
Defaults to False.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
PretrainedConfig: The configuration object of the specified model.
|
|
40
|
+
"""
|
|
41
|
+
# Fetch the configuration from HuggingFace's model hub.
|
|
42
|
+
config = AutoConfig.from_pretrained(
|
|
43
|
+
model,
|
|
44
|
+
trust_remote_code=trust_remote_code, # Whether to allow remote code execution.
|
|
45
|
+
)
|
|
46
|
+
return config
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def validate_models(model_1: str, model_2: str, base_model: str) -> None:
|
|
50
|
+
"""
|
|
51
|
+
Validate model names to ensure they are different and exist.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_1: Name of the first model
|
|
55
|
+
model_2: Name of the second model
|
|
56
|
+
base_model: Name of the base model
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
click.BadParameter: If validation fails
|
|
60
|
+
"""
|
|
61
|
+
if model_1 == model_2 or model_1 == base_model or model_2 == base_model:
|
|
62
|
+
raise click.BadParameter("All model names must be different")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def quantize_8bit(x: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
# Get absolute min and max values
|
|
67
|
+
abs_max = torch.max(torch.abs(x))
|
|
68
|
+
|
|
69
|
+
# Scale to [-127, 127] range for 8-bit signed integers
|
|
70
|
+
# Using 127 instead of 128 to keep zero exactly representable
|
|
71
|
+
scaled = 127 * (x / abs_max)
|
|
72
|
+
|
|
73
|
+
# Round to nearest integer
|
|
74
|
+
quantized = torch.round(scaled)
|
|
75
|
+
|
|
76
|
+
# Clamp values to ensure they stay in valid range
|
|
77
|
+
quantized = torch.clamp(quantized, -127, 127)
|
|
78
|
+
|
|
79
|
+
return quantized
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def load_model_state_dict(model_name: str, device: str) -> dict:
|
|
83
|
+
"""
|
|
84
|
+
Load a model and return its state dictionary.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model_name (str): Name or path of the model to load
|
|
88
|
+
device (str): Device to load the model on ('cuda' or 'cpu')
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
dict: State dictionary of the loaded model
|
|
92
|
+
"""
|
|
93
|
+
logging.info(f"Loading model: {model_name}")
|
|
94
|
+
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
|
95
|
+
state_dict = model.state_dict()
|
|
96
|
+
del model # Free memory
|
|
97
|
+
return state_dict
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def extract_delta_parameters(
|
|
101
|
+
model_1_name: str,
|
|
102
|
+
model_2_name: str,
|
|
103
|
+
model_base_name: str,
|
|
104
|
+
low_precision: bool,
|
|
105
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
106
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
Extract the delta parameters (weight differences) between two models
|
|
109
|
+
relative to a base model.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model_1_name (str): Name or path of the first model.
|
|
113
|
+
model_2_name (str): Name or path of the second model.
|
|
114
|
+
model_base_name (str): Name or path of the base model for comparison.
|
|
115
|
+
low_precision (bool): Whether to use low precision weights
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
(torch.Tensor, torch.Tensor): Delta parameters of model_1 and model_2 relative to base model.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# Extract state dictionaries from models
|
|
122
|
+
state_dict_1 = load_model_state_dict(model_1_name, device)
|
|
123
|
+
state_dict_2 = load_model_state_dict(model_2_name, device)
|
|
124
|
+
state_dict_base = load_model_state_dict(model_base_name, device)
|
|
125
|
+
|
|
126
|
+
# Determine the number of layers
|
|
127
|
+
num_layers = state_dict_base["lm_head.weight"].shape[0]
|
|
128
|
+
|
|
129
|
+
# Check if model architectures match, log a warning if not
|
|
130
|
+
if (
|
|
131
|
+
state_dict_1["lm_head.weight"].shape[0]
|
|
132
|
+
!= state_dict_2["lm_head.weight"].shape[0]
|
|
133
|
+
):
|
|
134
|
+
shape_1 = state_dict_1["lm_head.weight"].shape
|
|
135
|
+
shape_2 = state_dict_2["lm_head.weight"].shape
|
|
136
|
+
logging.warning(
|
|
137
|
+
f"Warning: Model architectures do not match. "
|
|
138
|
+
f"Using sub weight space instead.\n"
|
|
139
|
+
f"Vocab sizes in model 1: {shape_1[0]}, "
|
|
140
|
+
f"Vocab sizes in model 2: {shape_2[0]}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Initialize lists to store delta parameters for both models
|
|
144
|
+
d_vector_1, d_vector_2 = [], []
|
|
145
|
+
|
|
146
|
+
# Iterate over keys in the base model's state dictionary with tqdm
|
|
147
|
+
for key, base_params in tqdm(
|
|
148
|
+
state_dict_base.items(), desc="Processing keys", unit="key"
|
|
149
|
+
):
|
|
150
|
+
# Only proceed if key exists in both models
|
|
151
|
+
try:
|
|
152
|
+
if key not in state_dict_1 or key not in state_dict_2:
|
|
153
|
+
logging.warning(f"Key {key} not found in one of the models")
|
|
154
|
+
continue
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logging.error(f"Error processing key {key}: {str(e)}")
|
|
157
|
+
|
|
158
|
+
# Get the parameters for each model (truncate to num_layers for consistency)
|
|
159
|
+
params_1 = state_dict_1[key][:num_layers]
|
|
160
|
+
params_2 = state_dict_2[key][:num_layers]
|
|
161
|
+
|
|
162
|
+
# Compute the deltas relative to the base model
|
|
163
|
+
delta_1 = (params_1 - base_params).view(-1)
|
|
164
|
+
delta_2 = (params_2 - base_params).view(-1)
|
|
165
|
+
|
|
166
|
+
# Accumulate deltas
|
|
167
|
+
d_vector_1.append(delta_1)
|
|
168
|
+
d_vector_2.append(delta_2)
|
|
169
|
+
|
|
170
|
+
# Clear memory
|
|
171
|
+
del state_dict_1, state_dict_2, state_dict_base
|
|
172
|
+
|
|
173
|
+
logging.info("Concatenating delta vectors...")
|
|
174
|
+
|
|
175
|
+
d_vector_1 = torch.cat(d_vector_1)
|
|
176
|
+
d_vector_2 = torch.cat(d_vector_2)
|
|
177
|
+
|
|
178
|
+
if low_precision:
|
|
179
|
+
logging.info("Quantizing delta vectors to 8-bit precision...")
|
|
180
|
+
d_vector_1 = quantize_8bit(d_vector_1)
|
|
181
|
+
d_vector_2 = quantize_8bit(d_vector_2)
|
|
182
|
+
logging.info("Quantization complete")
|
|
183
|
+
|
|
184
|
+
return d_vector_1, d_vector_2
|
|
@@ -10,6 +10,7 @@ from lightning.fabric.loggers import TensorBoardLogger
|
|
|
10
10
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
11
11
|
from omegaconf import DictConfig, OmegaConf
|
|
12
12
|
|
|
13
|
+
from fusion_bench.constants import RuntimeConstants
|
|
13
14
|
from fusion_bench.utils import import_object
|
|
14
15
|
from fusion_bench.utils.instantiate_utils import instantiate
|
|
15
16
|
|
|
@@ -206,14 +207,7 @@ class LightningFabricMixin:
|
|
|
206
207
|
Returns:
|
|
207
208
|
bool: True if fast_dev_run is enabled, False otherwise.
|
|
208
209
|
"""
|
|
209
|
-
|
|
210
|
-
return True
|
|
211
|
-
elif hasattr(self, "_program") and self._program.config.get(
|
|
212
|
-
"fast_dev_run", False
|
|
213
|
-
):
|
|
214
|
-
return True
|
|
215
|
-
else:
|
|
216
|
-
return False
|
|
210
|
+
return RuntimeConstants().debug
|
|
217
211
|
|
|
218
212
|
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
219
213
|
"""
|
|
@@ -1,11 +1,165 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
3
|
+
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Literal, Optional
|
|
2
4
|
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
3
11
|
from fusion_bench.mixins import LightningFabricMixin
|
|
4
|
-
from fusion_bench.
|
|
12
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
13
|
+
from fusion_bench.models.open_clip import (
|
|
14
|
+
ClassificationHead,
|
|
15
|
+
ImageClassifier,
|
|
16
|
+
ImageEncoder,
|
|
17
|
+
)
|
|
18
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
5
19
|
|
|
6
20
|
log = logging.getLogger(__name__)
|
|
7
21
|
|
|
8
22
|
|
|
9
23
|
class OpenCLIPClassificationMixin(LightningFabricMixin):
|
|
24
|
+
|
|
10
25
|
_train_processor = None
|
|
11
26
|
_test_processor = None
|
|
27
|
+
dataloader_kwargs: DictConfig
|
|
28
|
+
modelpool: OpenCLIPVisionModelPool
|
|
29
|
+
zero_shot_heads: Dict[str, ClassificationHead] = {}
|
|
30
|
+
|
|
31
|
+
def _init_processor(self, encoder: Optional["ImageEncoder"] = None):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the CLIP processors for training and testing.
|
|
34
|
+
"""
|
|
35
|
+
if encoder is None:
|
|
36
|
+
encoder: "ImageEncoder" = self.modelpool.load_pretrained_or_first_model()
|
|
37
|
+
self._train_processor = encoder.train_preprocess
|
|
38
|
+
self._test_processor = encoder.val_preprocess
|
|
39
|
+
return self._train_processor, self._test_processor
|
|
40
|
+
|
|
41
|
+
def get_clip_processor(self, stage: Literal["train", "test"]):
|
|
42
|
+
"""
|
|
43
|
+
Get the CLIP processor, loading it from the model pool if necessary.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
CLIPProcessor: The CLIP processor for image and text preprocessing.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
AssertionError: If the model pool is not set.
|
|
50
|
+
"""
|
|
51
|
+
if stage == "train":
|
|
52
|
+
if self._train_processor is None:
|
|
53
|
+
self._init_processor()
|
|
54
|
+
return self._train_processor
|
|
55
|
+
elif stage == "test":
|
|
56
|
+
if self._test_processor is None:
|
|
57
|
+
self._init_processor()
|
|
58
|
+
return self._test_processor
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Invalid stage: {stage}")
|
|
61
|
+
|
|
62
|
+
def setup_zero_shot_classification_head(
|
|
63
|
+
self,
|
|
64
|
+
task_names: Optional[List[str]] = None,
|
|
65
|
+
freeze: bool = True,
|
|
66
|
+
dtype: Optional[torch.dtype] = None,
|
|
67
|
+
):
|
|
68
|
+
# check task names consistency across processes
|
|
69
|
+
_task_names = self.fabric.broadcast(task_names, src=0)
|
|
70
|
+
if not self.fabric.is_global_zero and task_names != _task_names:
|
|
71
|
+
raise ValueError("The `task_names` must be the same across all processes.")
|
|
72
|
+
|
|
73
|
+
for task in tqdm(
|
|
74
|
+
self.modelpool.model_names if task_names is None else task_names,
|
|
75
|
+
"Setting up zero-shot classification head",
|
|
76
|
+
disable=not self.fabric.is_global_zero,
|
|
77
|
+
):
|
|
78
|
+
head = self.modelpool.load_classification_head(task)
|
|
79
|
+
if freeze:
|
|
80
|
+
head.requires_grad_(False)
|
|
81
|
+
if dtype is not None:
|
|
82
|
+
head = head.to(dtype=dtype)
|
|
83
|
+
self.zero_shot_heads[task] = self.to_device(head)
|
|
84
|
+
|
|
85
|
+
def set_clip_processor(self, stage: Literal["train", "test"], processor: Callable):
|
|
86
|
+
"""
|
|
87
|
+
Set the CLIP processor for a specific stage.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
stage (Literal["train", "test"]): The stage for which to set the processor.
|
|
91
|
+
processor (Callable): The CLIP processor to set.
|
|
92
|
+
"""
|
|
93
|
+
if stage == "train":
|
|
94
|
+
self._train_processor = processor
|
|
95
|
+
elif stage == "test":
|
|
96
|
+
self._test_processor = processor
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Invalid stage: {stage}")
|
|
99
|
+
|
|
100
|
+
@functools.cache
|
|
101
|
+
def get_shuffled_test_loader_iter(
|
|
102
|
+
self,
|
|
103
|
+
task: str,
|
|
104
|
+
batch_size: Optional[int] = None,
|
|
105
|
+
num_workers: Optional[int] = None,
|
|
106
|
+
**loader_kwargs,
|
|
107
|
+
) -> Iterator:
|
|
108
|
+
"""
|
|
109
|
+
Get an iterator for a shuffled test DataLoader.
|
|
110
|
+
|
|
111
|
+
This method creates a DataLoader for the test dataset of the specified task,
|
|
112
|
+
with shuffling enabled. It allows for optional customization of batch size,
|
|
113
|
+
number of workers, and other DataLoader keyword arguments.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
task (str): The task identifier for which the test dataset is to be loaded.
|
|
117
|
+
batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
|
|
118
|
+
num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
|
|
119
|
+
**loader_kwargs: Additional keyword arguments to pass to the DataLoader.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Iterator: An iterator over the shuffled test DataLoader.
|
|
123
|
+
"""
|
|
124
|
+
# get dataloader kwargs
|
|
125
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
126
|
+
dataloader_kwargs["shuffle"] = True
|
|
127
|
+
if batch_size is not None:
|
|
128
|
+
dataloader_kwargs["batch_size"] = batch_size
|
|
129
|
+
if num_workers is not None:
|
|
130
|
+
dataloader_kwargs["num_workers"] = num_workers
|
|
131
|
+
dataloader_kwargs.update(loader_kwargs)
|
|
132
|
+
|
|
133
|
+
# get the test dataset
|
|
134
|
+
clip_dataset = CLIPDataset(
|
|
135
|
+
self.modelpool.load_test_dataset(task),
|
|
136
|
+
processor=self.get_clip_processor(stage="test"),
|
|
137
|
+
)
|
|
138
|
+
# create the dataloader
|
|
139
|
+
loader = DataLoader(clip_dataset, **dataloader_kwargs)
|
|
140
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
141
|
+
return iter(InfiniteDataLoader(loader))
|
|
142
|
+
|
|
143
|
+
def compute_logits(
|
|
144
|
+
self,
|
|
145
|
+
module: ImageClassifier,
|
|
146
|
+
images,
|
|
147
|
+
task: str,
|
|
148
|
+
):
|
|
149
|
+
"""
|
|
150
|
+
Compute the logits for a batch of images using the provided module and task.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
module (ImageClassifier): The image classification module to use for computing logits.
|
|
154
|
+
images (torch.Tensor): The batch of images for which to compute logits.
|
|
155
|
+
task (str): The task identifier to specify which classification head to use.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
torch.Tensor: The computed logits for the input images.
|
|
159
|
+
"""
|
|
160
|
+
if len(self.zero_shot_heads) == 0:
|
|
161
|
+
self.setup_zero_shot_classification_head()
|
|
162
|
+
task_head = self.zero_shot_heads[task]
|
|
163
|
+
features = module(images)
|
|
164
|
+
logits = task_head(features)
|
|
165
|
+
return logits
|
|
@@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
+
from fusion_bench import TorchModelType
|
|
10
11
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
12
|
from fusion_bench.utils import (
|
|
12
13
|
ValidationError,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import pickle
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Callable, Optional, Union, cast
|
|
4
|
+
from typing import Callable, Optional, Union, cast, override
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from datasets import load_dataset
|
|
@@ -41,8 +41,8 @@ def _check_and_redirect_open_clip_modeling():
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
try:
|
|
44
|
-
import src
|
|
45
|
-
import src.modeling
|
|
44
|
+
import src # type: ignore
|
|
45
|
+
import src.modeling # type: ignore
|
|
46
46
|
except ImportError:
|
|
47
47
|
if "src" not in sys.modules:
|
|
48
48
|
# redirect the import of `src` to `fusion_bench.models.open_clip`
|
|
@@ -114,6 +114,7 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
114
114
|
self._test_processor = encoder.val_preprocess
|
|
115
115
|
return self._test_processor
|
|
116
116
|
|
|
117
|
+
@override
|
|
117
118
|
def load_model(
|
|
118
119
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
119
120
|
) -> ImageEncoder:
|
|
@@ -210,6 +211,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
210
211
|
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
211
212
|
- Default, load the model using `instantiate` from hydra.
|
|
212
213
|
"""
|
|
214
|
+
if self._classification_heads is None:
|
|
215
|
+
raise ValueError("No classification heads are defined in the model pool.")
|
|
213
216
|
if (
|
|
214
217
|
isinstance(model_name_or_config, str)
|
|
215
218
|
and model_name_or_config in self._classification_heads
|
|
@@ -222,6 +225,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
222
225
|
return head
|
|
223
226
|
|
|
224
227
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
228
|
+
if self._train_datasets is None:
|
|
229
|
+
raise ValueError("No train datasets are defined in the model pool.")
|
|
225
230
|
dataset_config = self._train_datasets[dataset_name]
|
|
226
231
|
if isinstance(dataset_config, str):
|
|
227
232
|
log.info(
|
|
@@ -233,6 +238,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
233
238
|
return dataset
|
|
234
239
|
|
|
235
240
|
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
241
|
+
if self._val_datasets is None:
|
|
242
|
+
raise ValueError("No val datasets are defined in the model pool.")
|
|
236
243
|
dataset_config = self._val_datasets[dataset_name]
|
|
237
244
|
if isinstance(dataset_config, str):
|
|
238
245
|
log.info(
|
|
@@ -244,6 +251,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
244
251
|
return dataset
|
|
245
252
|
|
|
246
253
|
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
254
|
+
if self._test_datasets is None:
|
|
255
|
+
raise ValueError("No test datasets are defined in the model pool.")
|
|
247
256
|
dataset_config = self._test_datasets[dataset_name]
|
|
248
257
|
if isinstance(dataset_config, str):
|
|
249
258
|
log.info(
|
|
@@ -113,21 +113,27 @@ class MaskModel(ParameterDictModel):
|
|
|
113
113
|
def get_distribution(
|
|
114
114
|
self,
|
|
115
115
|
mask_type: Literal["discrete", "continuous"],
|
|
116
|
+
temperature: float = 0.5,
|
|
116
117
|
**kwargs,
|
|
117
118
|
):
|
|
118
119
|
return {
|
|
119
|
-
name: self._param_to_distribution(
|
|
120
|
+
name: self._param_to_distribution(
|
|
121
|
+
param, mask_type=mask_type, temperature=temperature, **kwargs
|
|
122
|
+
)
|
|
120
123
|
for name, param in self.named_parameters()
|
|
121
124
|
}
|
|
122
125
|
|
|
123
126
|
def sample_mask(
|
|
124
127
|
self,
|
|
125
128
|
mask_type: Literal["discrete", "continuous"] = "discrete",
|
|
129
|
+
temperature: float = 0.5,
|
|
126
130
|
**kwargs,
|
|
127
131
|
):
|
|
128
132
|
mask = {}
|
|
129
133
|
for name, param in self.named_parameters():
|
|
130
|
-
dist = self._param_to_distribution(
|
|
134
|
+
dist = self._param_to_distribution(
|
|
135
|
+
param, mask_type, temperature=temperature, **kwargs
|
|
136
|
+
)
|
|
131
137
|
if mask_type == "discrete":
|
|
132
138
|
mask[name] = dist.sample()
|
|
133
139
|
elif mask_type == "continuous":
|
|
@@ -1,3 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenCLIP model wrappers used by FusionBench.
|
|
3
|
+
|
|
4
|
+
This module provides lightweight `torch.nn.Module` wrappers around OpenCLIP
|
|
5
|
+
components that are commonly used throughout FusionBench experiments:
|
|
6
|
+
|
|
7
|
+
- `ImageEncoder`: loads an OpenCLIP image encoder and exposes `encode_image`.
|
|
8
|
+
- `ClassificationHead`: a linear head optionally normalizing inputs.
|
|
9
|
+
- `ImageClassifier` / `MultiHeadImageClassifier`: convenience compositions.
|
|
10
|
+
|
|
11
|
+
Note:
|
|
12
|
+
This module requires the optional dependency `open_clip_torch`.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from fusion_bench.utils.packages import is_open_clip_available
|
|
16
|
+
|
|
17
|
+
if not is_open_clip_available():
|
|
18
|
+
raise ImportError(
|
|
19
|
+
"open_clip is not installed. Please install it with `pip install open_clip_torch`."
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from pathlib import Path
|
|
1
23
|
from typing import Callable, List
|
|
2
24
|
|
|
3
25
|
import open_clip
|
|
@@ -10,6 +32,19 @@ from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
|
|
|
10
32
|
|
|
11
33
|
class ImageEncoder(torch.nn.Module):
|
|
12
34
|
R"""
|
|
35
|
+
OpenCLIP image encoder wrapper.
|
|
36
|
+
|
|
37
|
+
This class loads an OpenCLIP model by name and exposes a forward pass that
|
|
38
|
+
returns image embeddings via `model.encode_image`.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_name: A model name supported by `open_clip`. FusionBench also
|
|
42
|
+
supports suffixes:
|
|
43
|
+
- ``"__pretrained__<tag>"`` to select a specific pretrained weights tag.
|
|
44
|
+
- ``"__init__"`` to use random initialization.
|
|
45
|
+
keep_lang: If False (default), removes the text encoder (when present)
|
|
46
|
+
to reduce memory usage.
|
|
47
|
+
|
|
13
48
|
Examples:
|
|
14
49
|
|
|
15
50
|
load the image encoder for a given model name
|
|
@@ -18,7 +53,7 @@ class ImageEncoder(torch.nn.Module):
|
|
|
18
53
|
>>> image_encoder = ImageEncoder(model_name="ViT-B-32")
|
|
19
54
|
"""
|
|
20
55
|
|
|
21
|
-
def __init__(self, model_name: str, keep_lang=False):
|
|
56
|
+
def __init__(self, model_name: str, keep_lang: bool = False):
|
|
22
57
|
super().__init__()
|
|
23
58
|
assert (
|
|
24
59
|
model_name in MODELS
|
|
@@ -42,22 +77,26 @@ class ImageEncoder(torch.nn.Module):
|
|
|
42
77
|
|
|
43
78
|
self.cache_dir = CACHEDIR
|
|
44
79
|
|
|
80
|
+
# if `keep_lang` is False, remove the text encoder to save memory
|
|
45
81
|
if not keep_lang and hasattr(self.model, "transformer"):
|
|
46
82
|
delattr(self.model, "transformer")
|
|
47
83
|
|
|
48
|
-
def forward(self, images):
|
|
84
|
+
def forward(self, images: Tensor) -> Tensor:
|
|
85
|
+
"""Encode a batch of images into embedding vectors."""
|
|
49
86
|
assert self.model is not None
|
|
50
87
|
return self.model.encode_image(images)
|
|
51
88
|
|
|
52
|
-
def __call__(self, inputs):
|
|
89
|
+
def __call__(self, inputs: Tensor) -> Tensor:
|
|
53
90
|
return self.forward(inputs)
|
|
54
91
|
|
|
55
|
-
def save(self, filename):
|
|
92
|
+
def save(self, filename: str) -> None:
|
|
93
|
+
"""Serialize this module to disk."""
|
|
56
94
|
print(f"Saving image encoder to {filename}")
|
|
57
95
|
utils.torch_save(self, filename)
|
|
58
96
|
|
|
59
97
|
@classmethod
|
|
60
|
-
def load(cls, model_name, filename):
|
|
98
|
+
def load(cls, model_name: str, filename: str | Path):
|
|
99
|
+
"""Load a saved encoder state dict into a freshly constructed encoder."""
|
|
61
100
|
print(f"Loading image encoder from {filename}")
|
|
62
101
|
|
|
63
102
|
state_dict = torch.load(filename, map_location="cpu")
|
|
@@ -68,6 +107,15 @@ class ImageEncoder(torch.nn.Module):
|
|
|
68
107
|
|
|
69
108
|
|
|
70
109
|
class ClassificationHead(torch.nn.Linear):
|
|
110
|
+
"""A linear classification head with optional input normalization.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
normalize: If True, L2-normalize inputs along the last dimension before
|
|
114
|
+
applying the linear projection.
|
|
115
|
+
weights: Weight matrix of shape (num_classes, feature_dim).
|
|
116
|
+
biases: Optional bias vector of shape (num_classes,).
|
|
117
|
+
"""
|
|
118
|
+
|
|
71
119
|
def __init__(
|
|
72
120
|
self,
|
|
73
121
|
normalize: bool,
|
|
@@ -85,6 +133,7 @@ class ClassificationHead(torch.nn.Linear):
|
|
|
85
133
|
self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
|
|
86
134
|
|
|
87
135
|
def forward(self, inputs: Tensor):
|
|
136
|
+
"""Compute logits from input features."""
|
|
88
137
|
if self.normalize:
|
|
89
138
|
inputs = inputs / inputs.norm(dim=-1, keepdim=True)
|
|
90
139
|
return super().forward(inputs)
|
|
@@ -93,11 +142,13 @@ class ClassificationHead(torch.nn.Linear):
|
|
|
93
142
|
return self.forward(inputs)
|
|
94
143
|
|
|
95
144
|
def save(self, filename):
|
|
145
|
+
"""Serialize this head to disk."""
|
|
96
146
|
print(f"Saving classification head to {filename}")
|
|
97
147
|
utils.torch_save(self, filename, save_state_dict=False)
|
|
98
148
|
|
|
99
149
|
@classmethod
|
|
100
150
|
def load(cls, filename):
|
|
151
|
+
"""Load a serialized `ClassificationHead` instance from disk."""
|
|
101
152
|
# print(f"Loading classification head from {filename}")
|
|
102
153
|
return utils.torch_load(filename)
|
|
103
154
|
|
|
@@ -106,6 +157,8 @@ class ImageClassifier(torch.nn.Module):
|
|
|
106
157
|
train_preprocess: Callable
|
|
107
158
|
val_preprocess: Callable
|
|
108
159
|
|
|
160
|
+
"""Convenience module combining an `ImageEncoder` and a `ClassificationHead`."""
|
|
161
|
+
|
|
109
162
|
def __init__(
|
|
110
163
|
self,
|
|
111
164
|
image_encoder: ImageEncoder,
|
|
@@ -119,10 +172,12 @@ class ImageClassifier(torch.nn.Module):
|
|
|
119
172
|
self.val_preprocess = self.image_encoder.val_preprocess
|
|
120
173
|
|
|
121
174
|
def freeze_head(self):
|
|
175
|
+
"""Disable gradient computation for the classification head."""
|
|
122
176
|
self.classification_head.weight.requires_grad_(False)
|
|
123
177
|
self.classification_head.bias.requires_grad_(False)
|
|
124
178
|
|
|
125
179
|
def forward(self, inputs: Tensor):
|
|
180
|
+
"""Run encoder then head and return logits."""
|
|
126
181
|
features = self.image_encoder(inputs)
|
|
127
182
|
outputs = self.classification_head(features)
|
|
128
183
|
return outputs
|
|
@@ -131,16 +186,20 @@ class ImageClassifier(torch.nn.Module):
|
|
|
131
186
|
return self.forward(inputs)
|
|
132
187
|
|
|
133
188
|
def save(self, filename):
|
|
189
|
+
"""Serialize this module to disk."""
|
|
134
190
|
print(f"Saving image classifier to {filename}")
|
|
135
191
|
utils.torch_save(self, filename)
|
|
136
192
|
|
|
137
193
|
@classmethod
|
|
138
194
|
def load(cls, filename):
|
|
195
|
+
"""Load a serialized `ImageClassifier` instance from disk."""
|
|
139
196
|
print(f"Loading image classifier from {filename}")
|
|
140
197
|
return utils.torch_load(filename)
|
|
141
198
|
|
|
142
199
|
|
|
143
200
|
class MultiHeadImageClassifier(torch.nn.Module):
|
|
201
|
+
"""Image encoder with multiple task-specific classification heads."""
|
|
202
|
+
|
|
144
203
|
def __init__(
|
|
145
204
|
self,
|
|
146
205
|
image_encoder: ImageEncoder,
|
|
@@ -154,11 +213,13 @@ class MultiHeadImageClassifier(torch.nn.Module):
|
|
|
154
213
|
self.val_preprocess = self.image_encoder.val_preprocess
|
|
155
214
|
|
|
156
215
|
def freeze_head(self):
|
|
216
|
+
"""Disable gradient computation for all heads."""
|
|
157
217
|
for idx in range(len(self.classification_heads)):
|
|
158
218
|
self.classification_heads[idx].weight.requires_grad_(False)
|
|
159
219
|
self.classification_heads[idx].bias.requires_grad_(False)
|
|
160
220
|
|
|
161
221
|
def forward(self, inputs, head_idx):
|
|
222
|
+
"""Run encoder then the selected head and return logits."""
|
|
162
223
|
features = self.image_encoder(inputs)
|
|
163
224
|
outputs = self.classification_heads[head_idx](features)
|
|
164
225
|
return outputs
|
|
@@ -167,10 +228,12 @@ class MultiHeadImageClassifier(torch.nn.Module):
|
|
|
167
228
|
return self.forward(inputs, head_idx)
|
|
168
229
|
|
|
169
230
|
def save(self, filename):
|
|
231
|
+
"""Serialize this module to disk."""
|
|
170
232
|
print(f"Saving image classifier to {filename}")
|
|
171
233
|
utils.torch_save(self, filename)
|
|
172
234
|
|
|
173
235
|
@classmethod
|
|
174
236
|
def load(cls, filename):
|
|
237
|
+
"""Load a serialized `MultiHeadImageClassifier` instance from disk."""
|
|
175
238
|
print(f"Loading image classifier from {filename}")
|
|
176
239
|
return utils.torch_load(filename)
|