fusion-bench 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fusion_bench/constants/runtime.py +4 -1
  2. fusion_bench/method/__init__.py +9 -1
  3. fusion_bench/method/base_algorithm.py +29 -19
  4. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  5. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  6. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  7. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  8. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  9. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  10. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  11. fusion_bench/metrics/model_kinship/utility.py +184 -0
  12. fusion_bench/mixins/lightning_fabric.py +2 -8
  13. fusion_bench/mixins/openclip_classification.py +155 -1
  14. fusion_bench/modelpool/base_pool.py +1 -0
  15. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  16. fusion_bench/models/masks/mask_model.py +8 -2
  17. fusion_bench/models/open_clip/modeling.py +68 -5
  18. fusion_bench/models/open_clip/utils.py +13 -2
  19. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  20. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  21. fusion_bench/py.typed +1 -0
  22. fusion_bench/scripts/cli.py +21 -16
  23. fusion_bench/scripts/imgui.py +2 -2
  24. fusion_bench/scripts/webui.py +2 -2
  25. fusion_bench/utils/__init__.py +2 -0
  26. fusion_bench/utils/devices.py +3 -1
  27. fusion_bench/utils/hydra_utils.py +75 -0
  28. fusion_bench/utils/instantiate_utils.py +29 -18
  29. fusion_bench/utils/misc.py +16 -0
  30. fusion_bench/utils/parameters.py +33 -0
  31. fusion_bench/utils/rich_utils.py +165 -25
  32. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +7 -7
  33. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +41 -34
  34. fusion_bench_config/README.md +9 -0
  35. fusion_bench_config/fabric/auto.yaml +1 -0
  36. fusion_bench_config/hydra/default.yaml +3 -1
  37. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  38. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
  39. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
  40. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
  41. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,184 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import List
