fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/constants/runtime.py +4 -1
  4. fusion_bench/dataset/__init__.py +2 -0
  5. fusion_bench/dataset/clip_dataset.py +4 -72
  6. fusion_bench/dataset/image_dataset.py +44 -18
  7. fusion_bench/method/base_algorithm.py +4 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  9. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  10. fusion_bench/method/dop/dop.py +0 -22
  11. fusion_bench/method/dop/dop_general.py +489 -0
  12. fusion_bench/method/dop/utils.py +24 -4
  13. fusion_bench/method/emr_merging/__init__.py +1 -0
  14. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  15. fusion_bench/method/emr_merging/utils.py +162 -0
  16. fusion_bench/method/opcm/opcm.py +6 -2
  17. fusion_bench/method/opcm/opcm_general.py +356 -0
  18. fusion_bench/method/opcm/utils.py +1 -4
  19. fusion_bench/method/simple_average.py +52 -18
  20. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  21. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  22. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  23. fusion_bench/mixins/lightning_fabric.py +110 -11
  24. fusion_bench/mixins/openclip_classification.py +155 -1
  25. fusion_bench/mixins/serialization.py +1 -1
  26. fusion_bench/modelpool/base_pool.py +37 -0
  27. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  28. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  29. fusion_bench/models/hf_clip.py +20 -0
  30. fusion_bench/models/modulator/__init__.py +1 -0
  31. fusion_bench/models/modulator/base.py +123 -0
  32. fusion_bench/models/open_clip/modeling.py +61 -5
  33. fusion_bench/models/open_clip/utils.py +13 -2
  34. fusion_bench/models/parameter_dict.py +119 -29
  35. fusion_bench/models/utils.py +190 -2
  36. fusion_bench/models/wrappers/switch.py +90 -0
  37. fusion_bench/programs/base_program.py +6 -0
  38. fusion_bench/programs/fabric_fusion_program.py +4 -0
  39. fusion_bench/py.typed +1 -0
  40. fusion_bench/scripts/cli.py +25 -23
  41. fusion_bench/scripts/imgui.py +2 -2
  42. fusion_bench/scripts/webui.py +2 -2
  43. fusion_bench/taskpool/image_classification.py +270 -0
  44. fusion_bench/utils/__init__.py +20 -1
  45. fusion_bench/utils/data.py +1 -1
  46. fusion_bench/utils/dict.py +19 -0
  47. fusion_bench/utils/dtype.py +19 -0
  48. fusion_bench/utils/hydra_utils.py +75 -0
  49. fusion_bench/utils/misc.py +1 -0
  50. fusion_bench/utils/packages.py +4 -0
  51. fusion_bench/utils/parameters.py +33 -0
  52. fusion_bench/utils/rich_utils.py +42 -19
  53. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  54. fusion_bench/utils/tensorboard.py +21 -3
  55. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  56. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
  57. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  58. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  59. fusion_bench_config/README.md +9 -0
  60. fusion_bench_config/fabric/auto.yaml +1 -0
  61. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  62. fusion_bench_config/hydra/default.yaml +3 -1
  63. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  64. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  65. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  66. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  67. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  68. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  69. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  70. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import pickle
3
3
  import sys
4
- from typing import Callable, Optional, Union, cast
4
+ from typing import Callable, Optional, Union, cast, override
5
5
 
6
6
  import torch
7
7
  from datasets import load_dataset
@@ -41,8 +41,8 @@ def _check_and_redirect_open_clip_modeling():
41
41
  )
42
42
 
43
43
  try:
44
- import src
45
- import src.modeling
44
+ import src # type: ignore
45
+ import src.modeling # type: ignore
46
46
  except ImportError:
47
47
  if "src" not in sys.modules:
48
48
  # redirect the import of `src` to `fusion_bench.models.open_clip`
@@ -114,6 +114,7 @@ class OpenCLIPVisionModelPool(BaseModelPool):
114
114
  self._test_processor = encoder.val_preprocess
115
115
  return self._test_processor
116
116
 
117
+ @override
117
118
  def load_model(
118
119
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
119
120
  ) -> ImageEncoder:
@@ -210,6 +211,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
210
211
  - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
211
212
  - Default, load the model using `instantiate` from hydra.
