fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +40 -83
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +10 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +36 -11
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +85 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/rich_utils.py +7 -3
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
|
@@ -16,11 +16,14 @@ from transformers.data import default_data_collator
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench.method import BaseAlgorithm
|
|
18
18
|
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
19
|
-
from fusion_bench.mixins
|
|
20
|
-
|
|
19
|
+
from fusion_bench.mixins import (
|
|
20
|
+
LightningFabricMixin,
|
|
21
|
+
SimpleProfilerMixin,
|
|
22
|
+
auto_register_config,
|
|
23
|
+
)
|
|
21
24
|
from fusion_bench.modelpool import Seq2SeqLMPool
|
|
22
25
|
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
23
|
-
from fusion_bench.utils import timeit_context
|
|
26
|
+
from fusion_bench.utils import print_parameters, timeit_context
|
|
24
27
|
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
25
28
|
from fusion_bench.utils.instantiate_utils import instantiate
|
|
26
29
|
from fusion_bench.utils.parameters import print_parameters
|
|
@@ -31,10 +34,11 @@ from .utils import get_memory_usage
|
|
|
31
34
|
log = logging.getLogger(__name__)
|
|
32
35
|
|
|
33
36
|
|
|
37
|
+
@auto_register_config
|
|
34
38
|
class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
35
|
-
BaseAlgorithm,
|
|
36
39
|
LightningFabricMixin,
|
|
37
40
|
SimpleProfilerMixin,
|
|
41
|
+
BaseAlgorithm,
|
|
38
42
|
):
|
|
39
43
|
"""
|
|
40
44
|
FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
|
|
@@ -60,7 +64,6 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
|
60
64
|
num_workers: int = 0,
|
|
61
65
|
max_steps: int = 1000,
|
|
62
66
|
use_grad_accumulate: bool = True,
|
|
63
|
-
cache_dir: bool = "outputs",
|
|
64
67
|
fast_dev_run: bool = False,
|
|
65
68
|
**kwargs,
|
|
66
69
|
):
|
|
@@ -70,23 +73,9 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
|
70
73
|
Args:
|
|
71
74
|
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
72
75
|
"""
|
|
73
|
-
self.checkpoint = checkpoint
|
|
74
|
-
self.save_checkpoint = save_checkpoint
|
|
75
|
-
self.router_hidden_layers = router_hidden_layers
|
|
76
|
-
self.init_lambda = init_lambda
|
|
77
|
-
self.batch_reduce = batch_reduce
|
|
78
|
-
self.lr = lr
|
|
79
|
-
self.optimizer = optimizer
|
|
80
|
-
self.devices = devices
|
|
81
|
-
self.batch_size = batch_size
|
|
82
|
-
self.num_workers = num_workers
|
|
83
|
-
self.max_steps = max_steps
|
|
84
|
-
self.use_grad_accumulate = use_grad_accumulate
|
|
85
|
-
self.cache_dir = cache_dir
|
|
86
|
-
self.fast_dev_run = fast_dev_run
|
|
87
76
|
super().__init__(**kwargs)
|
|
88
77
|
|
|
89
|
-
def construct_moe_model(self)
|
|
78
|
+
def construct_moe_model(self):
|
|
90
79
|
"""
|
|
91
80
|
Construct the Mixture of Experts (MoE) model using the models in the model pool.
|
|
92
81
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .wudi import WUDIMerging, wudi_merging
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
|
|
3
|
+
Arxiv: http://arxiv.org/abs/2503.08099
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
12
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
13
|
+
from fusion_bench.utils import timeit_context
|
|
14
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def wudi_merging(
|
|
18
|
+
task_vectors: List[torch.Tensor],
|
|
19
|
+
accelerator="cuda",
|
|
20
|
+
iter_num: int = 300,
|
|
21
|
+
exclude_keys: List[str] = None,
|
|
22
|
+
):
|
|
23
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
24
|
+
|
|
25
|
+
with timeit_context("WUDI Merging"):
|
|
26
|
+
new_vector = {}
|
|
27
|
+
for key in tqdm(task_vectors[0], desc="WUDI Merging", leave=False):
|
|
28
|
+
tqdm.write(f"key: {key}")
|
|
29
|
+
original_device = task_vectors[0][key].device
|
|
30
|
+
tvs = torch.stack(
|
|
31
|
+
[
|
|
32
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
33
|
+
for task_vector in task_vectors
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
num_tvs = len(tvs)
|
|
37
|
+
new_vector[key] = torch.nn.Parameter(torch.sum(tvs, dim=0))
|
|
38
|
+
|
|
39
|
+
if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
|
|
40
|
+
optimizer = torch.optim.Adam([new_vector[key]], lr=1e-5, weight_decay=0)
|
|
41
|
+
l2_norms = torch.square(
|
|
42
|
+
torch.norm(tvs.reshape(tvs.shape[0], -1), p=2, dim=-1)
|
|
43
|
+
)
|
|
44
|
+
for i in tqdm(
|
|
45
|
+
range(iter_num),
|
|
46
|
+
):
|
|
47
|
+
disturbing_vectors = new_vector[key].unsqueeze(0) - tvs
|
|
48
|
+
product = torch.matmul(disturbing_vectors, tvs.transpose(1, 2))
|
|
49
|
+
loss = torch.sum(
|
|
50
|
+
torch.square(product) / l2_norms.unsqueeze(-1).unsqueeze(-1)
|
|
51
|
+
)
|
|
52
|
+
optimizer.zero_grad()
|
|
53
|
+
loss.backward()
|
|
54
|
+
optimizer.step()
|
|
55
|
+
else:
|
|
56
|
+
new_vector[key] = new_vector[key] / num_tvs
|
|
57
|
+
new_vector[key] = new_vector[key].to(
|
|
58
|
+
device=original_device, non_blocking=True
|
|
59
|
+
)
|
|
60
|
+
return new_vector
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@auto_register_config
|
|
64
|
+
class WUDIMerging(
|
|
65
|
+
LightningFabricMixin,
|
|
66
|
+
BaseAlgorithm,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
iter_num: int,
|
|
75
|
+
exclude_keys: List[str] = None,
|
|
76
|
+
**kwargs,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
|
|
80
|
+
def run(self, modelpool: BaseModelPool):
|
|
81
|
+
# load the pretrained model and the task vectors of all the finetuned models
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
84
|
+
task_vectors = []
|
|
85
|
+
for model_name in modelpool.model_names:
|
|
86
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
87
|
+
task_vectors.append(
|
|
88
|
+
state_dict_sub(
|
|
89
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
del finetuned_model # free memory
|
|
93
|
+
|
|
94
|
+
merged_tv = wudi_merging(
|
|
95
|
+
task_vectors,
|
|
96
|
+
accelerator=self.fabric.device,
|
|
97
|
+
iter_num=self.iter_num,
|
|
98
|
+
exclude_keys=self.exclude_keys,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
pretrained_model.load_state_dict(
|
|
102
|
+
state_dict_add(pretrained_model.state_dict(), merged_tv)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return pretrained_model
|
|
@@ -113,11 +113,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
113
113
|
clip_model: Optional[CLIPModel] = None,
|
|
114
114
|
task_names: Optional[List[str]] = None,
|
|
115
115
|
):
|
|
116
|
+
"""
|
|
117
|
+
Initializes a zero-shot classification head.
|
|
118
|
+
|
|
119
|
+
This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
|
|
120
|
+
These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
|
|
121
|
+
from the provided CLIP model, which are necessary for calculating the final logits.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
|
|
125
|
+
clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
|
|
126
|
+
task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
|
|
127
|
+
"""
|
|
116
128
|
self.whether_setup_zero_shot_classification_head = True
|
|
129
|
+
# load clip model if not provided
|
|
117
130
|
if clip_model is None:
|
|
118
131
|
if self.modelpool.has_pretrained:
|
|
119
132
|
clip_model = self.modelpool.load_clip_model("_pretrained_")
|
|
120
133
|
else:
|
|
134
|
+
log.warning(
|
|
135
|
+
f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
|
|
136
|
+
)
|
|
121
137
|
clip_model = self.modelpool.load_clip_model(
|
|
122
138
|
self.modelpool.model_names[0]
|
|
123
139
|
)
|
|
@@ -166,16 +182,20 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
166
182
|
image_embeds: Optional[torch.Tensor] = None,
|
|
167
183
|
) -> torch.Tensor:
|
|
168
184
|
"""
|
|
169
|
-
|
|
185
|
+
Computes the classification logits for a batch of images for a specific task.
|
|
186
|
+
|
|
187
|
+
This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
|
|
188
|
+
The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
|
|
189
|
+
The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.
|
|
170
190
|
|
|
171
191
|
Args:
|
|
172
|
-
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The
|
|
173
|
-
images (torch.Tensor):
|
|
174
|
-
task (str): The
|
|
175
|
-
image_embeds (Optional[torch.Tensor]):
|
|
192
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
|
|
193
|
+
images (torch.Tensor): A batch of images to classify.
|
|
194
|
+
task (str): The name of the classification task.
|
|
195
|
+
image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.
|
|
176
196
|
|
|
177
197
|
Returns:
|
|
178
|
-
torch.Tensor:
|
|
198
|
+
torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
|
|
179
199
|
"""
|
|
180
200
|
text_embeds = self.zeroshot_weights[task]
|
|
181
201
|
|
|
@@ -100,6 +100,10 @@ class LightningFabricMixin:
|
|
|
100
100
|
self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
|
|
101
101
|
return self._fabric_instance
|
|
102
102
|
|
|
103
|
+
@fabric.setter
|
|
104
|
+
def fabric(self, instance: L.Fabric):
|
|
105
|
+
self._fabric_instance = instance
|
|
106
|
+
|
|
103
107
|
@property
|
|
104
108
|
def log_dir(self):
|
|
105
109
|
"""
|
|
@@ -4,8 +4,9 @@ from copy import deepcopy
|
|
|
4
4
|
from functools import wraps
|
|
5
5
|
from inspect import Parameter, _ParameterKind
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Dict, Optional, Union
|
|
7
|
+
from typing import Dict, Mapping, Optional, Union
|
|
8
8
|
|
|
9
|
+
from bidict import MutableBidict, bidict
|
|
9
10
|
from omegaconf import DictConfig, OmegaConf
|
|
10
11
|
|
|
11
12
|
from fusion_bench.constants import FUSION_BENCH_VERSION
|
|
@@ -15,12 +16,33 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
|
15
16
|
log = logging.getLogger(__name__)
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
|
-
"YAMLSerializationMixin",
|
|
19
19
|
"auto_register_config",
|
|
20
|
+
"YAMLSerializationMixin",
|
|
20
21
|
"BaseYAMLSerializable",
|
|
21
22
|
]
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
def _set_attr(self, param_name: str, value):
|
|
26
|
+
"""
|
|
27
|
+
Set an attribute on the object using the parameter name from config mapping.
|
|
28
|
+
|
|
29
|
+
This function looks up the corresponding attribute name for the given parameter
|
|
30
|
+
name using the object's _config_mapping, then sets that attribute to the
|
|
31
|
+
specified value. It also logs the operation for debugging purposes.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
self: The object instance to set the attribute on.
|
|
35
|
+
param_name (str): The parameter name (config key) to map to an attribute.
|
|
36
|
+
value: The value to assign to the attribute.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the parameter name is not found in the config mapping.
|
|
40
|
+
"""
|
|
41
|
+
attr_name = self._config_mapping.inverse[param_name]
|
|
42
|
+
log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
|
|
43
|
+
setattr(self, attr_name, value)
|
|
44
|
+
|
|
45
|
+
|
|
24
46
|
def auto_register_config(cls):
|
|
25
47
|
"""
|
|
26
48
|
Decorator to automatically register __init__ parameters in _config_mapping.
|
|
@@ -45,37 +67,16 @@ def auto_register_config(cls):
|
|
|
45
67
|
functionality and modified __init__ behavior.
|
|
46
68
|
|
|
47
69
|
Behavior:
|
|
48
|
-
- **Parameter Registration**: All non-variadic parameters (excluding
|
|
70
|
+
- **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
|
|
49
71
|
from the __init__ method are automatically added to _config_mapping
|
|
50
72
|
- **Positional Arguments**: Handled in order and mapped to corresponding parameter names
|
|
51
73
|
- **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
|
|
52
74
|
- **Default Values**: Applied when parameters are not provided via arguments
|
|
53
75
|
- **Attribute Setting**: All parameters become instance attributes accessible via dot notation
|
|
54
76
|
|
|
55
|
-
Example:
|
|
56
|
-
```python
|
|
57
|
-
@auto_register_config
|
|
58
|
-
class MyAlgorithm(BaseYAMLSerializable):
|
|
59
|
-
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
|
|
60
|
-
super().__init__()
|
|
61
|
-
|
|
62
|
-
# All instantiation methods work automatically:
|
|
63
|
-
algo1 = MyAlgorithm(0.01, 64) # positional args
|
|
64
|
-
algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
|
|
65
|
-
algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
|
|
66
|
-
|
|
67
|
-
# Attributes are automatically set and can be serialized:
|
|
68
|
-
print(algo1.learning_rate) # 0.01
|
|
69
|
-
print(algo1.batch_size) # 64
|
|
70
|
-
print(algo1.model_name) # "default" (from default value)
|
|
71
|
-
|
|
72
|
-
config = algo1.config
|
|
73
|
-
# DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
|
|
74
|
-
```
|
|
75
|
-
|
|
76
77
|
Note:
|
|
77
78
|
- The decorator wraps the original __init__ method while preserving its signature for IDE support
|
|
78
|
-
- Parameters with
|
|
79
|
+
- Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
|
|
79
80
|
- The attributes are auto-registered, then the original __init__ method is called,
|
|
80
81
|
- Type hints, method name, and other metadata are preserved using functools.wraps
|
|
81
82
|
- This decorator is designed to work seamlessly with the YAML serialization system
|
|
@@ -89,7 +90,10 @@ def auto_register_config(cls):
|
|
|
89
90
|
|
|
90
91
|
# Auto-register parameters in _config_mapping
|
|
91
92
|
if not "_config_mapping" in cls.__dict__:
|
|
92
|
-
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping",
|
|
93
|
+
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
|
|
94
|
+
if not isinstance(cls._config_mapping, bidict):
|
|
95
|
+
cls._config_mapping = bidict(cls._config_mapping)
|
|
96
|
+
|
|
93
97
|
registered_parameters = tuple(cls._config_mapping.values())
|
|
94
98
|
|
|
95
99
|
for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
|
|
@@ -102,9 +106,9 @@ def auto_register_config(cls):
|
|
|
102
106
|
) and (param_name not in registered_parameters):
|
|
103
107
|
cls._config_mapping[param_name] = param_name
|
|
104
108
|
|
|
109
|
+
@wraps(original_init)
|
|
105
110
|
def __init__(self, *args, **kwargs):
|
|
106
|
-
|
|
107
|
-
|
|
111
|
+
log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
|
|
108
112
|
# auto-register the attributes based on the signature
|
|
109
113
|
sig = inspect.signature(original_init)
|
|
110
114
|
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
|
@@ -117,29 +121,26 @@ def auto_register_config(cls):
|
|
|
117
121
|
_ParameterKind.VAR_POSITIONAL,
|
|
118
122
|
_ParameterKind.VAR_KEYWORD,
|
|
119
123
|
]:
|
|
120
|
-
|
|
124
|
+
_set_attr(self, param_name, arg_value)
|
|
121
125
|
|
|
122
126
|
# Handle keyword arguments and defaults
|
|
123
127
|
for param_name in param_names:
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
_ParameterKind.VAR_KEYWORD,
|
|
129
|
-
]
|
|
130
|
-
) and (param_name not in registered_parameters):
|
|
128
|
+
if sig.parameters[param_name].kind not in [
|
|
129
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
130
|
+
_ParameterKind.VAR_KEYWORD,
|
|
131
|
+
]:
|
|
131
132
|
# Skip if already set by positional argument
|
|
132
133
|
param_index = param_names.index(param_name)
|
|
133
134
|
if param_index >= 0 and param_index < len(args):
|
|
134
135
|
continue
|
|
135
136
|
|
|
136
137
|
if param_name in kwargs:
|
|
137
|
-
|
|
138
|
+
_set_attr(self, param_name, kwargs[param_name])
|
|
138
139
|
else:
|
|
139
140
|
# Set default value if available and attribute doesn't exist
|
|
140
141
|
default_value = sig.parameters[param_name].default
|
|
141
142
|
if default_value is not Parameter.empty:
|
|
142
|
-
|
|
143
|
+
_set_attr(self, param_name, default_value)
|
|
143
144
|
|
|
144
145
|
# Call the original __init__
|
|
145
146
|
result = original_init(self, *args, **kwargs)
|
|
@@ -152,33 +153,10 @@ def auto_register_config(cls):
|
|
|
152
153
|
|
|
153
154
|
class YAMLSerializationMixin:
|
|
154
155
|
_config_key: Optional[str] = None
|
|
155
|
-
_config_mapping:
|
|
156
|
+
_config_mapping: MutableBidict[str, str] = bidict()
|
|
156
157
|
R"""
|
|
157
158
|
`_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
|
|
158
159
|
|
|
159
|
-
For example, if an algorithm class is defined as follows:
|
|
160
|
-
|
|
161
|
-
```python
|
|
162
|
-
class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
163
|
-
hyper_parameter_1 = None
|
|
164
|
-
hyper_parameter_2 = None
|
|
165
|
-
|
|
166
|
-
_config_mapping = BaseModelFusionAlgorithm._config_mapping | {
|
|
167
|
-
"hyper_parameter_1" : "hyper_param_1",
|
|
168
|
-
"hyper_parameter_2" : "hyper_param_2",
|
|
169
|
-
}
|
|
170
|
-
def __init__(self, hyper_param_1: int, hyper_param_2: int):
|
|
171
|
-
self.hyper_parameter_1 = hyper_param_1
|
|
172
|
-
self.hyper_parameter_2 = hyper_param_2
|
|
173
|
-
super().__init__()
|
|
174
|
-
```
|
|
175
|
-
|
|
176
|
-
The model pool will be converted to a DictConfig as follows:
|
|
177
|
-
|
|
178
|
-
```python
|
|
179
|
-
algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
180
|
-
```
|
|
181
|
-
|
|
182
160
|
>>> algorithm.config
|
|
183
161
|
DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
184
162
|
|
|
@@ -197,17 +175,6 @@ class YAMLSerializationMixin:
|
|
|
197
175
|
This property converts the model pool instance into a dictionary
|
|
198
176
|
configuration, which can be used for serialization or other purposes.
|
|
199
177
|
|
|
200
|
-
Example:
|
|
201
|
-
|
|
202
|
-
```python
|
|
203
|
-
model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
204
|
-
config = model.config
|
|
205
|
-
print(config)
|
|
206
|
-
# DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
207
|
-
```
|
|
208
|
-
|
|
209
|
-
This is useful for serializing the object to a YAML file or for debugging.
|
|
210
|
-
|
|
211
178
|
Returns:
|
|
212
179
|
DictConfig: The configuration of the model pool.
|
|
213
180
|
"""
|
|
@@ -272,16 +239,6 @@ class YAMLSerializationMixin:
|
|
|
272
239
|
serialization. This is how the attribute will appear in YAML output.
|
|
273
240
|
value: The value to assign to the attribute.
|
|
274
241
|
|
|
275
|
-
Example:
|
|
276
|
-
```python
|
|
277
|
-
model = BaseYAMLSerializable()
|
|
278
|
-
model.set_option("learning_rate", "lr", 0.001)
|
|
279
|
-
|
|
280
|
-
# This sets model.learning_rate = 0.001
|
|
281
|
-
# and maps it to "lr" in the config output
|
|
282
|
-
config = model.config
|
|
283
|
-
# config will contain: {"lr": 0.001, ...}
|
|
284
|
-
```
|
|
285
242
|
"""
|
|
286
243
|
setattr(self, attr_name, value)
|
|
287
244
|
self._config_mapping[attr_name] = param_name
|
|
@@ -277,7 +277,7 @@ class BaseModelPool(
|
|
|
277
277
|
for dataset_name in self.test_dataset_names:
|
|
278
278
|
yield self.load_test_dataset(dataset_name)
|
|
279
279
|
|
|
280
|
-
def save_model(self, model: nn.Module, path: str):
|
|
280
|
+
def save_model(self, model: nn.Module, path: str, *args, **kwargs):
|
|
281
281
|
"""
|
|
282
282
|
Save the state dictionary of the model to the specified path.
|
|
283
283
|
|