fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +12 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/ensemble.py +17 -2
  12. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  14. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  15. fusion_bench/method/linear/__init__.py +6 -2
  16. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  17. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  18. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  19. fusion_bench/method/model_stock/__init__.py +1 -0
  20. fusion_bench/method/model_stock/model_stock.py +309 -0
  21. fusion_bench/method/regmean/clip_regmean.py +3 -6
  22. fusion_bench/method/regmean/regmean.py +27 -56
  23. fusion_bench/method/regmean/utils.py +56 -0
  24. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  25. fusion_bench/method/simple_average.py +2 -2
  26. fusion_bench/method/slerp/__init__.py +1 -1
  27. fusion_bench/method/slerp/slerp.py +110 -14
  28. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  29. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  30. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  31. fusion_bench/method/wudi/__init__.py +1 -0
  32. fusion_bench/method/wudi/wudi.py +105 -0
  33. fusion_bench/mixins/clip_classification.py +26 -6
  34. fusion_bench/mixins/lightning_fabric.py +4 -0
  35. fusion_bench/mixins/serialization.py +40 -83
  36. fusion_bench/modelpool/base_pool.py +1 -1
  37. fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
  38. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  39. fusion_bench/models/hf_clip.py +4 -0
  40. fusion_bench/models/hf_utils.py +10 -4
  41. fusion_bench/models/linearized/vision_model.py +6 -6
  42. fusion_bench/models/model_card_templates/default.md +8 -1
  43. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  44. fusion_bench/models/we_moe.py +8 -8
  45. fusion_bench/models/wrappers/ensemble.py +136 -7
  46. fusion_bench/scripts/cli.py +2 -2
  47. fusion_bench/taskpool/base_pool.py +99 -17
  48. fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
  49. fusion_bench/taskpool/dummy.py +101 -13
  50. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  51. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  52. fusion_bench/utils/__init__.py +1 -0
  53. fusion_bench/utils/data.py +6 -4
  54. fusion_bench/utils/devices.py +36 -11
  55. fusion_bench/utils/dtype.py +3 -2
  56. fusion_bench/utils/lazy_state_dict.py +85 -19
  57. fusion_bench/utils/packages.py +3 -3
  58. fusion_bench/utils/parameters.py +0 -2
  59. fusion_bench/utils/rich_utils.py +7 -3
  60. fusion_bench/utils/timer.py +92 -10
  61. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  62. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
  63. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  64. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  65. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  66. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  67. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  68. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  69. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  70. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  71. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  72. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  73. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  74. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  75. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  76. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  77. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  78. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,14 @@ from transformers.data import default_data_collator
16
16
 
17
17
  from fusion_bench.method import BaseAlgorithm
18
18
  from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
19
- from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
20
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
19
+ from fusion_bench.mixins import (
20
+ LightningFabricMixin,
21
+ SimpleProfilerMixin,
22
+ auto_register_config,
23
+ )
21
24
  from fusion_bench.modelpool import Seq2SeqLMPool
22
25
  from fusion_bench.models.we_moe import WeightEnsemblingMoE
23
- from fusion_bench.utils import timeit_context
26
+ from fusion_bench.utils import print_parameters, timeit_context
24
27
  from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
25
28
  from fusion_bench.utils.instantiate_utils import instantiate
26
29
  from fusion_bench.utils.parameters import print_parameters
@@ -31,10 +34,11 @@ from .utils import get_memory_usage
31
34
  log = logging.getLogger(__name__)
32
35
 
33
36
 
37
+ @auto_register_config
34
38
  class FlanT5WeightEnsemblingMoEAlgorithm(
35
- BaseAlgorithm,
36
39
  LightningFabricMixin,
37
40
  SimpleProfilerMixin,
41
+ BaseAlgorithm,
38
42
  ):