212
213
  """
214
+ if self._classification_heads is None:
215
+ raise ValueError("No classification heads are defined in the model pool.")
213
216
  if (
214
217
  isinstance(model_name_or_config, str)
215
218
  and model_name_or_config in self._classification_heads
@@ -222,6 +225,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
222
225
  return head
223
226
 
224
227
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
228
+ if self._train_datasets is None:
229
+ raise ValueError("No train datasets are defined in the model pool.")
225
230
  dataset_config = self._train_datasets[dataset_name]
226
231
  if isinstance(dataset_config, str):
227
232
  log.info(
@@ -233,6 +238,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
233
238
  return dataset
234
239
 
235
240
  def load_val_dataset(self, dataset_name: str, *args, **kwargs):
241
+ if self._val_datasets is None:
242
+ raise ValueError("No val datasets are defined in the model pool.")
236
243
  dataset_config = self._val_datasets[dataset_name]
237
244
  if isinstance(dataset_config, str):
238
245
  log.info(
@@ -244,6 +251,8 @@ class OpenCLIPVisionModelPool(BaseModelPool):
244
251
  return dataset
245
252
 
246
253
  def load_test_dataset(self, dataset_name: str, *args, **kwargs):
254
+ if self._test_datasets is None:
255
+ raise ValueError("No test datasets are defined in the model pool.")
247
256
  dataset_config = self._test_datasets[dataset_name]
248
257
  if isinstance(dataset_config, str):
249
258
  log.info(
@@ -62,16 +62,36 @@ class HFCLIPClassifier(nn.Module):
62
62
  persistent=False,
63
63
  )
64
64
 
65
+ # NOTE:
66
+ # The property setters seems not to work properly with `nn.Module` attributes.
67
+ # So avoid using them in practice.
68
+ # To set the text or vision model, directly access the attributes.
69
+ # For example:
70
+ # classifier.clip_model.text_model = new_text_model
71
+ # or
72
+ # classifier.clip_model.vision_model = new_vision_model
73
+ # reference: https://github.com/pytorch/pytorch/issues/52664
74
+
65
75
  @property
66
76
  def text_model(self):
67
77
  """Get the text model component of CLIP."""
68
78
  return self.clip_model.text_model
69
79
 
80
+ @text_model.setter
81
+ def text_model(self, model: nn.Module):
82
+ """Set the text model component of CLIP."""
83
+ self.clip_model.text_model = model
84
+
70
85
  @property
71
86
  def vision_model(self):
72
87
  """Get the vision model component of CLIP."""
73
88
  return self.clip_model.vision_model
74
89
 
90
+ @vision_model.setter
91
+ def vision_model(self, model: nn.Module):
92
+ """Set the vision model component of CLIP."""
93
+ self.clip_model.vision_model = model
94
+
75
95
  def set_classification_task(
76
96
  self,
77
97
  classnames: List[str],
@@ -0,0 +1 @@
1
+ from .base import ModulatedModel, TaskModulator
@@ -0,0 +1,123 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Generic, Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from fusion_bench import TorchModelType
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class ModulatedModel(nn.Module, Generic[TorchModelType]):
14
+ """
15
+ A model wrapper that uses task-specific modulators to adapt a shared backbone
16
+ for different tasks.
17
+
18
+ The model maintains a shared backbone and task-specific modulators. During forward pass,
19
+ the appropriate modulator is applied based on the current task.
20
+ """
21
+
22
+ _current_task: Optional[str] = None
23
+
24
+ def __init__(
25
+ self,
26
+ backbone: TorchModelType,
27
+ modulators: Dict[str, "TaskModulator[TorchModelType]"],
28
+ ):
29
+ super().__init__()
30
+ self.backbone = backbone
31
+ self.modulators = nn.ModuleDict(modulators)
32
+
33
+ def add_modulator(self, task_name: str, modulator: "TaskModulator[TorchModelType]"):
34
+ """Add a new task-specific modulator."""
35
+ if task_name in self.modulators:
36
+ raise ValueError(f"Modulator for task '{task_name}' already exists.")
37
+ self.modulators[task_name] = modulator
38
+
39
+ def remove_modulator(self, task_name: str):
40
+ """Remove an existing task-specific modulator."""
41
+ if task_name not in self.modulators:
42
+ raise ValueError(f"Modulator for task '{task_name}' does not exist.")
43
+ if self._current_task == task_name:
44
+ log.warning(
45
+ f"Removing modulator for current task '{task_name}'. "
46
+ "This will make unset the current task unpredictable."
47
+ )
48
+ del self.modulators[task_name]
49
+
50
+ def set_task(self, task_name: str):
51
+ """Set the current task for inference."""
52
+ if task_name not in self.modulators:
53
+ raise ValueError(
54
+ f"Task '{task_name}' not found in modulators. Available tasks: {list(self.modulators.keys())}"
55
+ )
56
+ if self._current_task == task_name:
57
+ return
58
+
59
+ # unset previous task
60
+ if self._current_task is not None:
61
+ self.modulators[self._current_task].remove(self)
62
+ assert (
63
+ self._current_task is None
64
+ ), "Current task should be None after removal."
65
+
66
+ # set new task
67
+ self.modulators[task_name].apply(self)
68
+ self._current_task = task_name
69
+
70
+ @property
71
+ def current_task(self) -> Optional[str]:
72
+ """Get the current task name."""
73
+ return self._current_task
74
+
75
+ def forward(self, *args, **kwargs) -> Any:
76
+ """
77
+ Forward pass with task-specific modulation.
78
+
79
+ Args:
80
+ *args: Positional arguments for the backbone model
81
+ **kwargs: Keyword arguments for the backbone model
82
+
83
+ Returns:
84
+ Model output after applying task-specific modulation
85
+ """
86
+ if self._current_task is None:
87
+ raise ValueError(
88
+ "No task specified. Set current_task or provide 'task' argument."
89
+ )
90
+
91
+ return self.backbone(*args, **kwargs)
92
+
93
+
94
+ class TaskModulator(nn.Module, Generic[TorchModelType], ABC):
95
+ """
96
+ Lightweight, task-specific parameterization that modulates
97
+ a shared representation.
98
+
99
+ This is the base class for all task modulators. Subclasses should implement
100
+ the `apply` method to define how the modulator adapts the backbone model
101
+ for a specific task.
102
+ """
103
+
104
+ @abstractmethod
105
+ def apply(self, modulated_model: "ModulatedModel[TorchModelType]"):
106
+ """
107
+ Apply task-specific modulation to the backbone model.
108
+
109
+ Args:
110
+ modulated_model: The modulated model
111
+ """
112
+ raise NotImplementedError("Subclasses must implement the apply method.")
113
+
114
+ @abstractmethod
115
+ def remove(self, modulated_model: "ModulatedModel[TorchModelType]"):
116
+ """
117
+ Remove task-specific modulation from the backbone model.
118
+ This is called when switching tasks.
119
+
120
+ Args:
121
+ modulated_model: The modulated model
122
+ """
123
+ raise NotImplementedError("Subclasses must implement the remove method.")
@@ -1,3 +1,17 @@
1
+ """
2
+ OpenCLIP model wrappers used by FusionBench.
3
+
4
+ This module provides lightweight `torch.nn.Module` wrappers around OpenCLIP
5
+ components that are commonly used throughout FusionBench experiments:
6
+
7
+ - `ImageEncoder`: loads an OpenCLIP image encoder and exposes `encode_image`.
8
+ - `ClassificationHead`: a linear head optionally normalizing inputs.
9
+ - `ImageClassifier` / `MultiHeadImageClassifier`: convenience compositions.
10
+
11
+ Note:
12
+ This module requires the optional dependency `open_clip_torch`.
13
+ """
14
+
1
15
  from fusion_bench.utils.packages import is_open_clip_available
2
16
 
3
17
  if not is_open_clip_available():
@@ -5,6 +19,7 @@ if not is_open_clip_available():
5
19
  "open_clip is not installed. Please install it with `pip install open_clip_torch`."
6
20
  )
7
21
 
22
+ from pathlib import Path
8
23
  from typing import Callable, List
9
24
 
10
25
  import open_clip
@@ -17,6 +32,19 @@ from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
17
32
 
18
33
  class ImageEncoder(torch.nn.Module):
19
34
  R"""
