fusion-bench 0.2.30__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. fusion_bench/constants/runtime.py +4 -1
  2. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  3. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  4. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  5. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  6. fusion_bench/mixins/lightning_fabric.py +2 -8
  7. fusion_bench/mixins/openclip_classification.py +155 -1
  8. fusion_bench/modelpool/base_pool.py +1 -0
  9. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  10. fusion_bench/models/open_clip/modeling.py +61 -5
  11. fusion_bench/models/open_clip/utils.py +13 -2
  12. fusion_bench/py.typed +1 -0
  13. fusion_bench/scripts/cli.py +7 -16
  14. fusion_bench/scripts/imgui.py +2 -2
  15. fusion_bench/scripts/webui.py +2 -2
  16. fusion_bench/utils/__init__.py +2 -0
  17. fusion_bench/utils/hydra_utils.py +75 -0
  18. fusion_bench/utils/parameters.py +33 -0
  19. fusion_bench/utils/rich_utils.py +42 -19
  20. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +1 -1
  21. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +29 -26
  22. fusion_bench_config/README.md +9 -0
  23. fusion_bench_config/fabric/auto.yaml +1 -0
  24. fusion_bench_config/hydra/default.yaml +3 -1
  25. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  26. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
  27. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
  28. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
  29. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
@@ -89,7 +89,10 @@ class RuntimeConstants:
89
89
  self._initialized = True
90
90
 
91
91
  debug = False
92
- """Global debug flag for enabling verbose logging and debugging features."""
92
+ """
93
+ Global debug flag for enabling verbose logging and debugging features.
94
+ Use `RuntimeConstants().debug` instead of `RuntimeConstants.debug`
95
+ """
93
96
 
94
97
  @property
95
98
  def cache_dir(self) -> Path:
@@ -173,6 +173,7 @@ class ImageClassificationFineTuning(BaseAlgorithm):
173
173
  ),
174
174
  },
175
175
  )
176
+ lit_module.train()
176
177
 