39
43
  """
40
44
  FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
@@ -60,7 +64,6 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
60
64
  num_workers: int = 0,
61
65
  max_steps: int = 1000,
62
66
  use_grad_accumulate: bool = True,
63
- cache_dir: bool = "outputs",
64
67
  fast_dev_run: bool = False,
65
68
  **kwargs,
66
69
  ):
@@ -70,23 +73,9 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
70
73
  Args:
71
74
  algorithm_config (DictConfig): The configuration for the algorithm.
72
75
  """
73
- self.checkpoint = checkpoint
74
- self.save_checkpoint = save_checkpoint
75
- self.router_hidden_layers = router_hidden_layers
76
- self.init_lambda = init_lambda
77
- self.batch_reduce = batch_reduce
78
- self.lr = lr
79
- self.optimizer = optimizer
80
- self.devices = devices
81
- self.batch_size = batch_size
82
- self.num_workers = num_workers
83
- self.max_steps = max_steps
84
- self.use_grad_accumulate = use_grad_accumulate
85
- self.cache_dir = cache_dir
86
- self.fast_dev_run = fast_dev_run
87
76
  super().__init__(**kwargs)
88
77
 
89
- def construct_moe_model(self) -> WeightEnsemblingMoE:
78
+ def construct_moe_model(self):
90
79
  """
91
80
  Construct the Mixture of Experts (MoE) model using the models in the model pool.
92
81
 
@@ -0,0 +1 @@
1
+ from .wudi import WUDIMerging, wudi_merging
@@ -0,0 +1,105 @@
1
+ """
2
+ Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
3
+ Arxiv: http://arxiv.org/abs/2503.08099
4
+ """
5
+
6
+ from typing import List
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
12
+ from fusion_bench.mixins import LightningFabricMixin
13
+ from fusion_bench.utils import timeit_context
14
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
15
+
16
+
17
+ def wudi_merging(
18
+ task_vectors: List[torch.Tensor],
19
+ accelerator="cuda",
20
+ iter_num: int = 300,
21
+ exclude_keys: List[str] = None,
22
+ ):
23
+ exclude_keys = [] if exclude_keys is None else exclude_keys
24
+
25
+ with timeit_context("WUDI Merging"):
26
+ new_vector = {}
27
+ for key in tqdm(task_vectors[0], desc="WUDI Merging", leave=False):
28
+ tqdm.write(f"key: {key}")
29
+ original_device = task_vectors[0][key].device
30
+ tvs = torch.stack(
31
+ [
32
+ task_vector[key].to(device=accelerator, non_blocking=True)
33
+ for task_vector in task_vectors
34
+ ]
35
+ )
36
+ num_tvs = len(tvs)
37
+ new_vector[key] = torch.nn.Parameter(torch.sum(tvs, dim=0))
38
+
39
+ if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
40
+ optimizer = torch.optim.Adam([new_vector[key]], lr=1e-5, weight_decay=0)
41
+ l2_norms = torch.square(
42
+ torch.norm(tvs.reshape(tvs.shape[0], -1), p=2, dim=-1)
43
+ )
44
+ for i in tqdm(
45
+ range(iter_num),
46
+ ):
47
+ disturbing_vectors = new_vector[key].unsqueeze(0) - tvs
48
+ product = torch.matmul(disturbing_vectors, tvs.transpose(1, 2))
49
+ loss = torch.sum(
50
+ torch.square(product) / l2_norms.unsqueeze(-1).unsqueeze(-1)
51
+ )
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+ else:
56
+ new_vector[key] = new_vector[key] / num_tvs
57
+ new_vector[key] = new_vector[key].to(
58
+ device=original_device, non_blocking=True
59
+ )
60
+ return new_vector
61
+
62
+
63
+ @auto_register_config
64
+ class WUDIMerging(
65
+ LightningFabricMixin,
66
+ BaseAlgorithm,
67
+ ):
68
+ """
69
+ Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ iter_num: int,
75
+ exclude_keys: List[str] = None,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+
80
+ def run(self, modelpool: BaseModelPool):
81
+ # load the pretrained model and the task vectors of all the finetuned models
82
+ with torch.no_grad():
83
+ pretrained_model = modelpool.load_pretrained_model()
84
+ task_vectors = []
85
+ for model_name in modelpool.model_names:
86
+ finetuned_model = modelpool.load_model(model_name)
87
+ task_vectors.append(
88
+ state_dict_sub(
89
+ finetuned_model.state_dict(), pretrained_model.state_dict()
90
+ )
91
+ )
92
+ del finetuned_model # free memory
93
+
94
+ merged_tv = wudi_merging(
95
+ task_vectors,
96
+ accelerator=self.fabric.device,
97
+ iter_num=self.iter_num,
98
+ exclude_keys=self.exclude_keys,
99
+ )
100
+
101
+ pretrained_model.load_state_dict(
102
+ state_dict_add(pretrained_model.state_dict(), merged_tv)
103
+ )
104
+
105
+ return pretrained_model
@@ -113,11 +113,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
113
113
  clip_model: Optional[CLIPModel] = None,