4
+
5
+ import click
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PretrainedConfig,
13
+ )
14
+
15
+
16
+ class Metric(str, Enum):
17
+ """Enumeration of supported metrics"""
18
+
19
+ PCC = "pcc"
20
+ ED = "ed"
21
+ CS = "cs"
22
+
23
+ @classmethod
24
+ def list(cls) -> List[str]:
25
+ """Return list of supported metric values"""
26
+ return [metric.value for metric in cls]
27
+
28
+
29
+ def get_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
30
+ """
31
+ Fetch the configuration of a pretrained model from HuggingFace.
32
+
33
+ Args:
34
+ model (str): The name or path of the model to load configuration for.
35
+ trust_remote_code (bool, optional): Whether to trust remote code during loading.
36
+ Defaults to False.
37
+
38
+ Returns:
39
+ PretrainedConfig: The configuration object of the specified model.
40
+ """
41
+ # Fetch the configuration from HuggingFace's model hub.
42
+ config = AutoConfig.from_pretrained(
43
+ model,
44
+ trust_remote_code=trust_remote_code, # Whether to allow remote code execution.
45
+ )
46
+ return config
47
+
48
+
49
+ def validate_models(model_1: str, model_2: str, base_model: str) -> None:
50
+ """
51
+ Validate model names to ensure they are different and exist.
52
+
53
+ Args:
54
+ model_1: Name of the first model
55
+ model_2: Name of the second model
56
+ base_model: Name of the base model
57
+
58
+ Raises:
59
+ click.BadParameter: If validation fails
60
+ """
61
+ if model_1 == model_2 or model_1 == base_model or model_2 == base_model:
62
+ raise click.BadParameter("All model names must be different")
63
+
64
+
65
+ def quantize_8bit(x: torch.Tensor) -> torch.Tensor:
66
+ # Get absolute min and max values
67
+ abs_max = torch.max(torch.abs(x))
68
+
69
+ # Scale to [-127, 127] range for 8-bit signed integers
70
+ # Using 127 instead of 128 to keep zero exactly representable
71
+ scaled = 127 * (x / abs_max)
72
+
73
+ # Round to nearest integer
74
+ quantized = torch.round(scaled)
75
+
76
+ # Clamp values to ensure they stay in valid range
77
+ quantized = torch.clamp(quantized, -127, 127)
78
+
79
+ return quantized
80
+
81
+
82
+ def load_model_state_dict(model_name: str, device: str) -> dict:
83
+ """
84
+ Load a model and return its state dictionary.
85
+
86
+ Args:
87
+ model_name (str): Name or path of the model to load
88
+ device (str): Device to load the model on ('cuda' or 'cpu')
89
+
90
+ Returns:
91
+ dict: State dictionary of the loaded model
92
+ """
93
+ logging.info(f"Loading model: {model_name}")
94
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
95
+ state_dict = model.state_dict()
96
+ del model # Free memory
97
+ return state_dict
98
+
99
+
100
+ def extract_delta_parameters(
101
+ model_1_name: str,
102
+ model_2_name: str,
103
+ model_base_name: str,
104
+ low_precision: bool,
105
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
106
+ ) -> tuple[torch.Tensor, torch.Tensor]:
107
+ """
108
+ Extract the delta parameters (weight differences) between two models
109
+ relative to a base model.
110
+
111
+ Args:
112
+ model_1_name (str): Name or path of the first model.
113
+ model_2_name (str): Name or path of the second model.
114
+ model_base_name (str): Name or path of the base model for comparison.
115
+ low_precision (bool): Whether to use low precision weights
116
+
117
+ Returns:
118
+ (torch.Tensor, torch.Tensor): Delta parameters of model_1 and model_2 relative to base model.
119
+ """
120
+
121
+ # Extract state dictionaries from models
122
+ state_dict_1 = load_model_state_dict(model_1_name, device)
123
+ state_dict_2 = load_model_state_dict(model_2_name, device)
124
+ state_dict_base = load_model_state_dict(model_base_name, device)
125
+
126
+ # Determine the number of layers
127
+ num_layers = state_dict_base["lm_head.weight"].shape[0]
128
+
129
+ # Check if model architectures match, log a warning if not
130
+ if (
131
+ state_dict_1["lm_head.weight"].shape[0]
132
+ != state_dict_2["lm_head.weight"].shape[0]
133
+ ):
134
+ shape_1 = state_dict_1["lm_head.weight"].shape
135
+ shape_2 = state_dict_2["lm_head.weight"].shape
136
+ logging.warning(
137
+ f"Warning: Model architectures do not match. "
138
+ f"Using sub weight space instead.\n"
139
+ f"Vocab sizes in model 1: {shape_1[0]}, "
140
+ f"Vocab sizes in model 2: {shape_2[0]}"
141
+ )
142
+
143
+ # Initialize lists to store delta parameters for both models
144
+ d_vector_1, d_vector_2 = [], []
145
+
146
+ # Iterate over keys in the base model's state dictionary with tqdm
147
+ for key, base_params in tqdm(
148
+ state_dict_base.items(), desc="Processing keys", unit="key"
149
+ ):
150
+ # Only proceed if key exists in both models
151
+ try:
152
+ if key not in state_dict_1 or key not in state_dict_2:
153
+ logging.warning(f"Key {key} not found in one of the models")
154
+ continue
155
+ except Exception as e:
156
+ logging.error(f"Error processing key {key}: {str(e)}")
157
+
158
+ # Get the parameters for each model (truncate to num_layers for consistency)
159
+ params_1 = state_dict_1[key][:num_layers]
160
+ params_2 = state_dict_2[key][:num_layers]
161
+
162
+ # Compute the deltas relative to the base model
163
+ delta_1 = (params_1 - base_params).view(-1)
164
+ delta_2 = (params_2 - base_params).view(-1)
165
+
166
+ # Accumulate deltas
167
+ d_vector_1.append(delta_1)
168
+ d_vector_2.append(delta_2)
169
+
170
+ # Clear memory
171
+ del state_dict_1, state_dict_2, state_dict_base
172
+
173
+ logging.info("Concatenating delta vectors...")
174
+
175
+ d_vector_1 = torch.cat(d_vector_1)
176
+ d_vector_2 = torch.cat(d_vector_2)
177
+
178
+ if low_precision:
179
+ logging.info("Quantizing delta vectors to 8-bit precision...")
180
+ d_vector_1 = quantize_8bit(d_vector_1)
181
+ d_vector_2 = quantize_8bit(d_vector_2)
182
+ logging.info("Quantization complete")
183
+
184
+ return d_vector_1, d_vector_2
@@ -10,6 +10,7 @@ from lightning.fabric.loggers import TensorBoardLogger
10
10
  from lightning.fabric.utilities.rank_zero import rank_zero_only
