fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +110 -11
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -0
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +25 -23
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +20 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import pickle
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Callable, Optional, Union, cast
|
|
4
|
+
from typing import Callable, Optional, Union, cast, override
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from datasets import load_dataset
|
|
@@ -41,8 +41,8 @@ def _check_and_redirect_open_clip_modeling():
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
try:
|
|
44
|
-
import src
|
|
45
|
-
import src.modeling
|
|
44
|
+
import src # type: ignore
|
|
45
|
+
import src.modeling # type: ignore
|
|
46
46
|
except ImportError:
|
|
47
47
|
if "src" not in sys.modules:
|
|
48
48
|
# redirect the import of `src` to `fusion_bench.models.open_clip`
|
|
@@ -114,6 +114,7 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
114
114
|
self._test_processor = encoder.val_preprocess
|
|
115
115
|
return self._test_processor
|
|
116
116
|
|
|
117
|
+
@override
|
|
117
118
|
def load_model(
|
|
118
119
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
119
120
|
) -> ImageEncoder:
|
|
@@ -210,6 +211,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
210
211
|
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
211
212
|
- Default, load the model using `instantiate` from hydra.
|
|
212
213
|
"""
|
|
214
|
+
if self._classification_heads is None:
|
|
215
|
+
raise ValueError("No classification heads are defined in the model pool.")
|
|
213
216
|
if (
|
|
214
217
|
isinstance(model_name_or_config, str)
|
|
215
218
|
and model_name_or_config in self._classification_heads
|
|
@@ -222,6 +225,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
222
225
|
return head
|
|
223
226
|
|
|
224
227
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
228
|
+
if self._train_datasets is None:
|
|
229
|
+
raise ValueError("No train datasets are defined in the model pool.")
|
|
225
230
|
dataset_config = self._train_datasets[dataset_name]
|
|
226
231
|
if isinstance(dataset_config, str):
|
|
227
232
|
log.info(
|
|
@@ -233,6 +238,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
233
238
|
return dataset
|
|
234
239
|
|
|
235
240
|
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
241
|
+
if self._val_datasets is None:
|
|
242
|
+
raise ValueError("No val datasets are defined in the model pool.")
|
|
236
243
|
dataset_config = self._val_datasets[dataset_name]
|
|
237
244
|
if isinstance(dataset_config, str):
|
|
238
245
|
log.info(
|
|
@@ -244,6 +251,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
|
|
|
244
251
|
return dataset
|
|
245
252
|
|
|
246
253
|
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
254
|
+
if self._test_datasets is None:
|
|
255
|
+
raise ValueError("No test datasets are defined in the model pool.")
|
|
247
256
|
dataset_config = self._test_datasets[dataset_name]
|
|
248
257
|
if isinstance(dataset_config, str):
|
|
249
258
|
log.info(
|
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -62,16 +62,36 @@ class HFCLIPClassifier(nn.Module):
|
|
|
62
62
|
persistent=False,
|
|
63
63
|
)
|
|
64
64
|
|
|
65
|
+
# NOTE:
|
|
66
|
+
# The property setters seems not to work properly with `nn.Module` attributes.
|
|
67
|
+
# So avoid using them in practice.
|
|
68
|
+
# To set the text or vision model, directly access the attributes.
|
|
69
|
+
# For example:
|
|
70
|
+
# classifier.clip_model.text_model = new_text_model
|
|
71
|
+
# or
|
|
72
|
+
# classifier.clip_model.vision_model = new_vision_model
|
|
73
|
+
# reference: https://github.com/pytorch/pytorch/issues/52664
|
|
74
|
+
|
|
65
75
|
@property
|
|
66
76
|
def text_model(self):
|
|
67
77
|
"""Get the text model component of CLIP."""
|
|
68
78
|
return self.clip_model.text_model
|
|
69
79
|
|
|
80
|
+
@text_model.setter
|
|
81
|
+
def text_model(self, model: nn.Module):
|
|
82
|
+
"""Set the text model component of CLIP."""
|
|
83
|
+
self.clip_model.text_model = model
|
|
84
|
+
|
|
70
85
|
@property
|
|
71
86
|
def vision_model(self):
|
|
72
87
|
"""Get the vision model component of CLIP."""
|
|
73
88
|
return self.clip_model.vision_model
|
|
74
89
|
|
|
90
|
+
@vision_model.setter
|
|
91
|
+
def vision_model(self, model: nn.Module):
|
|
92
|
+
"""Set the vision model component of CLIP."""
|
|
93
|
+
self.clip_model.vision_model = model
|
|
94
|
+
|
|
75
95
|
def set_classification_task(
|
|
76
96
|
self,
|
|
77
97
|
classnames: List[str],
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import ModulatedModel, TaskModulator
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Dict, Generic, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from fusion_bench import TorchModelType
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModulatedModel(nn.Module, Generic[TorchModelType]):
|
|
14
|
+
"""
|
|
15
|
+
A model wrapper that uses task-specific modulators to adapt a shared backbone
|
|
16
|
+
for different tasks.
|
|
17
|
+
|
|
18
|
+
The model maintains a shared backbone and task-specific modulators. During forward pass,
|
|
19
|
+
the appropriate modulator is applied based on the current task.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
_current_task: Optional[str] = None
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
backbone: TorchModelType,
|
|
27
|
+
modulators: Dict[str, "TaskModulator[TorchModelType]"],
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.backbone = backbone
|
|
31
|
+
self.modulators = nn.ModuleDict(modulators)
|
|
32
|
+
|
|
33
|
+
def add_modulator(self, task_name: str, modulator: "TaskModulator[TorchModelType]"):
|
|
34
|
+
"""Add a new task-specific modulator."""
|
|
35
|
+
if task_name in self.modulators:
|
|
36
|
+
raise ValueError(f"Modulator for task '{task_name}' already exists.")
|
|
37
|
+
self.modulators[task_name] = modulator
|
|
38
|
+
|
|
39
|
+
def remove_modulator(self, task_name: str):
|
|
40
|
+
"""Remove an existing task-specific modulator."""
|
|
41
|
+
if task_name not in self.modulators:
|
|
42
|
+
raise ValueError(f"Modulator for task '{task_name}' does not exist.")
|
|
43
|
+
if self._current_task == task_name:
|
|
44
|
+
log.warning(
|
|
45
|
+
f"Removing modulator for current task '{task_name}'. "
|
|
46
|
+
"This will make unset the current task unpredictable."
|
|
47
|
+
)
|
|
48
|
+
del self.modulators[task_name]
|
|
49
|
+
|
|
50
|
+
def set_task(self, task_name: str):
|
|
51
|
+
"""Set the current task for inference."""
|
|
52
|
+
if task_name not in self.modulators:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Task '{task_name}' not found in modulators. Available tasks: {list(self.modulators.keys())}"
|
|
55
|
+
)
|
|
56
|
+
if self._current_task == task_name:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
# unset previous task
|
|
60
|
+
if self._current_task is not None:
|
|
61
|
+
self.modulators[self._current_task].remove(self)
|
|
62
|
+
assert (
|
|
63
|
+
self._current_task is None
|
|
64
|
+
), "Current task should be None after removal."
|
|
65
|
+
|
|
66
|
+
# set new task
|
|
67
|
+
self.modulators[task_name].apply(self)
|
|
68
|
+
self._current_task = task_name
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def current_task(self) -> Optional[str]:
|
|
72
|
+
"""Get the current task name."""
|
|
73
|
+
return self._current_task
|
|
74
|
+
|
|
75
|
+
def forward(self, *args, **kwargs) -> Any:
|
|
76
|
+
"""
|
|
77
|
+
Forward pass with task-specific modulation.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
*args: Positional arguments for the backbone model
|
|
81
|
+
**kwargs: Keyword arguments for the backbone model
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Model output after applying task-specific modulation
|
|
85
|
+
"""
|
|
86
|
+
if self._current_task is None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"No task specified. Set current_task or provide 'task' argument."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return self.backbone(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class TaskModulator(nn.Module, Generic[TorchModelType], ABC):
|
|
95
|
+
"""
|
|
96
|
+
Lightweight, task-specific parameterization that modulates
|
|
97
|
+
a shared representation.
|
|
98
|
+
|
|
99
|
+
This is the base class for all task modulators. Subclasses should implement
|
|
100
|
+
the `apply` method to define how the modulator adapts the backbone model
|
|
101
|
+
for a specific task.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def apply(self, modulated_model: "ModulatedModel[TorchModelType]"):
|
|
106
|
+
"""
|
|
107
|
+
Apply task-specific modulation to the backbone model.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
modulated_model: The modulated model
|
|
111
|
+
"""
|
|
112
|
+
raise NotImplementedError("Subclasses must implement the apply method.")
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def remove(self, modulated_model: "ModulatedModel[TorchModelType]"):
|
|
116
|
+
"""
|
|
117
|
+
Remove task-specific modulation from the backbone model.
|
|
118
|
+
This is called when switching tasks.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
modulated_model: The modulated model
|
|
122
|
+
"""
|
|
123
|
+
raise NotImplementedError("Subclasses must implement the remove method.")
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenCLIP model wrappers used by FusionBench.
|
|
3
|
+
|
|
4
|
+
This module provides lightweight `torch.nn.Module` wrappers around OpenCLIP
|
|
5
|
+
components that are commonly used throughout FusionBench experiments:
|
|
6
|
+
|
|
7
|
+
- `ImageEncoder`: loads an OpenCLIP image encoder and exposes `encode_image`.
|
|
8
|
+
- `ClassificationHead`: a linear head optionally normalizing inputs.
|
|
9
|
+
- `ImageClassifier` / `MultiHeadImageClassifier`: convenience compositions.
|
|
10
|
+
|
|
11
|
+
Note:
|
|
12
|
+
This module requires the optional dependency `open_clip_torch`.
|
|
13
|
+
"""
|
|
14
|
+
|
|
1
15
|
from fusion_bench.utils.packages import is_open_clip_available
|
|
2
16
|
|
|
3
17
|
if not is_open_clip_available():
|
|
@@ -5,6 +19,7 @@ if not is_open_clip_available():
|
|
|
5
19
|
"open_clip is not installed. Please install it with `pip install open_clip_torch`."
|
|
6
20
|
)
|
|
7
21
|
|
|
22
|
+
from pathlib import Path
|
|
8
23
|
from typing import Callable, List
|
|
9
24
|
|
|
10
25
|
import open_clip
|
|
@@ -17,6 +32,19 @@ from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
|
|
|
17
32
|
|
|
18
33
|
class ImageEncoder(torch.nn.Module):
|
|
19
34
|
R"""
|
|
35
|
+
OpenCLIP image encoder wrapper.
|
|
36
|
+
|
|
37
|
+
This class loads an OpenCLIP model by name and exposes a forward pass that
|
|
38
|
+
returns image embeddings via `model.encode_image`.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_name: A model name supported by `open_clip`. FusionBench also
|
|
42
|
+
supports suffixes:
|
|
43
|
+
- ``"__pretrained__<tag>"`` to select a specific pretrained weights tag.
|
|
44
|
+
- ``"__init__"`` to use random initialization.
|
|
45
|
+
keep_lang: If False (default), removes the text encoder (when present)
|
|
46
|
+
to reduce memory usage.
|
|
47
|
+
|
|
20
48
|
Examples:
|
|
21
49
|
|
|
22
50
|
load the image encoder for a given model name
|
|
@@ -25,7 +53,7 @@ class ImageEncoder(torch.nn.Module):
|
|
|
25
53
|
>>> image_encoder = ImageEncoder(model_name="ViT-B-32")
|
|
26
54
|
"""
|
|
27
55
|
|
|
28
|
-
def __init__(self, model_name: str, keep_lang=False):
|
|
56
|
+
def __init__(self, model_name: str, keep_lang: bool = False):
|
|
29
57
|
super().__init__()
|
|
30
58
|
assert (
|
|
31
59
|
model_name in MODELS
|
|
@@ -49,22 +77,26 @@ class ImageEncoder(torch.nn.Module):
|
|
|
49
77
|
|
|
50
78
|
self.cache_dir = CACHEDIR
|
|
51
79
|
|
|
80
|
+
# if `keep_lang` is False, remove the text encoder to save memory
|
|
52
81
|
if not keep_lang and hasattr(self.model, "transformer"):
|
|
53
82
|
delattr(self.model, "transformer")
|
|
54
83
|
|
|
55
|
-
def forward(self, images):
|
|
84
|
+
def forward(self, images: Tensor) -> Tensor:
|
|
85
|
+
"""Encode a batch of images into embedding vectors."""
|
|
56
86
|
assert self.model is not None
|
|
57
87
|
return self.model.encode_image(images)
|
|
58
88
|
|
|
59
|
-
def __call__(self, inputs):
|
|
89
|
+
def __call__(self, inputs: Tensor) -> Tensor:
|
|
60
90
|
return self.forward(inputs)
|
|
61
91
|
|
|
62
|
-
def save(self, filename):
|
|
92
|
+
def save(self, filename: str) -> None:
|
|
93
|
+
"""Serialize this module to disk."""
|
|
63
94
|
print(f"Saving image encoder to {filename}")
|
|
64
95
|
utils.torch_save(self, filename)
|
|
65
96
|
|
|
66
97
|
@classmethod
|
|
67
|
-
def load(cls, model_name, filename):
|
|
98
|
+
def load(cls, model_name: str, filename: str | Path):
|
|
99
|
+
"""Load a saved encoder state dict into a freshly constructed encoder."""
|
|
68
100
|
print(f"Loading image encoder from {filename}")
|
|
69
101
|
|
|
70
102
|
state_dict = torch.load(filename, map_location="cpu")
|
|
@@ -75,6 +107,15 @@ class ImageEncoder(torch.nn.Module):
|
|
|
75
107
|
|
|
76
108
|
|
|
77
109
|
class ClassificationHead(torch.nn.Linear):
|
|
110
|
+
"""A linear classification head with optional input normalization.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
normalize: If True, L2-normalize inputs along the last dimension before
|
|
114
|
+
applying the linear projection.
|
|
115
|
+
weights: Weight matrix of shape (num_classes, feature_dim).
|
|
116
|
+
biases: Optional bias vector of shape (num_classes,).
|
|
117
|
+
"""
|
|
118
|
+
|
|
78
119
|
def __init__(
|
|
79
120
|
self,
|
|
80
121
|
normalize: bool,
|
|
@@ -92,6 +133,7 @@ class ClassificationHead(torch.nn.Linear):
|
|
|
92
133
|
self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
|
|
93
134
|
|
|
94
135
|
def forward(self, inputs: Tensor):
|
|
136
|
+
"""Compute logits from input features."""
|
|
95
137
|
if self.normalize:
|
|
96
138
|
inputs = inputs / inputs.norm(dim=-1, keepdim=True)
|
|
97
139
|
return super().forward(inputs)
|
|
@@ -100,11 +142,13 @@ class ClassificationHead(torch.nn.Linear):
|
|
|
100
142
|
return self.forward(inputs)
|
|
101
143
|
|
|
102
144
|
def save(self, filename):
|
|
145
|
+
"""Serialize this head to disk."""
|
|
103
146
|
print(f"Saving classification head to {filename}")
|
|
104
147
|
utils.torch_save(self, filename, save_state_dict=False)
|
|
105
148
|
|
|
106
149
|
@classmethod
|
|
107
150
|
def load(cls, filename):
|
|
151
|
+
"""Load a serialized `ClassificationHead` instance from disk."""
|
|
108
152
|
# print(f"Loading classification head from {filename}")
|
|
109
153
|
return utils.torch_load(filename)
|
|
110
154
|
|
|
@@ -113,6 +157,8 @@ class ImageClassifier(torch.nn.Module):
|
|
|
113
157
|
train_preprocess: Callable
|
|
114
158
|
val_preprocess: Callable
|
|
115
159
|
|
|
160
|
+
"""Convenience module combining an `ImageEncoder` and a `ClassificationHead`."""
|
|
161
|
+
|
|
116
162
|
def __init__(
|
|
117
163
|
self,
|
|
118
164
|
image_encoder: ImageEncoder,
|
|
@@ -126,10 +172,12 @@ class ImageClassifier(torch.nn.Module):
|
|
|
126
172
|
self.val_preprocess = self.image_encoder.val_preprocess
|
|
127
173
|
|
|
128
174
|
def freeze_head(self):
|
|
175
|
+
"""Disable gradient computation for the classification head."""
|
|
129
176
|
self.classification_head.weight.requires_grad_(False)
|
|
130
177
|
self.classification_head.bias.requires_grad_(False)
|
|
131
178
|
|
|
132
179
|
def forward(self, inputs: Tensor):
|
|
180
|
+
"""Run encoder then head and return logits."""
|
|
133
181
|
features = self.image_encoder(inputs)
|
|
134
182
|
outputs = self.classification_head(features)
|
|
135
183
|
return outputs
|
|
@@ -138,16 +186,20 @@ class ImageClassifier(torch.nn.Module):
|
|
|
138
186
|
return self.forward(inputs)
|
|
139
187
|
|
|
140
188
|
def save(self, filename):
|
|
189
|
+
"""Serialize this module to disk."""
|
|
141
190
|
print(f"Saving image classifier to {filename}")
|
|
142
191
|
utils.torch_save(self, filename)
|
|
143
192
|
|
|
144
193
|
@classmethod
|
|
145
194
|
def load(cls, filename):
|
|
195
|
+
"""Load a serialized `ImageClassifier` instance from disk."""
|
|
146
196
|
print(f"Loading image classifier from {filename}")
|
|
147
197
|
return utils.torch_load(filename)
|
|
148
198
|
|
|
149
199
|
|
|
150
200
|
class MultiHeadImageClassifier(torch.nn.Module):
|
|
201
|
+
"""Image encoder with multiple task-specific classification heads."""
|
|
202
|
+
|
|
151
203
|
def __init__(
|
|
152
204
|
self,
|
|
153
205
|
image_encoder: ImageEncoder,
|
|
@@ -161,11 +213,13 @@ class MultiHeadImageClassifier(torch.nn.Module):
|
|
|
161
213
|
self.val_preprocess = self.image_encoder.val_preprocess
|
|
162
214
|
|
|
163
215
|
def freeze_head(self):
|
|
216
|
+
"""Disable gradient computation for all heads."""
|
|
164
217
|
for idx in range(len(self.classification_heads)):
|
|
165
218
|
self.classification_heads[idx].weight.requires_grad_(False)
|
|
166
219
|
self.classification_heads[idx].bias.requires_grad_(False)
|
|
167
220
|
|
|
168
221
|
def forward(self, inputs, head_idx):
|
|
222
|
+
"""Run encoder then the selected head and return logits."""
|
|
169
223
|
features = self.image_encoder(inputs)
|
|
170
224
|
outputs = self.classification_heads[head_idx](features)
|
|
171
225
|
return outputs
|
|
@@ -174,10 +228,12 @@ class MultiHeadImageClassifier(torch.nn.Module):
|
|
|
174
228
|
return self.forward(inputs, head_idx)
|
|
175
229
|
|
|
176
230
|
def save(self, filename):
|
|
231
|
+
"""Serialize this module to disk."""
|
|
177
232
|
print(f"Saving image classifier to {filename}")
|
|
178
233
|
utils.torch_save(self, filename)
|
|
179
234
|
|
|
180
235
|
@classmethod
|
|
181
236
|
def load(cls, filename):
|
|
237
|
+
"""Load a serialized `MultiHeadImageClassifier` instance from disk."""
|
|
182
238
|
print(f"Loading image classifier from {filename}")
|
|
183
239
|
return utils.torch_load(filename)
|
|
@@ -77,7 +77,16 @@ def torch_load_old(save_path: str, device=None):
|
|
|
77
77
|
return classifier
|
|
78
78
|
|
|
79
79
|
|
|
80
|
-
def torch_save(model, save_path, save_state_dict=True):
|
|
80
|
+
def torch_save(model: torch.nn.Module, save_path: str, save_state_dict: bool = True):
|
|
81
|
+
"""
|
|
82
|
+
Save a model to disk.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
model: The model to save.
|
|
86
|
+
save_path (str): The path to save the model to.
|
|
87
|
+
save_state_dict (bool): Whether to save the state dict of the model (weights only).
|
|
88
|
+
If False, the entire model object is saved. Default is True.
|
|
89
|
+
"""
|
|
81
90
|
# TODO: hacky way to save state dict
|
|
82
91
|
if save_state_dict and isinstance(model, torch.nn.Module):
|
|
83
92
|
model = model.state_dict()
|
|
@@ -86,7 +95,9 @@ def torch_save(model, save_path, save_state_dict=True):
|
|
|
86
95
|
torch.save(model, save_path)
|
|
87
96
|
|
|
88
97
|
|
|
89
|
-
def torch_load(
|
|
98
|
+
def torch_load(
|
|
99
|
+
save_path: str, device: Optional[torch.device] = None
|
|
100
|
+
) -> torch.nn.Module:
|
|
90
101
|
model = torch.load(save_path, map_location="cpu")
|
|
91
102
|
if device is not None:
|
|
92
103
|
model = model.to(device)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
from typing import List, Mapping, Optional, Tuple
|
|
1
|
+
from typing import Iterator, List, Mapping, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
5
|
|
|
6
|
-
__all__ = "
|
|
6
|
+
__all__ = ["ParameterDictModel"]
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
9
|
+
def set_nested_attr(
|
|
10
10
|
obj,
|
|
11
11
|
names: List[str],
|
|
12
12
|
val,
|
|
@@ -27,7 +27,7 @@ def _set_attr(
|
|
|
27
27
|
else:
|
|
28
28
|
if check_parent and not hasattr(obj, names[0]):
|
|
29
29
|
setattr(obj, names[0], parent_builder())
|
|
30
|
-
|
|
30
|
+
set_nested_attr(
|
|
31
31
|
getattr(obj, names[0]),
|
|
32
32
|
names[1:],
|
|
33
33
|
val,
|
|
@@ -36,7 +36,7 @@ def _set_attr(
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def has_nested_attr(obj, names: List[str]):
|
|
40
40
|
"""
|
|
41
41
|
Checks if an attribute exists in an object recursively.
|
|
42
42
|
|
|
@@ -50,26 +50,49 @@ def has_attr(obj, names: List[str]):
|
|
|
50
50
|
if len(names) == 1:
|
|
51
51
|
return hasattr(obj, names[0])
|
|
52
52
|
else:
|
|
53
|
-
|
|
53
|
+
if not hasattr(obj, names[0]):
|
|
54
|
+
return False
|
|
55
|
+
return has_nested_attr(getattr(obj, names[0]), names[1:])
|
|
54
56
|
|
|
55
57
|
|
|
56
58
|
class ParameterDictModel(nn.Module):
|
|
57
59
|
"""
|
|
58
|
-
|
|
59
|
-
|
|
60
|
+
A module that stores parameters in a nested dictionary structure.
|
|
61
|
+
|
|
62
|
+
This model behaves similarly to `nn.ParameterDict`, but supports hierarchical keys
|
|
63
|
+
with dots (e.g., "layer1.weight"). Parameters are stored as nested attributes,
|
|
64
|
+
allowing for structured parameter access and manipulation.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> params = {
|
|
68
|
+
... "encoder.weight": nn.Parameter(torch.randn(10, 5)),
|
|
69
|
+
... "decoder.bias": nn.Parameter(torch.randn(5)),
|
|
70
|
+
... }
|
|
71
|
+
>>> model = ParameterDictModel(params)
|
|
72
|
+
>>> model["encoder.weight"].shape
|
|
73
|
+
torch.Size([10, 5])
|
|
74
|
+
>>> "encoder.weight" in model
|
|
75
|
+
True
|
|
60
76
|
"""
|
|
61
77
|
|
|
62
78
|
def __init__(
|
|
63
79
|
self,
|
|
64
|
-
parameters: Optional[Mapping[str, nn.Parameter]] = None,
|
|
65
|
-
):
|
|
80
|
+
parameters: Optional[Mapping[str, Union[nn.Parameter, torch.Tensor]]] = None,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Args:
|
|
84
|
+
parameters: Optional mapping of parameter names to parameter tensors.
|
|
85
|
+
Keys can contain dots to create nested structures.
|
|
86
|
+
Values must be `nn.Parameter` or `nn.Buffer` instances.
|
|
87
|
+
"""
|
|
88
|
+
|
|
66
89
|
super().__init__()
|
|
67
90
|
if parameters is not None:
|
|
68
91
|
for name, param in parameters.items():
|
|
69
92
|
assert isinstance(
|
|
70
93
|
param, (nn.Parameter, nn.Buffer)
|
|
71
94
|
), f"{name} is not a nn.Parameter or nn.Buffer"
|
|
72
|
-
|
|
95
|
+
set_nested_attr(
|
|
73
96
|
self,
|
|
74
97
|
name.split("."),
|
|
75
98
|
param,
|
|
@@ -77,12 +100,13 @@ class ParameterDictModel(nn.Module):
|
|
|
77
100
|
parent_builder=__class__,
|
|
78
101
|
)
|
|
79
102
|
|
|
80
|
-
def __repr__(self):
|
|
103
|
+
def __repr__(self) -> str:
|
|
81
104
|
"""
|
|
82
105
|
Generate a string representation of the model's parameters.
|
|
83
106
|
|
|
84
107
|
Returns:
|
|
85
|
-
|
|
108
|
+
A string representation of the model's parameters in the format:
|
|
109
|
+
"ParameterDictModel(name1: shape1, name2: shape2, ...)"
|
|
86
110
|
"""
|
|
87
111
|
param_reprs = []
|
|
88
112
|
for name, param in self.named_parameters():
|
|
@@ -90,32 +114,98 @@ class ParameterDictModel(nn.Module):
|
|
|
90
114
|
param_reprs.append(param_repr)
|
|
91
115
|
return f"{self.__class__.__name__}({', '.join(param_reprs)})"
|
|
92
116
|
|
|
93
|
-
def
|
|
94
|
-
|
|
117
|
+
def __iter__(self) -> Iterator[str]:
|
|
118
|
+
"""
|
|
119
|
+
Iterate over the model's parameters.
|
|
120
|
+
|
|
121
|
+
Yields:
|
|
122
|
+
Tuples of (parameter name, parameter tensor).
|
|
123
|
+
"""
|
|
124
|
+
yield from self.keys()
|
|
125
|
+
|
|
126
|
+
def __getitem__(
|
|
127
|
+
self, key: str
|
|
128
|
+
) -> Union[nn.Parameter, torch.Tensor, "ParameterDictModel"]:
|
|
129
|
+
"""
|
|
130
|
+
Retrieve a parameter or nested submodule by key.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
key: Parameter name, which can contain dots for nested access.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
The parameter, tensor, or nested ParameterDictModel at the specified key.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
KeyError: If the key is not found in the model.
|
|
140
|
+
"""
|
|
141
|
+
assert isinstance(
|
|
142
|
+
key, str
|
|
143
|
+
), f"Key must be a string, but got {type(key)}: {key}."
|
|
144
|
+
if not has_nested_attr(self, key.split(".")):
|
|
95
145
|
raise KeyError(f"Key {key} not found in {self}")
|
|
96
|
-
|
|
146
|
+
key_parts = key.split(".")
|
|
97
147
|
obj = self
|
|
98
|
-
for k in
|
|
148
|
+
for k in key_parts:
|
|
99
149
|
obj = getattr(obj, k)
|
|
100
150
|
return obj
|
|
101
151
|
|
|
102
|
-
def __setitem__(self, key: str, value: nn.Parameter):
|
|
103
|
-
|
|
104
|
-
|
|
152
|
+
def __setitem__(self, key: str, value: Union[nn.Parameter, torch.Tensor]) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Set a parameter at the specified key, creating nested structure if needed.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
key: Parameter name, which can contain dots for nested assignment.
|
|
158
|
+
value: Parameter or tensor to assign.
|
|
159
|
+
"""
|
|
160
|
+
if not has_nested_attr(self, key.split(".")):
|
|
161
|
+
set_nested_attr(self, key.split("."), value, check_parent=True)
|
|
105
162
|
else:
|
|
106
|
-
|
|
163
|
+
set_nested_attr(self, key.split("."), value, check_parent=False)
|
|
164
|
+
|
|
165
|
+
def __contains__(self, key: str) -> bool:
|
|
166
|
+
"""
|
|
167
|
+
Check if a parameter key exists in the model.
|
|
107
168
|
|
|
108
|
-
|
|
109
|
-
|
|
169
|
+
Args:
|
|
170
|
+
key: Parameter name, which can contain dots for nested checking.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
True if the key exists, False otherwise.
|
|
174
|
+
"""
|
|
175
|
+
return has_nested_attr(self, key.split("."))
|
|
110
176
|
|
|
111
177
|
def keys(self):
|
|
112
|
-
|
|
178
|
+
"""
|
|
179
|
+
Return a list of all parameter names in the model.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
List of parameter names (including nested names with dots).
|
|
183
|
+
"""
|
|
184
|
+
return self.state_dict().keys()
|
|
185
|
+
|
|
186
|
+
def items(self):
|
|
187
|
+
"""
|
|
188
|
+
Return a list of (name, parameter) tuples.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
List of tuples containing parameter names and their corresponding tensors.
|
|
192
|
+
"""
|
|
193
|
+
yield from self.state_dict().items()
|
|
113
194
|
|
|
114
|
-
def
|
|
115
|
-
|
|
195
|
+
def values(self):
|
|
196
|
+
"""
|
|
197
|
+
Return a list of all parameter values in the model.
|
|
116
198
|
|
|
117
|
-
|
|
118
|
-
|
|
199
|
+
Returns:
|
|
200
|
+
List of parameter tensors.
|
|
201
|
+
"""
|
|
202
|
+
yield from self.state_dict().values()
|
|
119
203
|
|
|
120
|
-
def __len__(self):
|
|
204
|
+
def __len__(self) -> int:
|
|
205
|
+
"""
|
|
206
|
+
Return the number of parameters in the model.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
The total number of parameters.
|
|
210
|
+
"""
|
|
121
211
|
return len(self.keys())
|