177
178
  log_dir = (
178
179
  self._program.path.log_dir
@@ -0,0 +1,285 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from tqdm import tqdm
9
+
10
+ from fusion_bench import (
11
+ BaseAlgorithm,
12
+ OpenCLIPClassificationMixin,
13
+ OpenCLIPVisionModelPool,
14
+ SimpleProfilerMixin,
15
+ StateDictType,
16
+ auto_register_config,
17
+ get_rankzero_logger,
18
+ instantiate,
19
+ )
20
+ from fusion_bench.method.adamerging.entropy_loss import entropy_loss
21
+ from fusion_bench.method.task_singular_vector import TaskSingularVectorMerging
22
+ from fusion_bench.method.task_singular_vector.utils import (
23
+ TSVM_utils,
24
+ check_parameterNamesMatch,
25
+ check_state_dicts_equal,
26
+ state_dict_to_vector,
27
+ vector_to_state_dict,
28
+ )
29
+ from fusion_bench.models.masks import MaskModel, mask_sparsity
30
+ from fusion_bench.models.open_clip import (
31
+ ClassificationHead,
32
+ ImageClassifier,
33
+ ImageEncoder,
34
+ )
35
+ from fusion_bench.models.wrappers.task_wise_fusion import (
36
+ TaskWiseMergedModel,
37
+ get_task_wise_weights,
38
+ )
39
+ from fusion_bench.utils.devices import clear_cuda_cache
40
+ from fusion_bench.utils.dtype import parse_dtype
41
+ from fusion_bench.utils.parameters import print_parameters, print_trainable_parameters
42
+ from fusion_bench.utils.rich_utils import print_config_yaml
43
+ from fusion_bench.utils.state_dict_arithmetic import (
44
+ _validate_state_dict_same_keys,
45
+ state_dict_add,
46
+ state_dict_hadamard_product,
47
+ state_dict_mul,
48
+ state_dict_sub,
49
+ )
50
+
51
+ log = get_rankzero_logger(__name__)
52
+
53
+
54
+ @auto_register_config
55
+ class ConcreteTSVMForOpenCLIP(
56
+ OpenCLIPClassificationMixin,
57
+ SimpleProfilerMixin,
58
+ BaseAlgorithm,
59
+ ):
60
+ def __init__(
61
+ self,
62
+ dataloader_kwargs: DictConfig,
63
+ optimizer: DictConfig,
64
+ lr_scheduler: DictConfig,
65
+ max_steps: int,
66
+ save_interval: int,
67
+ initial_logits: float,
68
+ temperature: float,
69
+ eval_mask_type: Literal["continuous", "discrete"],
70
+ mask_checkpoint: Optional[str],
71
+ merge_dtype: str,
72
+ clamp_weights: bool,
73
+ tie_weights: bool,
74
+ strict: bool,
75
+ skip_training: bool,
76
+ # === TSVM parameters ===
77
+ exclude_keys: Optional[List[str]],
78
+ alpha: float,
79
+ return_single_task_models: bool = True,
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ if not return_single_task_models:
84
+ log.warning("return_single_task_models is forced to be True here.")
85
+ self.return_single_task_models = True
86
+
87
+ @torch.no_grad()
88
+ def setup_models(self):
89
+ """
90
+ load the pre-trained model, task vectors, and construct the mask model.
91
+ """
92
+ merge_dtype = parse_dtype(self.merge_dtype)
93
+ modelpool = self.modelpool
94
+
95
+ # load the pre-trained model
96
+ pretrained_model = modelpool.load_pretrained_model()
97
+ self.set_clip_processor(stage="test", processor=pretrained_model.val_preprocess)
98
+
99
+ # constrcute mask model
100
+ mask_model = MaskModel(
101
+ pretrained_model, ignore_untrained_params=True, parameter_type="logits"
102
+ )
103
+ if merge_dtype is not None:
104
+ mask_model.to(merge_dtype)
105
+ mask_model.fill_(self.initial_logits)
106
+
107
+ if self.fabric.is_global_zero:
108
+ print("summary of mask model:")
109
+ print_parameters(mask_model)
110
+
111
+ if self.fabric.is_global_zero:
112
+ tsvm_algo = TaskSingularVectorMerging(
113
+ alpha=self.alpha,
114
+ exclude_keys=self.exclude_keys,
115
+ return_single_task_models=self.return_single_task_models,
116
+ )
117
+ tsvm_algo._fabric_instance = self.fabric
118
+ models = tsvm_algo.run(modelpool)
119
+
120
+ finetuned_models = [models[name] for name in modelpool.model_names]
121
+
122
+ task_wise_weight = get_task_wise_weights(
123
+ num_models=len(modelpool.model_names),
124
+ init_values=self.alpha,
125
+ )
126
+
127
+ # create a wrapped model
128
+ module = TaskWiseMergedModel(
129
+ task_wise_weight=task_wise_weight,
130
+ pretrained_model=pretrained_model,
131
+ finetuned_models=finetuned_models,
132
+ clamp_weights=self.clamp_weights,
133
+ tie_weights=self.tie_weights,
134
+ strict=self.strict,
135
+ task_vector_dtype=merge_dtype,
136
+ )
137
+ module = module.to(dtype=merge_dtype)
138
+
139
+ print("trainable parameter summary of merged model (TaskWiseMergedModel):")
140
+ print_trainable_parameters(module)
141
+ else:
142
+ module = None
143
+
144
+ with torch.no_grad():
145
+ self.fabric.barrier()
146
+ module = self.fabric.broadcast(module, src=0)
147
+
148
+ return module, mask_model
149
+
150
+ def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
151
+ """
152
+ Train the mask model using the provided module.
153
+
154
+ This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
155
+
156
+ Args:
157
+ module (TaskWiseMergedModel): The wrapped model with task-wise weights.
158
+ mask_model (MaskModel): The mask model to be trained.
159
+ """
160
+ config = self.config
161
+ merge_dtype = parse_dtype(self.merge_dtype)
162
+ log.info(f"Using merge dtype: {merge_dtype}")
163
+
164
+ optimizer: "torch.optim.Optimizer" = instantiate(
165
+ self.optimizer,
166
+ params=filter(lambda p: p.requires_grad, mask_model.parameters()),
167
+ )
168
+ print(f"{optimizer=}")
169
+ if self.lr_scheduler is not None:
170
+ lr_scheduler = instantiate(
171
+ self.lr_scheduler,
172
+ optimizer=optimizer,
173
+ )
174
+ print(f"{lr_scheduler=}")
175
+ else:
176
+ lr_scheduler = None
177
+
178
+ log.info("Setup models and optimizer with Fabric.")
179
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
180
+
181
+ log.info("Move the merged module to the correct device and disable gradients.")
182
+ module.requires_grad_(False)
183
+ module.to(mask_model.device)
184
+
185
+ mask_model.train()
186
+ optimizer.zero_grad()
187
+ for step_idx in (
188
+ pbar := tqdm(
189
+ range(self.config.max_steps if not self.is_debug_mode else 5),
190
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
191
+ + "Concrete TSVM Test-Time Adaptation",
192
+ dynamic_ncols=True,
193
+ disable=not self.fabric.is_global_zero,
194
+ )
195
+ ):
196
+ metrics = {}
197
+ # sample a shared mask and merge weights
198
+ with self.profile("sample mask"):
199
+ mask = mask_model.sample_mask(
200
+ mask_type="continuous", temperature=config.temperature
201
+ )
202
+ metrics["train/sparsity"] = mask_sparsity(mask)
203
+ with self.profile("merge weights"):
204
+ # rescale mask
205
+ for name, m in mask.items():
206
+ mask[name] = m / torch.mean(m)
207
+ module.merge_weights(task_vector_mask=mask)
208
+
209
+ # ------ inner optimization goes here ------
210
+ # NOTE:
211
+ # Because the algorithmic parameters of TSVM are assumed to be chosen on a validation test
212
+ # set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
213
+ # ------------------------------------------
214
+
215
+ total_loss = None
216
+ for task in self.modelpool.model_names:
217
+ with self.profile("data loading"):
218
+ batch = next(self.get_shuffled_test_loader_iter(task))
219
+ # NOTE: The labels are not allowed to be used during test-time adaptation
220
+ images = batch[0].to(dtype=merge_dtype)
221
+ with self.profile("forward pass"):
222
+ logits = self.compute_logits(module, images, task)
223
+ loss = entropy_loss(logits)
224
+ total_loss = loss if total_loss is None else total_loss + loss
225
+
226
+ with self.profile("compute grad"):
227
+ self.fabric.backward(total_loss)
228
+
229
+ with self.profile("optimizer step"):
230
+ optimizer.step()
231
+ optimizer.zero_grad()
232
+
233
+ if lr_scheduler is not None:
234
+ lr_scheduler.step()
235
+
236
+ metrics.update({"train/loss": loss.item()})
237
+ self.fabric.log_dict(metrics, step=step_idx)
238
+ pbar.set_postfix(metrics)
239
+
240
+ if (step_idx + 1) % self.config.save_interval == 0:
241
+ with self.profiler.profile("save checkpoint"):
242
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
243
+ if not os.path.exists(save_dir):
244
+ os.makedirs(save_dir, exist_ok=True)
245
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
246
+ print(f"saving checkpoint to {save_path}")
247
+ state = {"model": mask_model}
248
+ self.fabric.save(save_path, state)
249
+
250
+ # Create or update a symbolic link to the latest checkpoint
251
+ if self.fabric.is_global_zero:
252
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
253
+ if os.path.exists(symlink_path):
254
+ os.remove(symlink_path)
255
+ os.link(os.path.abspath(save_path), symlink_path)
256
+
257
+ self.print_profile_summary()
258
+
259
+ def run(self, modelpool: OpenCLIPVisionModelPool):
260
+ self.modelpool = modelpool
261
+ merge_dtype = parse_dtype(self.merge_dtype)
262
+
263
+ with self.profile("setup models"):
264
+ module, mask_model = self.setup_models()
265
+ self.setup_zero_shot_classification_head(freeze=True, dtype=merge_dtype)
266
+
267
+ if self.mask_checkpoint is None:
268
+ if not self.skip_training:
269
+ clear_cuda_cache()
270
+ self.train_mask(module, mask_model=mask_model)
271
+ else:
272
+ if self.fabric.is_global_zero:
273
+ print("loading mask from checkpoint", self.mask_checkpoint)
274
+ self.fabric.load(self.mask_checkpoint, {"model": mask_model})
275
+
276
+ with torch.no_grad():
277
+ clear_cuda_cache()
278
+ mask = mask_model.sample_mask(
279
+ mask_type=self.eval_mask_type, temperature=self.temperature
280
+ )
281
+ # rescale mask
282
+ for name, m in mask.items():
283
+ mask[name] = m / torch.mean(m)
284
+ model = module.merge_and_unload(mask)
285
+ return model.to(dtype=torch.float32)
@@ -249,12 +249,13 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
249
249
  # - SVD finds the principal components (most important directions)
250
250
  # - Task vectors are reconstructed using only the most significant components
251
251
  # - The reconstructed vectors are merged (summed) to create a unified task vector
252
- new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
253
- task_vectors,
254
- exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
255
- accelerator=accelerator, # Use GPU if available
256
- return_single_task_models=self.return_single_task_models,
257
- )
252
+ with torch.no_grad():
253
+ new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
254
+ task_vectors,
255
+ exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
256
+ accelerator=accelerator, # Use GPU if available
257
+ return_single_task_models=self.return_single_task_models,
258
+ )
258
259
 