11
11
  from omegaconf import DictConfig, OmegaConf
12
12
 
13
+ from fusion_bench.constants import RuntimeConstants
13
14
  from fusion_bench.utils import import_object
14
15
  from fusion_bench.utils.instantiate_utils import instantiate
15
16
 
@@ -206,14 +207,7 @@ class LightningFabricMixin:
206
207
  Returns:
207
208
  bool: True if fast_dev_run is enabled, False otherwise.
208
209
  """
209
- if hasattr(self, "config") and self.config.get("fast_dev_run", False):
210
- return True
211
- elif hasattr(self, "_program") and self._program.config.get(
212
- "fast_dev_run", False
213
- ):
214
- return True
215
- else:
216
- return False
210
+ return RuntimeConstants().debug
217
211
 
218
212
  def log(self, name: str, value: Any, step: Optional[int] = None):
219
213
  """
@@ -1,11 +1,165 @@
1
+ import functools
1
2
  import logging
3
+ from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Literal, Optional
2
4
 
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
3
11
  from fusion_bench.mixins import LightningFabricMixin
4
- from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
12
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
13
+ from fusion_bench.models.open_clip import (
14
+ ClassificationHead,
15
+ ImageClassifier,
16
+ ImageEncoder,
17
+ )
18
+ from fusion_bench.utils.data import InfiniteDataLoader
5
19
 
6
20
  log = logging.getLogger(__name__)
7
21
 
8
22
 
9
23
  class OpenCLIPClassificationMixin(LightningFabricMixin):
24
+
10
25
  _train_processor = None
11
26
  _test_processor = None