35
+ OpenCLIP image encoder wrapper.
36
+
37
+ This class loads an OpenCLIP model by name and exposes a forward pass that
38
+ returns image embeddings via `model.encode_image`.
39
+
40
+ Args:
41
+ model_name: A model name supported by `open_clip`. FusionBench also
42
+ supports suffixes:
43
+ - ``"__pretrained__<tag>"`` to select a specific pretrained weights tag.
44
+ - ``"__init__"`` to use random initialization.
45
+ keep_lang: If False (default), removes the text encoder (when present)
46
+ to reduce memory usage.
47
+
20
48
  Examples:
21
49
 
22
50
  load the image encoder for a given model name
@@ -25,7 +53,7 @@ class ImageEncoder(torch.nn.Module):
25
53
  >>> image_encoder = ImageEncoder(model_name="ViT-B-32")
26
54
  """
27
55
 
28
- def __init__(self, model_name: str, keep_lang=False):
56
+ def __init__(self, model_name: str, keep_lang: bool = False):
29
57
  super().__init__()
30
58
  assert (
31
59
  model_name in MODELS
@@ -49,22 +77,26 @@ class ImageEncoder(torch.nn.Module):
49
77
 
50
78
  self.cache_dir = CACHEDIR
51
79
 
80
+ # if `keep_lang` is False, remove the text encoder to save memory
52
81
  if not keep_lang and hasattr(self.model, "transformer"):
53
82
  delattr(self.model, "transformer")
54
83
 
55
- def forward(self, images):
84
+ def forward(self, images: Tensor) -> Tensor:
85
+ """Encode a batch of images into embedding vectors."""
56
86
  assert self.model is not None
57
87
  return self.model.encode_image(images)
58
88
 
59
- def __call__(self, inputs):
89
+ def __call__(self, inputs: Tensor) -> Tensor:
60
90
  return self.forward(inputs)
61
91
 
62
- def save(self, filename):
92
+ def save(self, filename: str) -> None:
93
+ """Serialize this module to disk."""
63
94
  print(f"Saving image encoder to {filename}")
64
95
  utils.torch_save(self, filename)
65
96
 
66
97
  @classmethod
67
- def load(cls, model_name, filename):
98
+ def load(cls, model_name: str, filename: str | Path):
99
+ """Load a saved encoder state dict into a freshly constructed encoder."""
68
100
  print(f"Loading image encoder from {filename}")
69
101
 
70
102
  state_dict = torch.load(filename, map_location="cpu")
@@ -75,6 +107,15 @@ class ImageEncoder(torch.nn.Module):
75
107
 
76
108
 
77
109
  class ClassificationHead(torch.nn.Linear):
110
+ """A linear classification head with optional input normalization.
111
+
112
+ Args:
113
+ normalize: If True, L2-normalize inputs along the last dimension before
114
+ applying the linear projection.
115
+ weights: Weight matrix of shape (num_classes, feature_dim).
116
+ biases: Optional bias vector of shape (num_classes,).
117
+ """
118
+
78
119
  def __init__(
79
120
  self,
80
121
  normalize: bool,
@@ -92,6 +133,7 @@ class ClassificationHead(torch.nn.Linear):
92
133
  self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
93
134
 
94
135
  def forward(self, inputs: Tensor):
136
+ """Compute logits from input features."""
95
137
  if self.normalize:
96
138
  inputs = inputs / inputs.norm(dim=-1, keepdim=True)
97
139
  return super().forward(inputs)
@@ -100,11 +142,13 @@ class ClassificationHead(torch.nn.Linear):
100
142
  return self.forward(inputs)
101
143
 
102
144
  def save(self, filename):
145
+ """Serialize this head to disk."""
103
146
  print(f"Saving classification head to {filename}")
104
147
  utils.torch_save(self, filename, save_state_dict=False)
105
148
 
106
149
  @classmethod
107
150
  def load(cls, filename):
151
+ """Load a serialized `ClassificationHead` instance from disk."""
108
152
  # print(f"Loading classification head from {filename}")
109
153
  return utils.torch_load(filename)
110
154
 
@@ -113,6 +157,8 @@ class ImageClassifier(torch.nn.Module):
113
157
  train_preprocess: Callable
114
158
  val_preprocess: Callable
115
159
 
160
+ """Convenience module combining an `ImageEncoder` and a `ClassificationHead`."""
161
+
116
162
  def __init__(
117
163
  self,
118
164
  image_encoder: ImageEncoder,
@@ -126,10 +172,12 @@ class ImageClassifier(torch.nn.Module):
126
172
  self.val_preprocess = self.image_encoder.val_preprocess
127
173
 
128
174
  def freeze_head(self):
175
+ """Disable gradient computation for the classification head."""
129
176
  self.classification_head.weight.requires_grad_(False)
130
177
  self.classification_head.bias.requires_grad_(False)
131
178
 
132
179
  def forward(self, inputs: Tensor):
180
+ """Run encoder then head and return logits."""
133
181
  features = self.image_encoder(inputs)
134
182
  outputs = self.classification_head(features)
135
183
  return outputs
@@ -138,16 +186,20 @@ class ImageClassifier(torch.nn.Module):
138
186
  return self.forward(inputs)
139
187
 
140
188
  def save(self, filename):
189
+ """Serialize this module to disk."""
141
190
  print(f"Saving image classifier to {filename}")
142
191
  utils.torch_save(self, filename)
143
192
 
144
193
  @classmethod
145
194
  def load(cls, filename):
195
+ """Load a serialized `ImageClassifier` instance from disk."""
146
196
  print(f"Loading image classifier from {filename}")
147
197
  return utils.torch_load(filename)
148
198
 
149
199
 
150
200
  class MultiHeadImageClassifier(torch.nn.Module):
201
+ """Image encoder with multiple task-specific classification heads."""
202
+
151
203
  def __init__(
152
204
  self,
153
205
  image_encoder: ImageEncoder,
@@ -161,11 +213,13 @@ class MultiHeadImageClassifier(torch.nn.Module):
161
213
  self.val_preprocess = self.image_encoder.val_preprocess
162
214
 
163
215
  def freeze_head(self):
216
+ """Disable gradient computation for all heads."""
164
217
  for idx in range(len(self.classification_heads)):
165
218
  self.classification_heads[idx].weight.requires_grad_(False)
166
219
  self.classification_heads[idx].bias.requires_grad_(False)
167
220
 
168
221
  def forward(self, inputs, head_idx):
222
+ """Run encoder then the selected head and return logits."""
169
223
  features = self.image_encoder(inputs)
170
224
  outputs = self.classification_heads[head_idx](features)
171
225
  return outputs
@@ -174,10 +228,12 @@ class MultiHeadImageClassifier(torch.nn.Module):
174
228
  return self.forward(inputs, head_idx)
175
229
 
176
230
  def save(self, filename):
231
+ """Serialize this module to disk."""
177
232
  print(f"Saving image classifier to {filename}")
178
233
  utils.torch_save(self, filename)
179
234
 
180
235
  @classmethod
181
236
  def load(cls, filename):
237
+ """Load a serialized `MultiHeadImageClassifier` instance from disk."""
182
238
  print(f"Loading image classifier from {filename}")
183
239
  return utils.torch_load(filename)
@@ -77,7 +77,16 @@ def torch_load_old(save_path: str, device=None):
77
77
  return classifier
78
78
 
79
79
 
80
- def torch_save(model, save_path, save_state_dict=True):
80
+ def torch_save(model: torch.nn.Module, save_path: str, save_state_dict: bool = True):
81
+ """
82
+ Save a model to disk.
83
+
84
+ Args:
85
+ model: The model to save.
86
+ save_path (str): The path to save the model to.
87
+ save_state_dict (bool): Whether to save the state dict of the model (weights only).
88
+ If False, the entire model object is saved. Default is True.
89
+ """
81
90
  # TODO: hacky way to save state dict
82
91
  if save_state_dict and isinstance(model, torch.nn.Module):
83
92
  model = model.state_dict()
@@ -86,7 +95,9 @@ def torch_save(model, save_path, save_state_dict=True):
86
95
  torch.save(model, save_path)
87
96
 
88
97
 
89
- def torch_load(save_path, device=None):
98
+ def torch_load(
99
+ save_path: str, device: Optional[torch.device] = None
100
+ ) -> torch.nn.Module:
90
101
  model = torch.load(save_path, map_location="cpu")
91
102
  if device is not None:
92
103
  model = model.to(device)
@@ -1,12 +1,12 @@
1
- from typing import List, Mapping, Optional, Tuple
1
+ from typing import Iterator, List, Mapping, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  from torch import nn
5
5
 
6
- __all__ = "ParamterDictModel"
6
+ __all__ = ["ParameterDictModel"]
7
7
 
8
8
 
9
- def _set_attr(
9
+ def set_nested_attr(
10
10
  obj,
11
11
  names: List[str],
12
12
  val,
@@ -27,7 +27,7 @@ def _set_attr(
27
27
  else:
28
28
  if check_parent and not hasattr(obj, names[0]):
29
29
  setattr(obj, names[0], parent_builder())
30
- _set_attr(
30
+ set_nested_attr(
31
31
  getattr(obj, names[0]),
32
32
  names[1:],
33
33
  val,
@@ -36,7 +36,7 @@ def _set_attr(
36
36
  )
37
37
 
38
38
 
39
- def has_attr(obj, names: List[str]):
39
+ def has_nested_attr(obj, names: List[str]):
40
40
  """