114
114
  task_names: Optional[List[str]] = None,
115
115
  ):
116
+ """
117
+ Initializes a zero-shot classification head.
118
+
119
+ This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
120
+ These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
121
+ from the provided CLIP model, which are necessary for calculating the final logits.
122
+
123
+ Args:
124
+ clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
125
+ clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
126
+ task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
127
+ """
116
128
  self.whether_setup_zero_shot_classification_head = True
129
+ # load clip model if not provided
117
130
  if clip_model is None:
118
131
  if self.modelpool.has_pretrained:
119
132
  clip_model = self.modelpool.load_clip_model("_pretrained_")
120
133
  else:
134
+ log.warning(
135
+ f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
136
+ )
121
137
  clip_model = self.modelpool.load_clip_model(
122
138
  self.modelpool.model_names[0]
123
139
  )
@@ -166,16 +182,20 @@ class CLIPClassificationMixin(LightningFabricMixin):
166
182
  image_embeds: Optional[torch.Tensor] = None,
167
183
  ) -> torch.Tensor:
168
184
  """
169
- Compute the logits of the images for a given task.
185
+ Computes the classification logits for a batch of images for a specific task.
186
+
187
+ This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
188
+ The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
189
+ The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.
170
190
 
171
191
  Args:
172
- module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
173
- images (torch.Tensor): The images to compute the logits.
174
- task (str): The task to compute the logits.
175
- image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
192
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
193
+ images (torch.Tensor): A batch of images to classify.
194
+ task (str): The name of the classification task.
195
+ image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.
176
196
 
177
197
  Returns:
178
- torch.Tensor: The logits of the images.
198
+ torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
179
199
  """
180
200
  text_embeds = self.zeroshot_weights[task]
181
201
 
@@ -100,6 +100,10 @@ class LightningFabricMixin:
100
100
  self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
101
101
  return self._fabric_instance
102
102
 
103
+ @fabric.setter
104
+ def fabric(self, instance: L.Fabric):
105
+ self._fabric_instance = instance
106
+
103
107
  @property
104
108
  def log_dir(self):
105
109
  """
@@ -4,8 +4,9 @@ from copy import deepcopy
4
4
  from functools import wraps
5
5
  from inspect import Parameter, _ParameterKind
6
6
  from pathlib import Path
7
- from typing import Dict, Optional, Union
7
+ from typing import Dict, Mapping, Optional, Union
8
8
 
9
+ from bidict import MutableBidict, bidict
9
10
  from omegaconf import DictConfig, OmegaConf
10
11
 
11
12
  from fusion_bench.constants import FUSION_BENCH_VERSION
@@ -15,12 +16,33 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
15
16
  log = logging.getLogger(__name__)
16
17
 
17
18
  __all__ = [
18
- "YAMLSerializationMixin",
19
19
  "auto_register_config",
20
+ "YAMLSerializationMixin",
20
21
  "BaseYAMLSerializable",
21
22
  ]
22
23
 
23
24
 
25
+ def _set_attr(self, param_name: str, value):
26
+ """
27
+ Set an attribute on the object using the parameter name from config mapping.
28
+
29
+ This function looks up the corresponding attribute name for the given parameter
30
+ name using the object's _config_mapping, then sets that attribute to the
31
+ specified value. It also logs the operation for debugging purposes.
32
+
33
+ Args:
34
+ self: The object instance to set the attribute on.
35
+ param_name (str): The parameter name (config key) to map to an attribute.
36
+ value: The value to assign to the attribute.
37
+
38
+ Raises:
39
+ ValueError: If the parameter name is not found in the config mapping.
40
+ """
41
+ attr_name = self._config_mapping.inverse[param_name]
42
+ log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
43
+ setattr(self, attr_name, value)
44
+
45
+
24
46
  def auto_register_config(cls):
25
47
  """
