fusion-bench 0.2.30__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +2 -8
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +7 -16
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +29 -26
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
|
@@ -89,7 +89,10 @@ class RuntimeConstants:
|
|
|
89
89
|
self._initialized = True
|
|
90
90
|
|
|
91
91
|
debug = False
|
|
92
|
-
"""
|
|
92
|
+
"""
|
|
93
|
+
Global debug flag for enabling verbose logging and debugging features.
|
|
94
|
+
Use `RuntimeConstants().debug` instead of `RuntimeConstants.debug`
|
|
95
|
+
"""
|
|
93
96
|
|
|
94
97
|
@property
|
|
95
98
|
def cache_dir(self) -> Path:
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench import (
|
|
11
|
+
BaseAlgorithm,
|
|
12
|
+
OpenCLIPClassificationMixin,
|
|
13
|
+
OpenCLIPVisionModelPool,
|
|
14
|
+
SimpleProfilerMixin,
|
|
15
|
+
StateDictType,
|
|
16
|
+
auto_register_config,
|
|
17
|
+
get_rankzero_logger,
|
|
18
|
+
instantiate,
|
|
19
|
+
)
|
|
20
|
+
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
|
|
21
|
+
from fusion_bench.method.task_singular_vector import TaskSingularVectorMerging
|
|
22
|
+
from fusion_bench.method.task_singular_vector.utils import (
|
|
23
|
+
TSVM_utils,
|
|
24
|
+
check_parameterNamesMatch,
|
|
25
|
+
check_state_dicts_equal,
|
|
26
|
+
state_dict_to_vector,
|
|
27
|
+
vector_to_state_dict,
|
|
28
|
+
)
|
|
29
|
+
from fusion_bench.models.masks import MaskModel, mask_sparsity
|
|
30
|
+
from fusion_bench.models.open_clip import (
|
|
31
|
+
ClassificationHead,
|
|
32
|
+
ImageClassifier,
|
|
33
|
+
ImageEncoder,
|
|
34
|
+
)
|
|
35
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
36
|
+
TaskWiseMergedModel,
|
|
37
|
+
get_task_wise_weights,
|
|
38
|
+
)
|
|
39
|
+
from fusion_bench.utils.devices import clear_cuda_cache
|
|
40
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
41
|
+
from fusion_bench.utils.parameters import print_parameters, print_trainable_parameters
|
|
42
|
+
from fusion_bench.utils.rich_utils import print_config_yaml
|
|
43
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
44
|
+
_validate_state_dict_same_keys,
|
|
45
|
+
state_dict_add,
|
|
46
|
+
state_dict_hadamard_product,
|
|
47
|
+
state_dict_mul,
|
|
48
|
+
state_dict_sub,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
log = get_rankzero_logger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@auto_register_config
|
|
55
|
+
class ConcreteTSVMForOpenCLIP(
|
|
56
|
+
OpenCLIPClassificationMixin,
|
|
57
|
+
SimpleProfilerMixin,
|
|
58
|
+
BaseAlgorithm,
|
|
59
|
+
):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
dataloader_kwargs: DictConfig,
|
|
63
|
+
optimizer: DictConfig,
|
|
64
|
+
lr_scheduler: DictConfig,
|
|
65
|
+
max_steps: int,
|
|
66
|
+
save_interval: int,
|
|
67
|
+
initial_logits: float,
|
|
68
|
+
temperature: float,
|
|
69
|
+
eval_mask_type: Literal["continuous", "discrete"],
|
|
70
|
+
mask_checkpoint: Optional[str],
|
|
71
|
+
merge_dtype: str,
|
|
72
|
+
clamp_weights: bool,
|
|
73
|
+
tie_weights: bool,
|
|
74
|
+
strict: bool,
|
|
75
|
+
skip_training: bool,
|
|
76
|
+
# === TSVM parameters ===
|
|
77
|
+
exclude_keys: Optional[List[str]],
|
|
78
|
+
alpha: float,
|
|
79
|
+
return_single_task_models: bool = True,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
super().__init__(**kwargs)
|
|
83
|
+
if not return_single_task_models:
|
|
84
|
+
log.warning("return_single_task_models is forced to be True here.")
|
|
85
|
+
self.return_single_task_models = True
|
|
86
|
+
|
|
87
|
+
@torch.no_grad()
|
|
88
|
+
def setup_models(self):
|
|
89
|
+
"""
|
|
90
|
+
load the pre-trained model, task vectors, and construct the mask model.
|
|
91
|
+
"""
|
|
92
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
93
|
+
modelpool = self.modelpool
|
|
94
|
+
|
|
95
|
+
# load the pre-trained model
|
|
96
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
97
|
+
self.set_clip_processor(stage="test", processor=pretrained_model.val_preprocess)
|
|
98
|
+
|
|
99
|
+
# constrcute mask model
|
|
100
|
+
mask_model = MaskModel(
|
|
101
|
+
pretrained_model, ignore_untrained_params=True, parameter_type="logits"
|
|
102
|
+
)
|
|
103
|
+
if merge_dtype is not None:
|
|
104
|
+
mask_model.to(merge_dtype)
|
|
105
|
+
mask_model.fill_(self.initial_logits)
|
|
106
|
+
|
|
107
|
+
if self.fabric.is_global_zero:
|
|
108
|
+
print("summary of mask model:")
|
|
109
|
+
print_parameters(mask_model)
|
|
110
|
+
|
|
111
|
+
if self.fabric.is_global_zero:
|
|
112
|
+
tsvm_algo = TaskSingularVectorMerging(
|
|
113
|
+
alpha=self.alpha,
|
|
114
|
+
exclude_keys=self.exclude_keys,
|
|
115
|
+
return_single_task_models=self.return_single_task_models,
|
|
116
|
+
)
|
|
117
|
+
tsvm_algo._fabric_instance = self.fabric
|
|
118
|
+
models = tsvm_algo.run(modelpool)
|
|
119
|
+
|
|
120
|
+
finetuned_models = [models[name] for name in modelpool.model_names]
|
|
121
|
+
|
|
122
|
+
task_wise_weight = get_task_wise_weights(
|
|
123
|
+
num_models=len(modelpool.model_names),
|
|
124
|
+
init_values=self.alpha,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# create a wrapped model
|
|
128
|
+
module = TaskWiseMergedModel(
|
|
129
|
+
task_wise_weight=task_wise_weight,
|
|
130
|
+
pretrained_model=pretrained_model,
|
|
131
|
+
finetuned_models=finetuned_models,
|
|
132
|
+
clamp_weights=self.clamp_weights,
|
|
133
|
+
tie_weights=self.tie_weights,
|
|
134
|
+
strict=self.strict,
|
|
135
|
+
task_vector_dtype=merge_dtype,
|
|
136
|
+
)
|
|
137
|
+
module = module.to(dtype=merge_dtype)
|
|
138
|
+
|
|
139
|
+
print("trainable parameter summary of merged model (TaskWiseMergedModel):")
|
|
140
|
+
print_trainable_parameters(module)
|
|
141
|
+
else:
|
|
142
|
+
module = None
|
|
143
|
+
|
|
144
|
+
with torch.no_grad():
|
|
145
|
+
self.fabric.barrier()
|
|
146
|
+
module = self.fabric.broadcast(module, src=0)
|
|
147
|
+
|
|
148
|
+
return module, mask_model
|
|
149
|
+
|
|
150
|
+
def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
|
|
151
|
+
"""
|
|
152
|
+
Train the mask model using the provided module.
|
|
153
|
+
|
|
154
|
+
This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
module (TaskWiseMergedModel): The wrapped model with task-wise weights.
|
|
158
|
+
mask_model (MaskModel): The mask model to be trained.
|
|
159
|
+
"""
|
|
160
|
+
config = self.config
|
|
161
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
162
|
+
log.info(f"Using merge dtype: {merge_dtype}")
|
|
163
|
+
|
|
164
|
+
optimizer: "torch.optim.Optimizer" = instantiate(
|
|
165
|
+
self.optimizer,
|
|
166
|
+
params=filter(lambda p: p.requires_grad, mask_model.parameters()),
|
|
167
|
+
)
|
|
168
|
+
print(f"{optimizer=}")
|
|
169
|
+
if self.lr_scheduler is not None:
|
|
170
|
+
lr_scheduler = instantiate(
|
|
171
|
+
self.lr_scheduler,
|
|
172
|
+
optimizer=optimizer,
|
|
173
|
+
)
|
|
174
|
+
print(f"{lr_scheduler=}")
|
|
175
|
+
else:
|
|
176
|
+
lr_scheduler = None
|
|
177
|
+
|
|
178
|
+
log.info("Setup models and optimizer with Fabric.")
|
|
179
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
180
|
+
|
|
181
|
+
log.info("Move the merged module to the correct device and disable gradients.")
|
|
182
|
+
module.requires_grad_(False)
|
|
183
|
+
module.to(mask_model.device)
|
|
184
|
+
|
|
185
|
+
mask_model.train()
|
|
186
|
+
optimizer.zero_grad()
|
|
187
|
+
for step_idx in (
|
|
188
|
+
pbar := tqdm(
|
|
189
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
190
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
191
|
+
+ "Concrete TSVM Test-Time Adaptation",
|
|
192
|
+
dynamic_ncols=True,
|
|
193
|
+
disable=not self.fabric.is_global_zero,
|
|
194
|
+
)
|
|
195
|
+
):
|
|
196
|
+
metrics = {}
|
|
197
|
+
# sample a shared mask and merge weights
|
|
198
|
+
with self.profile("sample mask"):
|
|
199
|
+
mask = mask_model.sample_mask(
|
|
200
|
+
mask_type="continuous", temperature=config.temperature
|
|
201
|
+
)
|
|
202
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
203
|
+
with self.profile("merge weights"):
|
|
204
|
+
# rescale mask
|
|
205
|
+
for name, m in mask.items():
|
|
206
|
+
mask[name] = m / torch.mean(m)
|
|
207
|
+
module.merge_weights(task_vector_mask=mask)
|
|
208
|
+
|
|
209
|
+
# ------ inner optimization goes here ------
|
|
210
|
+
# NOTE:
|
|
211
|
+
# Because the algorithmic parameters of TSVM are assumed to be chosen on a validation test
|
|
212
|
+
# set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
|
|
213
|
+
# ------------------------------------------
|
|
214
|
+
|
|
215
|
+
total_loss = None
|
|
216
|
+
for task in self.modelpool.model_names:
|
|
217
|
+
with self.profile("data loading"):
|
|
218
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
219
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
220
|
+
images = batch[0].to(dtype=merge_dtype)
|
|
221
|
+
with self.profile("forward pass"):
|
|
222
|
+
logits = self.compute_logits(module, images, task)
|
|
223
|
+
loss = entropy_loss(logits)
|
|
224
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
225
|
+
|
|
226
|
+
with self.profile("compute grad"):
|
|
227
|
+
self.fabric.backward(total_loss)
|
|
228
|
+
|
|
229
|
+
with self.profile("optimizer step"):
|
|
230
|
+
optimizer.step()
|
|
231
|
+
optimizer.zero_grad()
|
|
232
|
+
|
|
233
|
+
if lr_scheduler is not None:
|
|
234
|
+
lr_scheduler.step()
|
|
235
|
+
|
|
236
|
+
metrics.update({"train/loss": loss.item()})
|
|
237
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
238
|
+
pbar.set_postfix(metrics)
|
|
239
|
+
|
|
240
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
241
|
+
with self.profiler.profile("save checkpoint"):
|
|
242
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
243
|
+
if not os.path.exists(save_dir):
|
|
244
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
245
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
246
|
+
print(f"saving checkpoint to {save_path}")
|
|
247
|
+
state = {"model": mask_model}
|
|
248
|
+
self.fabric.save(save_path, state)
|
|
249
|
+
|
|
250
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
251
|
+
if self.fabric.is_global_zero:
|
|
252
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
253
|
+
if os.path.exists(symlink_path):
|
|
254
|
+
os.remove(symlink_path)
|
|
255
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
256
|
+
|
|
257
|
+
self.print_profile_summary()
|
|
258
|
+
|
|
259
|
+
def run(self, modelpool: OpenCLIPVisionModelPool):
|
|
260
|
+
self.modelpool = modelpool
|
|
261
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
262
|
+
|
|
263
|
+
with self.profile("setup models"):
|
|
264
|
+
module, mask_model = self.setup_models()
|
|
265
|
+
self.setup_zero_shot_classification_head(freeze=True, dtype=merge_dtype)
|
|
266
|
+
|
|
267
|
+
if self.mask_checkpoint is None:
|
|
268
|
+
if not self.skip_training:
|
|
269
|
+
clear_cuda_cache()
|
|
270
|
+
self.train_mask(module, mask_model=mask_model)
|
|
271
|
+
else:
|
|
272
|
+
if self.fabric.is_global_zero:
|
|
273
|
+
print("loading mask from checkpoint", self.mask_checkpoint)
|
|
274
|
+
self.fabric.load(self.mask_checkpoint, {"model": mask_model})
|
|
275
|
+
|
|
276
|
+
with torch.no_grad():
|
|
277
|
+
clear_cuda_cache()
|
|
278
|
+
mask = mask_model.sample_mask(
|
|
279
|
+
mask_type=self.eval_mask_type, temperature=self.temperature
|
|
280
|
+
)
|
|
281
|
+
# rescale mask
|
|
282
|
+
for name, m in mask.items():
|
|
283
|
+
mask[name] = m / torch.mean(m)
|
|
284
|
+
model = module.merge_and_unload(mask)
|
|
285
|
+
return model.to(dtype=torch.float32)
|
|
@@ -249,12 +249,13 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
249
249
|
# - SVD finds the principal components (most important directions)
|
|
250
250
|
# - Task vectors are reconstructed using only the most significant components
|
|
251
251
|
# - The reconstructed vectors are merged (summed) to create a unified task vector
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
252
|
+
with torch.no_grad():
|
|
253
|
+
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
254
|
+
task_vectors,
|
|
255
|
+
exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
|
|
256
|
+
accelerator=accelerator, # Use GPU if available
|
|
257
|
+
return_single_task_models=self.return_single_task_models,
|
|
258
|
+
)
|
|
258
259
|
|
|
259
260
|
# Handle the case where individual transformed task vectors are also returned
|
|
260
261
|
if self.return_single_task_models:
|
|
@@ -311,7 +311,6 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(
|
|
|
311
311
|
|
|
312
312
|
###############
|
|
313
313
|
#### TSV Merge Orthogonalization
|
|
314
|
-
@torch.no_grad()
|
|
315
314
|
def compute_and_sum_svd_mem_reduction(
|
|
316
315
|
task_vectors: List[StateDictType],
|
|
317
316
|
exclude_keys: Optional[List[str]] = None,
|
|
@@ -10,6 +10,7 @@ from lightning.fabric.loggers import TensorBoardLogger
|
|
|
10
10
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
11
11
|
from omegaconf import DictConfig, OmegaConf
|
|
12
12
|
|
|
13
|
+
from fusion_bench.constants import RuntimeConstants
|
|
13
14
|
from fusion_bench.utils import import_object
|
|
14
15
|
from fusion_bench.utils.instantiate_utils import instantiate
|
|
15
16
|
|
|
@@ -206,14 +207,7 @@ class LightningFabricMixin:
|
|
|
206
207
|
Returns:
|
|
207
208
|
bool: True if fast_dev_run is enabled, False otherwise.
|
|
208
209
|
"""
|
|
209
|
-
|
|
210
|
-
return True
|
|
211
|
-
elif hasattr(self, "_program") and self._program.config.get(
|
|
212
|
-
"fast_dev_run", False
|
|
213
|
-
):
|
|
214
|
-
return True
|
|
215
|
-
else:
|
|
216
|
-
return False
|
|
210
|
+
return RuntimeConstants().debug
|
|
217
211
|
|
|
218
212
|
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
219
213
|
"""
|
|
@@ -1,11 +1,165 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
3
|
+
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Literal, Optional
|
|
2
4
|
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
3
11
|
from fusion_bench.mixins import LightningFabricMixin
|
|
4
|
-
from fusion_bench.
|
|
12
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
13
|
+
from fusion_bench.models.open_clip import (
|
|
14
|
+
ClassificationHead,
|
|
15
|
+
ImageClassifier,
|
|
16
|
+
ImageEncoder,
|
|
17
|
+
)
|
|
18
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
5
19
|
|
|
6
20
|
log = logging.getLogger(__name__)
|
|
7
21
|
|
|
8
22
|
|
|
9
23
|
class OpenCLIPClassificationMixin(LightningFabricMixin):
|
|
24
|
+
|
|
10
25
|
_train_processor = None
|
|
11
26
|
_test_processor = None
|
|
27
|
+
dataloader_kwargs: DictConfig
|
|
28
|
+
modelpool: OpenCLIPVisionModelPool
|
|
29
|
+
zero_shot_heads: Dict[str, ClassificationHead] = {}
|
|
30
|
+
|
|
31
|
+
def _init_processor(self, encoder: Optional["ImageEncoder"] = None):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the CLIP processors for training and testing.
|
|
34
|
+
"""
|
|
35
|
+
if encoder is None:
|
|
36
|
+
encoder: "ImageEncoder" = self.modelpool.load_pretrained_or_first_model()
|
|
37
|
+
self._train_processor = encoder.train_preprocess
|
|
38
|
+
self._test_processor = encoder.val_preprocess
|
|
39
|
+
return self._train_processor, self._test_processor
|
|
40
|
+
|
|
41
|
+
def get_clip_processor(self, stage: Literal["train", "test"]):
|
|
42
|
+
"""
|
|
43
|
+
Get the CLIP processor, loading it from the model pool if necessary.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
CLIPProcessor: The CLIP processor for image and text preprocessing.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
AssertionError: If the model pool is not set.
|
|
50
|
+
"""
|
|
51
|
+
if stage == "train":
|
|
52
|
+
if self._train_processor is None:
|
|
53
|
+
self._init_processor()
|
|
54
|
+
return self._train_processor
|
|
55
|
+
elif stage == "test":
|
|
56
|
+
if self._test_processor is None:
|
|
57
|
+
self._init_processor()
|
|
58
|
+
return self._test_processor
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Invalid stage: {stage}")
|
|
61
|
+
|
|
62
|
+
def setup_zero_shot_classification_head(
|
|
63
|
+
self,
|
|
64
|
+
task_names: Optional[List[str]] = None,
|
|
65
|
+
freeze: bool = True,
|
|
66
|
+
dtype: Optional[torch.dtype] = None,
|
|
67
|
+
):
|
|
68
|
+
# check task names consistency across processes
|
|
69
|
+
_task_names = self.fabric.broadcast(task_names, src=0)
|
|
70
|
+
if not self.fabric.is_global_zero and task_names != _task_names:
|
|
71
|
+
raise ValueError("The `task_names` must be the same across all processes.")
|
|
72
|
+
|
|
73
|
+
for task in tqdm(
|
|
74
|
+
self.modelpool.model_names if task_names is None else task_names,
|
|
75
|
+
"Setting up zero-shot classification head",
|
|
76
|
+
disable=not self.fabric.is_global_zero,
|
|
77
|
+
):
|
|
78
|
+
head = self.modelpool.load_classification_head(task)
|
|
79
|
+
if freeze:
|
|
80
|
+
head.requires_grad_(False)
|
|
81
|
+
if dtype is not None:
|
|
82
|
+
head = head.to(dtype=dtype)
|
|
83
|
+
self.zero_shot_heads[task] = self.to_device(head)
|
|
84
|
+
|
|
85
|
+
def set_clip_processor(self, stage: Literal["train", "test"], processor: Callable):
|
|
86
|
+
"""
|
|
87
|
+
Set the CLIP processor for a specific stage.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
stage (Literal["train", "test"]): The stage for which to set the processor.
|
|
91
|
+
processor (Callable): The CLIP processor to set.
|
|
92
|
+
"""
|
|
93
|
+
if stage == "train":
|
|
94
|
+
self._train_processor = processor
|
|
95
|
+
elif stage == "test":
|
|
96
|
+
self._test_processor = processor
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Invalid stage: {stage}")
|
|
99
|
+
|
|
100
|
+
@functools.cache
|
|
101
|
+
def get_shuffled_test_loader_iter(
|
|
102
|
+
self,
|
|
103
|
+
task: str,
|
|
104
|
+
batch_size: Optional[int] = None,
|
|
105
|
+
num_workers: Optional[int] = None,
|
|
106
|
+
**loader_kwargs,
|
|
107
|
+
) -> Iterator:
|
|
108
|
+
"""
|
|
109
|
+
Get an iterator for a shuffled test DataLoader.
|
|
110
|
+
|
|
111
|
+
This method creates a DataLoader for the test dataset of the specified task,
|
|
112
|
+
with shuffling enabled. It allows for optional customization of batch size,
|
|
113
|
+
number of workers, and other DataLoader keyword arguments.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
task (str): The task identifier for which the test dataset is to be loaded.
|
|
117
|
+
batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
|
|
118
|
+
num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
|
|
119
|
+
**loader_kwargs: Additional keyword arguments to pass to the DataLoader.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Iterator: An iterator over the shuffled test DataLoader.
|
|
123
|
+
"""
|
|
124
|
+
# get dataloader kwargs
|
|
125
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
126
|
+
dataloader_kwargs["shuffle"] = True
|
|
127
|
+
if batch_size is not None:
|
|
128
|
+
dataloader_kwargs["batch_size"] = batch_size
|
|
129
|
+
if num_workers is not None:
|
|
130
|
+
dataloader_kwargs["num_workers"] = num_workers
|
|
131
|
+
dataloader_kwargs.update(loader_kwargs)
|
|
132
|
+
|
|
133
|
+
# get the test dataset
|
|
134
|
+
clip_dataset = CLIPDataset(
|
|
135
|
+
self.modelpool.load_test_dataset(task),
|
|
136
|
+
processor=self.get_clip_processor(stage="test"),
|
|
137
|
+
)
|
|
138
|
+
# create the dataloader
|
|
139
|
+
loader = DataLoader(clip_dataset, **dataloader_kwargs)
|
|
140
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
141
|
+
return iter(InfiniteDataLoader(loader))
|
|
142
|
+
|
|
143
|
+
def compute_logits(
|
|
144
|
+
self,
|
|
145
|
+
module: ImageClassifier,
|
|
146
|
+
images,
|
|
147
|
+
task: str,
|
|
148
|
+
):
|
|
149
|
+
"""
|
|
150
|
+
Compute the logits for a batch of images using the provided module and task.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
module (ImageClassifier): The image classification module to use for computing logits.
|
|
154
|
+
images (torch.Tensor): The batch of images for which to compute logits.
|
|
155
|
+
task (str): The task identifier to specify which classification head to use.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
torch.Tensor: The computed logits for the input images.
|
|
159
|
+
"""
|
|
160
|
+
if len(self.zero_shot_heads) == 0:
|
|
161
|
+
self.setup_zero_shot_classification_head()
|
|
162
|
+
task_head = self.zero_shot_heads[task]
|
|
163
|
+
features = module(images)
|
|
164
|
+
logits = task_head(features)
|
|
165
|
+
return logits
|
|
@@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
+
from fusion_bench import TorchModelType
|
|
10
11
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
12
|
from fusion_bench.utils import (
|
|
12
13
|
ValidationError,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import pickle
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Callable, Optional, Union, cast
|
|
4
|
+
from typing import Callable, Optional, Union, cast, override
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from datasets import load_dataset
|
|
@@ -41,8 +41,8 @@ def _check_and_redirect_open_clip_modeling():
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
try:
|
|
44
|
-
import src
|
|
45
|
-
import src.modeling
|
|
44
|
+
import src # type: ignore
|
|
45
|
+
import src.modeling # type: ignore
|
|
46
46
|
except ImportError:
|
|
47
47
|
if "src" not in sys.modules:
|
|
48
48
|
# redirect the import of `src` to `fusion_bench.models.open_clip`
|
|
@@ -114,6 +114,7 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
114
114
|
self._test_processor = encoder.val_preprocess
|
|
115
115
|
return self._test_processor
|
|
116
116
|
|
|
117
|
+
@override
|
|
117
118
|
def load_model(
|
|
118
119
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
119
120
|
) -> ImageEncoder:
|
|
@@ -210,6 +211,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
210
211
|
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
211
212
|
- Default, load the model using `instantiate` from hydra.
|
|
212
213
|
"""
|
|
214
|
+
if self._classification_heads is None:
|
|
215
|
+
raise ValueError("No classification heads are defined in the model pool.")
|
|
213
216
|
if (
|
|
214
217
|
isinstance(model_name_or_config, str)
|
|
215
218
|
and model_name_or_config in self._classification_heads
|
|
@@ -222,6 +225,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
222
225
|
return head
|
|
223
226
|
|
|
224
227
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
228
|
+
if self._train_datasets is None:
|
|
229
|
+
raise ValueError("No train datasets are defined in the model pool.")
|
|
225
230
|
dataset_config = self._train_datasets[dataset_name]
|
|
226
231
|
if isinstance(dataset_config, str):
|
|
227
232
|
log.info(
|
|
@@ -233,6 +238,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
233
238
|
return dataset
|
|
234
239
|
|
|
235
240
|
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
241
|
+
if self._val_datasets is None:
|
|
242
|
+
raise ValueError("No val datasets are defined in the model pool.")
|
|
236
243
|
dataset_config = self._val_datasets[dataset_name]
|
|
237
244
|
if isinstance(dataset_config, str):
|
|
238
245
|
log.info(
|
|
@@ -244,6 +251,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
244
251
|
return dataset
|
|
245
252
|
|
|
246
253
|
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
254
|
+
if self._test_datasets is None:
|
|
255
|
+
raise ValueError("No test datasets are defined in the model pool.")
|
|
247
256
|
dataset_config = self._test_datasets[dataset_name]
|
|
248
257
|
if isinstance(dataset_config, str):
|
|
249
258
|
log.info(
|