41
41
  Checks if an attribute exists in an object recursively.
42
42
 
@@ -50,26 +50,49 @@ def has_attr(obj, names: List[str]):
50
50
  if len(names) == 1:
51
51
  return hasattr(obj, names[0])
52
52
  else:
53
- return has_attr(getattr(obj, names[0]), names[1:])
53
+ if not hasattr(obj, names[0]):
54
+ return False
55
+ return has_nested_attr(getattr(obj, names[0]), names[1:])
54
56
 
55
57
 
56
58
  class ParameterDictModel(nn.Module):
57
59
  """
58
- This model is used to create a model with parameters from a dictionary.
59
- It behaves like a normal `nn.ParameterDict`, but support keys with dots.
60
+ A module that stores parameters in a nested dictionary structure.
61
+
62
+ This model behaves similarly to `nn.ParameterDict`, but supports hierarchical keys
63
+ with dots (e.g., "layer1.weight"). Parameters are stored as nested attributes,
64
+ allowing for structured parameter access and manipulation.
65
+
66
+ Example:
67
+ >>> params = {
68
+ ... "encoder.weight": nn.Parameter(torch.randn(10, 5)),
69
+ ... "decoder.bias": nn.Parameter(torch.randn(5)),
70
+ ... }
71
+ >>> model = ParameterDictModel(params)
72
+ >>> model["encoder.weight"].shape
73
+ torch.Size([10, 5])
74
+ >>> "encoder.weight" in model
75
+ True
60
76
  """