259
260
  # Handle the case where individual transformed task vectors are also returned
260
261
  if self.return_single_task_models:
@@ -311,7 +311,6 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(
311
311
 
312
312
  ###############
313
313
  #### TSV Merge Orthogonalization
314
- @torch.no_grad()
315
314
  def compute_and_sum_svd_mem_reduction(
316
315
  task_vectors: List[StateDictType],
317
316
  exclude_keys: Optional[List[str]] = None,
@@ -10,6 +10,7 @@ from lightning.fabric.loggers import TensorBoardLogger
10
10
  from lightning.fabric.utilities.rank_zero import rank_zero_only
11
11
  from omegaconf import DictConfig, OmegaConf
12
12
 
13
+ from fusion_bench.constants import RuntimeConstants
13
14
  from fusion_bench.utils import import_object
14
15
  from fusion_bench.utils.instantiate_utils import instantiate
15
16
 
@@ -206,14 +207,7 @@ class LightningFabricMixin:
206
207
  Returns:
207
208
  bool: True if fast_dev_run is enabled, False otherwise.
208
209
  """
209
- if hasattr(self, "config") and self.config.get("fast_dev_run", False):
210
- return True
211
- elif hasattr(self, "_program") and self._program.config.get(
212
- "fast_dev_run", False
213
- ):
214
- return True
215
- else:
216
- return False
210
+ return RuntimeConstants().debug
217
211
 
218
212
  def log(self, name: str, value: Any, step: Optional[int] = None):
219
213
  """
@@ -1,11 +1,165 @@
1
+ import functools
1
2
  import logging
3
+ from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Literal, Optional
2
4
 
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
3
11
  from fusion_bench.mixins import LightningFabricMixin
4
- from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
12
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
13
+ from fusion_bench.models.open_clip import (
14
+ ClassificationHead,
15
+ ImageClassifier,
16
+ ImageEncoder,
17
+ )
18
+ from fusion_bench.utils.data import InfiniteDataLoader
5
19
 
6
20
  log = logging.getLogger(__name__)
7
21
 
8
22
 
9
23
  class OpenCLIPClassificationMixin(LightningFabricMixin):
24
+
10
25
  _train_processor = None
11
26
  _test_processor = None
27
+ dataloader_kwargs: DictConfig
28
+ modelpool: OpenCLIPVisionModelPool
29
+ zero_shot_heads: Dict[str, ClassificationHead] = {}
30
+
31
+ def _init_processor(self, encoder: Optional["ImageEncoder"] = None):
32
+ """
33
+ Initialize the CLIP processors for training and testing.
34
+ """
35
+ if encoder is None:
36
+ encoder: "ImageEncoder" = self.modelpool.load_pretrained_or_first_model()
37
+ self._train_processor = encoder.train_preprocess
38
+ self._test_processor = encoder.val_preprocess
39
+ return self._train_processor, self._test_processor
40
+
41
+ def get_clip_processor(self, stage: Literal["train", "test"]):
42
+ """
43
+ Get the CLIP processor, loading it from the model pool if necessary.
44
+
45
+ Returns:
46
+ CLIPProcessor: The CLIP processor for image and text preprocessing.
47
+
48
+ Raises:
49
+ AssertionError: If the model pool is not set.
50
+ """
51
+ if stage == "train":
52
+ if self._train_processor is None:
53
+ self._init_processor()
54
+ return self._train_processor
55
+ elif stage == "test":
56
+ if self._test_processor is None:
57
+ self._init_processor()
58
+ return self._test_processor
59
+ else:
60
+ raise ValueError(f"Invalid stage: {stage}")
61
+
62
+ def setup_zero_shot_classification_head(
63
+ self,
64
+ task_names: Optional[List[str]] = None,
65
+ freeze: bool = True,
66
+ dtype: Optional[torch.dtype] = None,
67
+ ):
68
+ # check task names consistency across processes
69
+ _task_names = self.fabric.broadcast(task_names, src=0)
70
+ if not self.fabric.is_global_zero and task_names != _task_names:
71
+ raise ValueError("The `task_names` must be the same across all processes.")
72
+
73
+ for task in tqdm(
74
+ self.modelpool.model_names if task_names is None else task_names,
75
+ "Setting up zero-shot classification head",
76
+ disable=not self.fabric.is_global_zero,
77
+ ):
78
+ head = self.modelpool.load_classification_head(task)
79
+ if freeze:
80
+ head.requires_grad_(False)
81
+ if dtype is not None:
82
+ head = head.to(dtype=dtype)
83
+ self.zero_shot_heads[task] = self.to_device(head)
84
+
85
+ def set_clip_processor(self, stage: Literal["train", "test"], processor: Callable):
86
+ """
87
+ Set the CLIP processor for a specific stage.
88
+
89
+ Args:
90
+ stage (Literal["train", "test"]): The stage for which to set the processor.
91
+ processor (Callable): The CLIP processor to set.
92
+ """
93
+ if stage == "train":
94
+ self._train_processor = processor
95
+ elif stage == "test":
96
+ self._test_processor = processor
97
+ else:
98
+ raise ValueError(f"Invalid stage: {stage}")
99
+
100
+ @functools.cache
101
+ def get_shuffled_test_loader_iter(
102
+ self,
103
+ task: str,
104
+ batch_size: Optional[int] = None,
105
+ num_workers: Optional[int] = None,
106
+ **loader_kwargs,
107
+ ) -> Iterator:
108
+ """
109
+ Get an iterator for a shuffled test DataLoader.
110
+
111
+ This method creates a DataLoader for the test dataset of the specified task,
112
+ with shuffling enabled. It allows for optional customization of batch size,
113
+ number of workers, and other DataLoader keyword arguments.
114
+
115
+ Args:
116
+ task (str): The task identifier for which the test dataset is to be loaded.
117
+ batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
118
+ num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
119
+ **loader_kwargs: Additional keyword arguments to pass to the DataLoader.
120
+
121
+ Returns:
122
+ Iterator: An iterator over the shuffled test DataLoader.
123
+ """
124
+ # get dataloader kwargs
125
+ dataloader_kwargs = self.dataloader_kwargs.copy()
126
+ dataloader_kwargs["shuffle"] = True
127
+ if batch_size is not None:
128
+ dataloader_kwargs["batch_size"] = batch_size
129
+ if num_workers is not None:
130
+ dataloader_kwargs["num_workers"] = num_workers
131
+ dataloader_kwargs.update(loader_kwargs)
132
+
133
+ # get the test dataset
134
+ clip_dataset = CLIPDataset(
135
+ self.modelpool.load_test_dataset(task),
136
+ processor=self.get_clip_processor(stage="test"),
137
+ )
138
+ # create the dataloader
139
+ loader = DataLoader(clip_dataset, **dataloader_kwargs)
140
+ loader = self.fabric.setup_dataloaders(loader)
141
+ return iter(InfiniteDataLoader(loader))
142
+
143
+ def compute_logits(
144
+ self,
145
+ module: ImageClassifier,
146
+ images,
147
+ task: str,
148
+ ):
149
+ """
150
+ Compute the logits for a batch of images using the provided module and task.
151
+
152
+ Args:
153
+ module (ImageClassifier): The image classification module to use for computing logits.
154
+ images (torch.Tensor): The batch of images for which to compute logits.
155
+ task (str): The task identifier to specify which classification head to use.
156
+
157
+ Returns:
158
+ torch.Tensor: The computed logits for the input images.
159
+ """
160
+ if len(self.zero_shot_heads) == 0:
161
+ self.setup_zero_shot_classification_head()
162
+ task_head = self.zero_shot_heads[task]
163
+ features = module(images)
164
+ logits = task_head(features)
165
+ return logits
@@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
+ from fusion_bench import TorchModelType
10
11
  from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
12
  from fusion_bench.utils import (
12
13
  ValidationError,
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import pickle
3
3
  import sys
4
- from typing import Callable, Optional, Union, cast
4
+ from typing import Callable, Optional, Union, cast, override
5
5
 
6
6
  import torch
7
7
  from datasets import load_dataset
@@ -41,8 +41,8 @@ def _check_and_redirect_open_clip_modeling():
41
41
  )
42
42
 
43
43
  try:
44
- import src
45
- import src.modeling
44
+ import src # type: ignore
45
+ import src.modeling # type: ignore
46
46
  except ImportError:
47
47
  if "src" not in sys.modules:
48
48
  # redirect the import of `src` to `fusion_bench.models.open_clip`
@@ -114,6 +114,7 @@ class OpenCLIPVisionModelPool(BaseModelPool):
114
114
  self._test_processor = encoder.val_preprocess
115
115
  return self._test_processor
116
116
 
117
+ @override
117
118
  def load_model(
118
119
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
119
120
  ) -> ImageEncoder:
@@ -210,6 +211,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
210
211
  - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
211
212
  - Default, load the model using `instantiate` from hydra.
212
213
  """
214
+ if self._classification_heads is None:
215
+ raise ValueError("No classification heads are defined in the model pool.")
213
216
  if (
214
217
  isinstance(model_name_or_config, str)
215
218
  and model_name_or_config in self._classification_heads
@@ -222,6 +225,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
222
225
  return head
223
226
 
224
227
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
228
+ if self._train_datasets is None:
229
+ raise ValueError("No train datasets are defined in the model pool.")
225
230
  dataset_config = self._train_datasets[dataset_name]
226
231
  if isinstance(dataset_config, str):
227
232
  log.info(
@@ -233,6 +238,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
233
238
  return dataset
234
239
 
235
240
  def load_val_dataset(self, dataset_name: str, *args, **kwargs):
241
+ if self._val_datasets is None:
242
+ raise ValueError("No val datasets are defined in the model pool.")
236
243
  dataset_config = self._val_datasets[dataset_name]
237
244
  if isinstance(dataset_config, str):
238
245
  log.info(
@@ -244,6 +251,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
244
251
  return dataset
245
252
 
246
253
  def load_test_dataset(self, dataset_name: str, *args, **kwargs):
254
+ if self._test_datasets is None:
255
+ raise ValueError("No test datasets are defined in the model pool.")
247
256
  dataset_config = self._test_datasets[dataset_name]
248
257
  if isinstance(dataset_config, str):
249
258
  log.info(