27
+ dataloader_kwargs: DictConfig
28
+ modelpool: OpenCLIPVisionModelPool
29
+ zero_shot_heads: Dict[str, ClassificationHead] = {}
30
+
31
+ def _init_processor(self, encoder: Optional["ImageEncoder"] = None):
32
+ """
33
+ Initialize the CLIP processors for training and testing.
34
+ """
35
+ if encoder is None:
36
+ encoder: "ImageEncoder" = self.modelpool.load_pretrained_or_first_model()
37
+ self._train_processor = encoder.train_preprocess
38
+ self._test_processor = encoder.val_preprocess
39
+ return self._train_processor, self._test_processor
40
+
41
+ def get_clip_processor(self, stage: Literal["train", "test"]):
42
+ """
43
+ Get the CLIP processor, loading it from the model pool if necessary.
44
+
45
+ Returns:
46
+ CLIPProcessor: The CLIP processor for image and text preprocessing.
47
+
48
+ Raises:
49
+ AssertionError: If the model pool is not set.
50
+ """
51
+ if stage == "train":
52
+ if self._train_processor is None:
53
+ self._init_processor()
54
+ return self._train_processor
55
+ elif stage == "test":
56
+ if self._test_processor is None:
57
+ self._init_processor()
58
+ return self._test_processor
59
+ else:
60
+ raise ValueError(f"Invalid stage: {stage}")
61
+
62
+ def setup_zero_shot_classification_head(
63
+ self,
64
+ task_names: Optional[List[str]] = None,
65
+ freeze: bool = True,
66
+ dtype: Optional[torch.dtype] = None,
67
+ ):
68
+ # check task names consistency across processes
69
+ _task_names = self.fabric.broadcast(task_names, src=0)
70
+ if not self.fabric.is_global_zero and task_names != _task_names:
71
+ raise ValueError("The `task_names` must be the same across all processes.")
72
+
73
+ for task in tqdm(
74
+ self.modelpool.model_names if task_names is None else task_names,
75
+ "Setting up zero-shot classification head",
76
+ disable=not self.fabric.is_global_zero,
77
+ ):
78
+ head = self.modelpool.load_classification_head(task)
79
+ if freeze:
80
+ head.requires_grad_(False)
81
+ if dtype is not None:
82
+ head = head.to(dtype=dtype)
83
+ self.zero_shot_heads[task] = self.to_device(head)
84
+
85
+ def set_clip_processor(self, stage: Literal["train", "test"], processor: Callable):
86
+ """
87
+ Set the CLIP processor for a specific stage.
88
+
89
+ Args:
90
+ stage (Literal["train", "test"]): The stage for which to set the processor.
91
+ processor (Callable): The CLIP processor to set.
92
+ """
93
+ if stage == "train":
94
+ self._train_processor = processor
95
+ elif stage == "test":
96
+ self._test_processor = processor
97
+ else:
98
+ raise ValueError(f"Invalid stage: {stage}")
99
+
100
+ @functools.cache
101
+ def get_shuffled_test_loader_iter(
102
+ self,
103
+ task: str,
104
+ batch_size: Optional[int] = None,
105
+ num_workers: Optional[int] = None,
106
+ **loader_kwargs,
107
+ ) -> Iterator:
108
+ """
109
+ Get an iterator for a shuffled test DataLoader.
110
+
111
+ This method creates a DataLoader for the test dataset of the specified task,
112
+ with shuffling enabled. It allows for optional customization of batch size,
113
+ number of workers, and other DataLoader keyword arguments.
114
+
115
+ Args:
116
+ task (str): The task identifier for which the test dataset is to be loaded.
117
+ batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
118
+ num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
119
+ **loader_kwargs: Additional keyword arguments to pass to the DataLoader.
120
+
121
+ Returns:
122
+ Iterator: An iterator over the shuffled test DataLoader.
123
+ """
124
+ # get dataloader kwargs
125
+ dataloader_kwargs = self.dataloader_kwargs.copy()
126
+ dataloader_kwargs["shuffle"] = True
127
+ if batch_size is not None:
128
+ dataloader_kwargs["batch_size"] = batch_size
129
+ if num_workers is not None:
130
+ dataloader_kwargs["num_workers"] = num_workers
131
+ dataloader_kwargs.update(loader_kwargs)
132
+
133
+ # get the test dataset
134
+ clip_dataset = CLIPDataset(
135
+ self.modelpool.load_test_dataset(task),
136
+ processor=self.get_clip_processor(stage="test"),
137
+ )
138
+ # create the dataloader
139
+ loader = DataLoader(clip_dataset, **dataloader_kwargs)
140
+ loader = self.fabric.setup_dataloaders(loader)
141
+ return iter(InfiniteDataLoader(loader))
142
+
143
+ def compute_logits(
144
+ self,
145
+ module: ImageClassifier,
146
+ images,
147
+ task: str,
148
+ ):
149
+ """
150
+ Compute the logits for a batch of images using the provided module and task.
151
+
152
+ Args:
153
+ module (ImageClassifier): The image classification module to use for computing logits.
154
+ images (torch.Tensor): The batch of images for which to compute logits.
155
+ task (str): The task identifier to specify which classification head to use.
156
+
157
+ Returns:
158
+ torch.Tensor: The computed logits for the input images.
159
+ """
160
+ if len(self.zero_shot_heads) == 0:
161
+ self.setup_zero_shot_classification_head()
162
+ task_head = self.zero_shot_heads[task]
163
+ features = module(images)
164
+ logits = task_head(features)
165
+ return logits
@@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
+ from fusion_bench import TorchModelType
10
11
  from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
12
  from fusion_bench.utils import (
12
13
  ValidationError,
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import pickle
3
3
  import sys
4
- from typing import Callable, Optional, Union, cast
4
+ from typing import Callable, Optional, Union, cast, override
5
5
 
6
6
  import torch
7
7
  from datasets import load_dataset
@@ -41,8 +41,8 @@ def _check_and_redirect_open_clip_modeling():
41
41
  )