61
77
 
62
78
  def __init__(
63
79
  self,
64
- parameters: Optional[Mapping[str, nn.Parameter]] = None,
65
- ):
80
+ parameters: Optional[Mapping[str, Union[nn.Parameter, torch.Tensor]]] = None,
81
+ ) -> None:
82
+ """
83
+ Args:
84
+ parameters: Optional mapping of parameter names to parameter tensors.
85
+ Keys can contain dots to create nested structures.
86
+ Values must be `nn.Parameter` or `nn.Buffer` instances.
87
+ """
88
+
66
89
  super().__init__()
67
90
  if parameters is not None:
68
91
  for name, param in parameters.items():
69
92
  assert isinstance(
70
93
  param, (nn.Parameter, nn.Buffer)
71
94
  ), f"{name} is not a nn.Parameter or nn.Buffer"
72
- _set_attr(
95
+ set_nested_attr(
73
96
  self,
74
97
  name.split("."),
75
98
  param,
@@ -77,12 +100,13 @@ class ParameterDictModel(nn.Module):
77
100
  parent_builder=__class__,
78
101
  )
79
102
 
80
- def __repr__(self):
103
+ def __repr__(self) -> str:
81
104
  """
82
105
  Generate a string representation of the model's parameters.
83
106
 
84
107
  Returns:
85
- str: A string representation of the model's parameters.
108
+ A string representation of the model's parameters in the format:
109
+ "ParameterDictModel(name1: shape1, name2: shape2, ...)"
86
110
  """
87
111
  param_reprs = []
88
112
  for name, param in self.named_parameters():
@@ -90,32 +114,98 @@ class ParameterDictModel(nn.Module):
90
114
  param_reprs.append(param_repr)
91
115
  return f"{self.__class__.__name__}({', '.join(param_reprs)})"
92
116
 
93
- def __getitem__(self, key: str):
94
- if not has_attr(self, key.split(".")):
117
+ def __iter__(self) -> Iterator[str]:
118
+ """
119
+ Iterate over the model's parameters.
120
+
121
+ Yields:
122
+ Tuples of (parameter name, parameter tensor).
123
+ """
124
+ yield from self.keys()
125
+
126
+ def __getitem__(
127
+ self, key: str
128
+ ) -> Union[nn.Parameter, torch.Tensor, "ParameterDictModel"]:
129
+ """
130
+ Retrieve a parameter or nested submodule by key.
131
+
132
+ Args:
133
+ key: Parameter name, which can contain dots for nested access.
134
+
135
+ Returns:
136
+ The parameter, tensor, or nested ParameterDictModel at the specified key.
137
+
138
+ Raises:
139
+ KeyError: If the key is not found in the model.
140
+ """
141
+ assert isinstance(
142
+ key, str
143
+ ), f"Key must be a string, but got {type(key)}: {key}."
144
+ if not has_nested_attr(self, key.split(".")):
95
145
  raise KeyError(f"Key {key} not found in {self}")
96
- key = key.split(".")
146
+ key_parts = key.split(".")
97
147
  obj = self
98
- for k in key:
148
+ for k in key_parts:
99
149
  obj = getattr(obj, k)
100
150
  return obj
101
151
 
102
- def __setitem__(self, key: str, value: nn.Parameter):
103
- if not has_attr(self, key.split(".")):
104
- _set_attr(self, key.split("."), value, check_parent=True)
152
+ def __setitem__(self, key: str, value: Union[nn.Parameter, torch.Tensor]) -> None:
153
+ """
154
+ Set a parameter at the specified key, creating nested structure if needed.
155
+
156
+ Args:
157
+ key: Parameter name, which can contain dots for nested assignment.
158
+ value: Parameter or tensor to assign.
159
+ """
160
+ if not has_nested_attr(self, key.split(".")):
161
+ set_nested_attr(self, key.split("."), value, check_parent=True)
105
162
  else:
106
- _set_attr(self, key.split("."), value, check_parent=False)
163
+ set_nested_attr(self, key.split("."), value, check_parent=False)
164
+
165
+ def __contains__(self, key: str) -> bool:
166
+ """
167
+ Check if a parameter key exists in the model.
107
168
 
108
- def __contains__(self, key: str):
109
- return has_attr(self, key.split("."))
169
+ Args:
170
+ key: Parameter name, which can contain dots for nested checking.
171
+
172
+ Returns:
173
+ True if the key exists, False otherwise.
174
+ """
175
+ return has_nested_attr(self, key.split("."))
110
176
 
111
177
  def keys(self):
112
- return [name for name, _ in self.named_parameters()]
178
+ """
179
+ Return a list of all parameter names in the model.
180
+
181
+ Returns:
182
+ List of parameter names (including nested names with dots).
183
+ """
184
+ return self.state_dict().keys()
185
+
186
+ def items(self):
187
+ """
188
+ Return a list of (name, parameter) tuples.
189
+
190
+ Returns:
191
+ List of tuples containing parameter names and their corresponding tensors.
192
+ """
193
+ yield from self.state_dict().items()
113
194
 
114
- def items(self) -> List[Tuple[str, nn.Parameter]]:
115
- return [(name, self[name]) for name in self.keys()]
195
+ def values(self):
196
+ """
197
+ Return a list of all parameter values in the model.
116
198
 
117
- def values(self) -> List[nn.Parameter]:
118
- return [self[name] for name in self.keys()]
199
+ Returns:
200
+ List of parameter tensors.
201
+ """
202
+ yield from self.state_dict().values()
119
203
 
120
- def __len__(self):
204
+ def __len__(self) -> int:
205
+ """
206
+ Return the number of parameters in the model.
207
+
208
+ Returns:
209
+ The total number of parameters.
210
+ """
121
211
  return len(self.keys())