26
48
  Decorator to automatically register __init__ parameters in _config_mapping.
@@ -45,37 +67,16 @@ def auto_register_config(cls):
45
67
  functionality and modified __init__ behavior.
46
68
 
47
69
  Behavior:
48
- - **Parameter Registration**: All non-variadic parameters (excluding *args, **kwargs)
70
+ - **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
49
71
  from the __init__ method are automatically added to _config_mapping
50
72
  - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
51
73
  - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
52
74
  - **Default Values**: Applied when parameters are not provided via arguments
53
75
  - **Attribute Setting**: All parameters become instance attributes accessible via dot notation
54
76
 
55
- Example:
56
- ```python
57
- @auto_register_config
58
- class MyAlgorithm(BaseYAMLSerializable):
59
- def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
60
- super().__init__()
61
-
62
- # All instantiation methods work automatically:
63
- algo1 = MyAlgorithm(0.01, 64) # positional args
64
- algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
65
- algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
66
-
67
- # Attributes are automatically set and can be serialized:
68
- print(algo1.learning_rate) # 0.01
69
- print(algo1.batch_size) # 64
70
- print(algo1.model_name) # "default" (from default value)
71
-
72
- config = algo1.config
73
- # DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
74
- ```
75
-
76
77
  Note:
77
78
  - The decorator wraps the original __init__ method while preserving its signature for IDE support
78
- - Parameters with *args or **kwargs signatures are ignored during registration
79
+ - Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
79
80
  - The attributes are auto-registered, then the original __init__ method is called,
80
81
  - Type hints, method name, and other metadata are preserved using functools.wraps
81
82
  - This decorator is designed to work seamlessly with the YAML serialization system
@@ -89,7 +90,10 @@ def auto_register_config(cls):
89
90
 
90
91
  # Auto-register parameters in _config_mapping
91
92
  if not "_config_mapping" in cls.__dict__:
92
- cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
93
+ cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
94
+ if not isinstance(cls._config_mapping, bidict):
95
+ cls._config_mapping = bidict(cls._config_mapping)
96
+
93
97
  registered_parameters = tuple(cls._config_mapping.values())
94
98
 
95
99
  for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
@@ -102,9 +106,9 @@ def auto_register_config(cls):
102
106
  ) and (param_name not in registered_parameters):
103
107
  cls._config_mapping[param_name] = param_name
104
108
 
109
+ @wraps(original_init)
105
110
  def __init__(self, *args, **kwargs):
106
- nonlocal original_init, registered_parameters
107
-
111
+ log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
108
112
  # auto-register the attributes based on the signature
109
113
  sig = inspect.signature(original_init)
110
114
  param_names = list(sig.parameters.keys())[1:] # Skip 'self'
@@ -117,29 +121,26 @@ def auto_register_config(cls):
117
121
  _ParameterKind.VAR_POSITIONAL,
118
122
  _ParameterKind.VAR_KEYWORD,
119
123
  ]:
120
- setattr(self, param_name, arg_value)
124
+ _set_attr(self, param_name, arg_value)
121
125
 
122
126
  # Handle keyword arguments and defaults
123
127
  for param_name in param_names:
124
- if (
125
- sig.parameters[param_name].kind
126
- not in [
127
- _ParameterKind.VAR_POSITIONAL,
128
- _ParameterKind.VAR_KEYWORD,
129
- ]
130
- ) and (param_name not in registered_parameters):
128
+ if sig.parameters[param_name].kind not in [
129
+ _ParameterKind.VAR_POSITIONAL,
130
+ _ParameterKind.VAR_KEYWORD,
131
+ ]:
131
132
  # Skip if already set by positional argument
132
133
  param_index = param_names.index(param_name)
133
134
  if param_index >= 0 and param_index < len(args):
134
135
  continue
135
136
 
136
137
  if param_name in kwargs:
137
- setattr(self, param_name, kwargs[param_name])
138
+ _set_attr(self, param_name, kwargs[param_name])
138
139
  else:
139
140
  # Set default value if available and attribute doesn't exist
140
141
  default_value = sig.parameters[param_name].default
141
142
  if default_value is not Parameter.empty:
142
- setattr(self, param_name, default_value)
143
+ _set_attr(self, param_name, default_value)
143
144
 
144
145
  # Call the original __init__
145
146
  result = original_init(self, *args, **kwargs)
@@ -152,33 +153,10 @@ def auto_register_config(cls):
152
153
 
153
154
  class YAMLSerializationMixin:
154
155
  _config_key: Optional[str] = None
155
- _config_mapping: Dict[str, str] = {}
156
+ _config_mapping: MutableBidict[str, str] = bidict()
156
157
  R"""
157
158
  `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
158
159
 
159
- For example, if an algorithm class is defined as follows:
160
-
161
- ```python
162
- class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
163
- hyper_parameter_1 = None
164
- hyper_parameter_2 = None
165
-
166
- _config_mapping = BaseModelFusionAlgorithm._config_mapping | {
167
- "hyper_parameter_1" : "hyper_param_1",
168
- "hyper_parameter_2" : "hyper_param_2",
169
- }
170
- def __init__(self, hyper_param_1: int, hyper_param_2: int):
171
- self.hyper_parameter_1 = hyper_param_1
172
- self.hyper_parameter_2 = hyper_param_2
173
- super().__init__()
174
- ```
175
-
176
- The model pool will be converted to a DictConfig as follows:
177
-
178
- ```python
179
- algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
180
- ```
181
-
182
160
  >>> algorithm.config
183
161
  DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
184
162
 
@@ -197,17 +175,6 @@ class YAMLSerializationMixin:
197
175
  This property converts the model pool instance into a dictionary
198
176
  configuration, which can be used for serialization or other purposes.
199
177
 
200
- Example:
201
-
202
- ```python
203
- model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
204
- config = model.config
205
- print(config)
206
- # DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
207
- ```
208
-
209
- This is useful for serializing the object to a YAML file or for debugging.
210
-
211
178
  Returns:
212
179
  DictConfig: The configuration of the model pool.
213
180
  """
@@ -272,16 +239,6 @@ class YAMLSerializationMixin:
272
239
  serialization. This is how the attribute will appear in YAML output.
273
240
  value: The value to assign to the attribute.
274
241
 
275
- Example:
276
- ```python
277
- model = BaseYAMLSerializable()
278
- model.set_option("learning_rate", "lr", 0.001)
279
-
280
- # This sets model.learning_rate = 0.001
281
- # and maps it to "lr" in the config output
282
- config = model.config
283
- # config will contain: {"lr": 0.001, ...}
284
- ```
285
242
  """
286
243
  setattr(self, attr_name, value)
287
244
  self._config_mapping[attr_name] = param_name
@@ -277,7 +277,7 @@ class BaseModelPool(
277
277
  for dataset_name in self.test_dataset_names:
278
278
  yield self.load_test_dataset(dataset_name)
279
279
 
280
- def save_model(self, model: nn.Module, path: str):
280
+ def save_model(self, model: nn.Module, path: str, *args, **kwargs):
281
281
  """
282
282
  Save the state dictionary of the model to the specified path.
283
283