42
42
 
43
43
  try:
44
- import src
45
- import src.modeling
44
+ import src # type: ignore
45
+ import src.modeling # type: ignore
46
46
  except ImportError:
47
47
  if "src" not in sys.modules:
48
48
  # redirect the import of `src` to `fusion_bench.models.open_clip`
@@ -114,6 +114,7 @@ class OpenCLIPVisionModelPool(BaseModelPool):
114
114
  self._test_processor = encoder.val_preprocess
115
115
  return self._test_processor
116
116
 
117
+ @override
117
118
  def load_model(
118
119
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
119
120
  ) -> ImageEncoder:
@@ -210,6 +211,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
210
211
  - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
211
212
  - Default, load the model using `instantiate` from hydra.
212
213
  """
214
+ if self._classification_heads is None:
215
+ raise ValueError("No classification heads are defined in the model pool.")
213
216
  if (
214
217
  isinstance(model_name_or_config, str)
215
218
  and model_name_or_config in self._classification_heads
@@ -222,6 +225,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
222
225
  return head
223
226
 
224
227
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
228
+ if self._train_datasets is None:
229
+ raise ValueError("No train datasets are defined in the model pool.")
225
230
  dataset_config = self._train_datasets[dataset_name]
226
231
  if isinstance(dataset_config, str):
227
232
  log.info(
@@ -233,6 +238,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
233
238
  return dataset
234
239
 
235
240
  def load_val_dataset(self, dataset_name: str, *args, **kwargs):
241
+ if self._val_datasets is None:
242
+ raise ValueError("No val datasets are defined in the model pool.")
236
243
  dataset_config = self._val_datasets[dataset_name]
237
244
  if isinstance(dataset_config, str):
238
245
  log.info(
@@ -244,6 +251,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
244
251
  return dataset
245
252
 
246
253
  def load_test_dataset(self, dataset_name: str, *args, **kwargs):
254
+ if self._test_datasets is None:
255
+ raise ValueError("No test datasets are defined in the model pool.")
247
256
  dataset_config = self._test_datasets[dataset_name]
248
257
  if isinstance(dataset_config, str):
249
258
  log.info(
@@ -113,21 +113,27 @@ class MaskModel(ParameterDictModel):
113
113
  def get_distribution(
114
114
  self,
115
115
  mask_type: Literal["discrete", "continuous"],
116
+ temperature: float = 0.5,
116
117
  **kwargs,
117
118
  ):
118
119
  return {
119
- name: self._param_to_distribution(param, mask_type=mask_type, **kwargs)
120
+ name: self._param_to_distribution(
121
+ param, mask_type=mask_type, temperature=temperature, **kwargs
122
+ )
120
123
  for name, param in self.named_parameters()
121
124
  }
122
125
 
123
126
  def sample_mask(
124
127
  self,
125
128
  mask_type: Literal["discrete", "continuous"] = "discrete",
129
+ temperature: float = 0.5,
126
130
  **kwargs,
127
131
  ):
128
132
  mask = {}
129
133
  for name, param in self.named_parameters():
130
- dist = self._param_to_distribution(param, mask_type, **kwargs)
134
+ dist = self._param_to_distribution(
135
+ param, mask_type, temperature=temperature, **kwargs
136
+ )
131
137
  if mask_type == "discrete":
132
138
  mask[name] = dist.sample()
133
139
  elif mask_type == "continuous":
@@ -1,3 +1,25 @@
1
+ """
2
+ OpenCLIP model wrappers used by FusionBench.
3
+
4
+ This module provides lightweight `torch.nn.Module` wrappers around OpenCLIP
5
+ components that are commonly used throughout FusionBench experiments:
6
+
7
+ - `ImageEncoder`: loads an OpenCLIP image encoder and exposes `encode_image`.
8
+ - `ClassificationHead`: a linear head optionally normalizing inputs.
9
+ - `ImageClassifier` / `MultiHeadImageClassifier`: convenience compositions.
10
+
11
+ Note:
12
+ This module requires the optional dependency `open_clip_torch`.
13
+ """
14
+
15
+ from fusion_bench.utils.packages import is_open_clip_available
16
+
17
+ if not is_open_clip_available():
18
+ raise ImportError(
19
+ "open_clip is not installed. Please install it with `pip install open_clip_torch`."
20
+ )
21
+
22
+ from pathlib import Path
1
23
  from typing import Callable, List
2
24
 
3
25
  import open_clip
@@ -10,6 +32,19 @@ from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
10
32
 
11
33
  class ImageEncoder(torch.nn.Module):
12
34
  R"""
35
+ OpenCLIP image encoder wrapper.
36
+
37
+ This class loads an OpenCLIP model by name and exposes a forward pass that
38
+ returns image embeddings via `model.encode_image`.
39
+
40
+ Args:
41
+ model_name: A model name supported by `open_clip`. FusionBench also
42
+ supports suffixes:
43
+ - ``"__pretrained__<tag>"`` to select a specific pretrained weights tag.
44
+ - ``"__init__"`` to use random initialization.
45
+ keep_lang: If False (default), removes the text encoder (when present)
46
+ to reduce memory usage.
47
+
13
48
  Examples:
14
49
 
15
50
  load the image encoder for a given model name
@@ -18,7 +53,7 @@ class ImageEncoder(torch.nn.Module):
18
53
  >>> image_encoder = ImageEncoder(model_name="ViT-B-32")
19
54
  """
20
55
 
21
- def __init__(self, model_name: str, keep_lang=False):
56
+ def __init__(self, model_name: str, keep_lang: bool = False):
22
57
  super().__init__()
23
58
  assert (
24
59
  model_name in MODELS
@@ -42,22 +77,26 @@ class ImageEncoder(torch.nn.Module):
42
77
 
43
78
  self.cache_dir = CACHEDIR
44
79
 
80
+ # if `keep_lang` is False, remove the text encoder to save memory
45
81
  if not keep_lang and hasattr(self.model, "transformer"):
46
82
  delattr(self.model, "transformer")
47
83
 
48
- def forward(self, images):
84
+ def forward(self, images: Tensor) -> Tensor:
85
+ """Encode a batch of images into embedding vectors."""
49
86
  assert self.model is not None
50
87
  return self.model.encode_image(images)
51
88
 
52
- def __call__(self, inputs):
89
+ def __call__(self, inputs: Tensor) -> Tensor:
53
90
  return self.forward(inputs)
54
91
 
55
- def save(self, filename):
92
+ def save(self, filename: str) -> None:
93
+ """Serialize this module to disk."""
56
94
  print(f"Saving image encoder to {filename}")
57
95
  utils.torch_save(self, filename)
58
96
 
59
97
  @classmethod
60
- def load(cls, model_name, filename):
98
+ def load(cls, model_name: str, filename: str | Path):
99
+ """Load a saved encoder state dict into a freshly constructed encoder."""
61
100
  print(f"Loading image encoder from {filename}")
62
101
 
63
102
  state_dict = torch.load(filename, map_location="cpu")
@@ -68,6 +107,15 @@ class ImageEncoder(torch.nn.Module):
68
107
 
69
108
 
70
109
  class ClassificationHead(torch.nn.Linear):
110
+ """A linear classification head with optional input normalization.
111
+
112
+ Args:
113
+ normalize: If True, L2-normalize inputs along the last dimension before
114
+ applying the linear projection.
115
+ weights: Weight matrix of shape (num_classes, feature_dim).
116
+ biases: Optional bias vector of shape (num_classes,).
117
+ """
118
+
71
119
  def __init__(
72
120
  self,
73
121
  normalize: bool,
@@ -85,6 +133,7 @@ class ClassificationHead(torch.nn.Linear):
85
133
  self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
86
134
 
87
135
  def forward(self, inputs: Tensor):
136
+ """Compute logits from input features."""
88
137
  if self.normalize:
89
138
  inputs = inputs / inputs.norm(dim=-1, keepdim=True)
90
139
  return super().forward(inputs)
@@ -93,11 +142,13 @@ class ClassificationHead(torch.nn.Linear):
93
142
  return self.forward(inputs)
94
143
 
95
144
  def save(self, filename):
145
+ """Serialize this head to disk."""
96
146
  print(f"Saving classification head to {filename}")
97
147
  utils.torch_save(self, filename, save_state_dict=False)
98
148
 
99
149
  @classmethod
100
150
  def load(cls, filename):
151
+ """Load a serialized `ClassificationHead` instance from disk."""
101
152
  # print(f"Loading classification head from {filename}")
102
153
  return utils.torch_load(filename)
103
154
 
@@ -106,6 +157,8 @@ class ImageClassifier(torch.nn.Module):
106
157
  train_preprocess: Callable
107
158
  val_preprocess: Callable
108
159
 
160
+ """Convenience module combining an `ImageEncoder` and a `ClassificationHead`."""
161
+
109
162
  def __init__(
110
163
  self,
111
164
  image_encoder: ImageEncoder,
@@ -119,10 +172,12 @@ class ImageClassifier(torch.nn.Module):
119
172
  self.val_preprocess = self.image_encoder.val_preprocess
120
173
 
121
174
  def freeze_head(self):
175
+ """Disable gradient computation for the classification head."""
122
176
  self.classification_head.weight.requires_grad_(False)
123
177
  self.classification_head.bias.requires_grad_(False)
124
178
 
125
179
  def forward(self, inputs: Tensor):
180
+ """Run encoder then head and return logits."""
126
181
  features = self.image_encoder(inputs)
127
182
  outputs = self.classification_head(features)
128
183
  return outputs
@@ -131,16 +186,20 @@ class ImageClassifier(torch.nn.Module):
131
186
  return self.forward(inputs)
132
187
 
133
188
  def save(self, filename):
189
+ """Serialize this module to disk."""
134
190
  print(f"Saving image classifier to {filename}")
135
191
  utils.torch_save(self, filename)
136
192
 
137
193
  @classmethod
138
194
  def load(cls, filename):
195
+ """Load a serialized `ImageClassifier` instance from disk."""
139
196
  print(f"Loading image classifier from {filename}")
140
197
  return utils.torch_load(filename)
141
198
 
142
199
 
143
200
  class MultiHeadImageClassifier(torch.nn.Module):
201
+ """Image encoder with multiple task-specific classification heads."""
202
+
144
203
  def __init__(
145
204
  self,
146
205
  image_encoder: ImageEncoder,
@@ -154,11 +213,13 @@ class MultiHeadImageClassifier(torch.nn.Module):
154
213
  self.val_preprocess = self.image_encoder.val_preprocess
155
214
 
156
215
  def freeze_head(self):
216
+ """Disable gradient computation for all heads."""
157
217
  for idx in range(len(self.classification_heads)):
158
218
  self.classification_heads[idx].weight.requires_grad_(False)
159
219
  self.classification_heads[idx].bias.requires_grad_(False)
160
220
 
161
221
  def forward(self, inputs, head_idx):
222
+ """Run encoder then the selected head and return logits."""
162
223
  features = self.image_encoder(inputs)
163
224
  outputs = self.classification_heads[head_idx](features)
164
225
  return outputs
@@ -167,10 +228,12 @@ class MultiHeadImageClassifier(torch.nn.Module):
167
228
  return self.forward(inputs, head_idx)
168
229
 
169
230
  def save(self, filename):
231
+ """Serialize this module to disk."""
170
232
  print(f"Saving image classifier to {filename}")
171
233
  utils.torch_save(self, filename)
172
234
 
173
235
  @classmethod
174
236
  def load(cls, filename):
237
+ """Load a serialized `MultiHeadImageClassifier` instance from disk."""
175
238
  print(f"Loading image classifier from {filename}")
176
239
  return utils.torch_load(filename)