fusion-bench 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +4 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/serialization.py +25 -15
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +262 -43
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_utils.py +9 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +82 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +53 -47
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,24 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
from torch import nn
|
|
8
|
+
from tqdm import tqdm
|
|
6
9
|
from typing_extensions import override
|
|
7
10
|
|
|
11
|
+
from fusion_bench import LazyStateDict, create_default_model_card, timeit_context
|
|
8
12
|
from fusion_bench.method import BaseAlgorithm
|
|
9
|
-
from fusion_bench.
|
|
13
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
14
|
+
from fusion_bench.modelpool import BaseModelPool, CausalLMPool
|
|
10
15
|
from fusion_bench.utils.type import StateDictType
|
|
11
16
|
|
|
12
17
|
from .slerp_utils import slerp
|
|
13
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from transformers import PreTrainedModel
|
|
21
|
+
|
|
14
22
|
log = logging.getLogger(__name__)
|
|
15
23
|
|
|
16
24
|
|
|
@@ -21,6 +29,7 @@ def slerp_on_state_dicts(
|
|
|
21
29
|
*,
|
|
22
30
|
DOT_THRESHOLD: float = 0.9995,
|
|
23
31
|
epsilon: float = 1e-8,
|
|
32
|
+
show_pbar: bool = False,
|
|
24
33
|
) -> StateDictType:
|
|
25
34
|
"""
|
|
26
35
|
Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
|
|
@@ -36,7 +45,8 @@ def slerp_on_state_dicts(
|
|
|
36
45
|
dict: The interpolated state dictionary.
|
|
37
46
|
"""
|
|
38
47
|
state_dict = {}
|
|
39
|
-
|
|
48
|
+
pbar = secondary_state_dict if not show_pbar else tqdm(secondary_state_dict)
|
|
49
|
+
for key in pbar:
|
|
40
50
|
v0 = primary_state_dict[key]
|
|
41
51
|
v1 = secondary_state_dict[key]
|
|
42
52
|
if v0.shape != v1.shape:
|
|
@@ -49,18 +59,19 @@ def slerp_on_state_dicts(
|
|
|
49
59
|
return state_dict
|
|
50
60
|
|
|
51
61
|
|
|
62
|
+
@auto_register_config
|
|
52
63
|
class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
53
64
|
"""
|
|
54
65
|
General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
|
|
55
66
|
"""
|
|
56
67
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
t: float,
|
|
71
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
72
|
+
epsilon: float = 1e-8,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
64
75
|
"""
|
|
65
76
|
Initialize the SlerpMergeAlgorithm.
|
|
66
77
|
|
|
@@ -69,10 +80,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
69
80
|
DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
|
|
70
81
|
epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
|
|
71
82
|
"""
|
|
72
|
-
|
|
73
|
-
self.DOT_THRESHOLD = DOT_THRESHOLD
|
|
74
|
-
self.epsilon = epsilon
|
|
75
|
-
super().__init__()
|
|
83
|
+
super().__init__(**kwargs)
|
|
76
84
|
|
|
77
85
|
@override
|
|
78
86
|
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
@@ -102,3 +110,91 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
102
110
|
|
|
103
111
|
primary_model.load_state_dict(state_dict)
|
|
104
112
|
return primary_model
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@auto_register_config
|
|
116
|
+
class SlerpForCausalLM(
|
|
117
|
+
SimpleProfilerMixin,
|
|
118
|
+
BaseAlgorithm,
|
|
119
|
+
):
|
|
120
|
+
"""
|
|
121
|
+
Slerp (Spherical Linear Interpolation) for Causal Language Models.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
t: float,
|
|
127
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
128
|
+
epsilon: float = 1e-8,
|
|
129
|
+
model_save_path: Optional[str] = None,
|
|
130
|
+
show_pbar: bool = False,
|
|
131
|
+
**kwargs,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Initialize the SlerpForCausalLM algorithm.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
t (float): The interpolation parameter. Must be in the range [0, 1].
|
|
138
|
+
t=0 returns the first model, t=1 returns the second model,
|
|
139
|
+
t=0.5 provides balanced interpolation.
|
|
140
|
+
DOT_THRESHOLD (float, optional): The threshold for the dot product of normalized vectors.
|
|
141
|
+
When the absolute dot product exceeds this threshold,
|
|
142
|
+
vectors are considered nearly collinear and linear
|
|
143
|
+
interpolation (LERP) is used instead of SLERP for
|
|
144
|
+
numerical stability. Defaults to 0.9995.
|
|
145
|
+
epsilon (float, optional): Small value used for numerical stability to avoid
|
|
146
|
+
division by zero during vector normalization.
|
|
147
|
+
Defaults to 1e-8.
|
|
148
|
+
model_save_path (Optional[str], optional): Path where the merged model should be saved.
|
|
149
|
+
If None, the model is not saved to disk.
|
|
150
|
+
Defaults to None.
|
|
151
|
+
show_pbar (bool, optional): Whether to display a progress bar during the interpolation
|
|
152
|
+
process. Useful for debugging or monitoring progress with
|
|
153
|
+
large models. Defaults to False.
|
|
154
|
+
**kwargs: Additional keyword arguments passed to the parent BaseAlgorithm class.
|
|
155
|
+
"""
|
|
156
|
+
super().__init__(**kwargs)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def run(self, modelpool: CausalLMPool):
|
|
160
|
+
assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
|
|
161
|
+
primary_model = modelpool.load_model(modelpool.all_model_names[0])
|
|
162
|
+
secondary_model = modelpool.load_model(modelpool.all_model_names[1])
|
|
163
|
+
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
primary_state_dict = primary_model.state_dict()
|
|
166
|
+
secondary_state_dict = secondary_model.state_dict()
|
|
167
|
+
state_dict = slerp_on_state_dicts(
|
|
168
|
+
self.t,
|
|
169
|
+
primary_state_dict,
|
|
170
|
+
secondary_state_dict,
|
|
171
|
+
DOT_THRESHOLD=self.DOT_THRESHOLD,
|
|
172
|
+
epsilon=self.epsilon,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if isinstance(primary_model, nn.Module):
|
|
176
|
+
model = primary_model
|
|
177
|
+
model.load_state_dict(state_dict)
|
|
178
|
+
elif isinstance(primary_model, LazyStateDict):
|
|
179
|
+
model: "PreTrainedModel" = deepcopy(primary_model.meta_module)
|
|
180
|
+
model.to(device=primary_model._device)
|
|
181
|
+
model.load_state_dict(state_dict)
|
|
182
|
+
else:
|
|
183
|
+
raise TypeError(
|
|
184
|
+
f"Unsupported model type: {type(primary_model)}. "
|
|
185
|
+
"Expected nn.Module or LazyStateDict."
|
|
186
|
+
)
|
|
187
|
+
if self.model_save_path is not None:
|
|
188
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
189
|
+
tokenizer = modelpool.load_tokenizer()
|
|
190
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
191
|
+
model.save_pretrained(self.model_save_path)
|
|
192
|
+
model_card_str = create_default_model_card(
|
|
193
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
194
|
+
description="Merged model using Slerp.",
|
|
195
|
+
algorithm_config=self.config,
|
|
196
|
+
modelpool_config=modelpool.config,
|
|
197
|
+
)
|
|
198
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
199
|
+
f.write(model_card_str)
|
|
200
|
+
return model
|
|
@@ -16,11 +16,14 @@ from transformers.data import default_data_collator
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench.method import BaseAlgorithm
|
|
18
18
|
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
19
|
-
from fusion_bench.mixins
|
|
20
|
-
|
|
19
|
+
from fusion_bench.mixins import (
|
|
20
|
+
LightningFabricMixin,
|
|
21
|
+
SimpleProfilerMixin,
|
|
22
|
+
auto_register_config,
|
|
23
|
+
)
|
|
21
24
|
from fusion_bench.modelpool import Seq2SeqLMPool
|
|
22
25
|
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
23
|
-
from fusion_bench.utils import timeit_context
|
|
26
|
+
from fusion_bench.utils import print_parameters, timeit_context
|
|
24
27
|
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
25
28
|
from fusion_bench.utils.instantiate_utils import instantiate
|
|
26
29
|
from fusion_bench.utils.parameters import print_parameters
|
|
@@ -31,10 +34,11 @@ from .utils import get_memory_usage
|
|
|
31
34
|
log = logging.getLogger(__name__)
|
|
32
35
|
|
|
33
36
|
|
|
37
|
+
@auto_register_config
|
|
34
38
|
class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
35
|
-
BaseAlgorithm,
|
|
36
39
|
LightningFabricMixin,
|
|
37
40
|
SimpleProfilerMixin,
|
|
41
|
+
BaseAlgorithm,
|
|
38
42
|
):
|
|
39
43
|
"""
|
|
40
44
|
FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
|
|
@@ -60,7 +64,6 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
|
60
64
|
num_workers: int = 0,
|
|
61
65
|
max_steps: int = 1000,
|
|
62
66
|
use_grad_accumulate: bool = True,
|
|
63
|
-
cache_dir: bool = "outputs",
|
|
64
67
|
fast_dev_run: bool = False,
|
|
65
68
|
**kwargs,
|
|
66
69
|
):
|
|
@@ -70,23 +73,9 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
|
70
73
|
Args:
|
|
71
74
|
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
72
75
|
"""
|
|
73
|
-
self.checkpoint = checkpoint
|
|
74
|
-
self.save_checkpoint = save_checkpoint
|
|
75
|
-
self.router_hidden_layers = router_hidden_layers
|
|
76
|
-
self.init_lambda = init_lambda
|
|
77
|
-
self.batch_reduce = batch_reduce
|
|
78
|
-
self.lr = lr
|
|
79
|
-
self.optimizer = optimizer
|
|
80
|
-
self.devices = devices
|
|
81
|
-
self.batch_size = batch_size
|
|
82
|
-
self.num_workers = num_workers
|
|
83
|
-
self.max_steps = max_steps
|
|
84
|
-
self.use_grad_accumulate = use_grad_accumulate
|
|
85
|
-
self.cache_dir = cache_dir
|
|
86
|
-
self.fast_dev_run = fast_dev_run
|
|
87
76
|
super().__init__(**kwargs)
|
|
88
77
|
|
|
89
|
-
def construct_moe_model(self)
|
|
78
|
+
def construct_moe_model(self):
|
|
90
79
|
"""
|
|
91
80
|
Construct the Mixture of Experts (MoE) model using the models in the model pool.
|
|
92
81
|
|
|
@@ -113,11 +113,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
113
113
|
clip_model: Optional[CLIPModel] = None,
|
|
114
114
|
task_names: Optional[List[str]] = None,
|
|
115
115
|
):
|
|
116
|
+
"""
|
|
117
|
+
Initializes a zero-shot classification head.
|
|
118
|
+
|
|
119
|
+
This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
|
|
120
|
+
These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
|
|
121
|
+
from the provided CLIP model, which are necessary for calculating the final logits.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
|
|
125
|
+
clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
|
|
126
|
+
task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
|
|
127
|
+
"""
|
|
116
128
|
self.whether_setup_zero_shot_classification_head = True
|
|
129
|
+
# load clip model if not provided
|
|
117
130
|
if clip_model is None:
|
|
118
131
|
if self.modelpool.has_pretrained:
|
|
119
132
|
clip_model = self.modelpool.load_clip_model("_pretrained_")
|
|
120
133
|
else:
|
|
134
|
+
log.warning(
|
|
135
|
+
f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
|
|
136
|
+
)
|
|
121
137
|
clip_model = self.modelpool.load_clip_model(
|
|
122
138
|
self.modelpool.model_names[0]
|
|
123
139
|
)
|
|
@@ -166,16 +182,20 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
166
182
|
image_embeds: Optional[torch.Tensor] = None,
|
|
167
183
|
) -> torch.Tensor:
|
|
168
184
|
"""
|
|
169
|
-
|
|
185
|
+
Computes the classification logits for a batch of images for a specific task.
|
|
186
|
+
|
|
187
|
+
This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
|
|
188
|
+
The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
|
|
189
|
+
The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.
|
|
170
190
|
|
|
171
191
|
Args:
|
|
172
|
-
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The
|
|
173
|
-
images (torch.Tensor):
|
|
174
|
-
task (str): The
|
|
175
|
-
image_embeds (Optional[torch.Tensor]):
|
|
192
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
|
|
193
|
+
images (torch.Tensor): A batch of images to classify.
|
|
194
|
+
task (str): The name of the classification task.
|
|
195
|
+
image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.
|
|
176
196
|
|
|
177
197
|
Returns:
|
|
178
|
-
torch.Tensor:
|
|
198
|
+
torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
|
|
179
199
|
"""
|
|
180
200
|
text_embeds = self.zeroshot_weights[task]
|
|
181
201
|
|
|
@@ -4,7 +4,7 @@ from copy import deepcopy
|
|
|
4
4
|
from functools import wraps
|
|
5
5
|
from inspect import Parameter, _ParameterKind
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Dict, Optional, Union
|
|
7
|
+
from typing import Dict, Mapping, Optional, Union
|
|
8
8
|
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
|
|
@@ -21,6 +21,20 @@ __all__ = [
|
|
|
21
21
|
]
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def _get_attr_name(config_mapping: Mapping[str, str], param_name):
|
|
25
|
+
for attr_name, p in config_mapping.items():
|
|
26
|
+
if p == param_name:
|
|
27
|
+
return attr_name
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(f"Parameter {param_name} not found in config mapping.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _set_attr(self, param_name: str, value):
|
|
33
|
+
attr_name = _get_attr_name(self._config_mapping, param_name)
|
|
34
|
+
log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
|
|
35
|
+
setattr(self, attr_name, value)
|
|
36
|
+
|
|
37
|
+
|
|
24
38
|
def auto_register_config(cls):
|
|
25
39
|
"""
|
|
26
40
|
Decorator to automatically register __init__ parameters in _config_mapping.
|
|
@@ -56,8 +70,8 @@ def auto_register_config(cls):
|
|
|
56
70
|
```python
|
|
57
71
|
@auto_register_config
|
|
58
72
|
class MyAlgorithm(BaseYAMLSerializable):
|
|
59
|
-
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
|
|
60
|
-
super().__init__()
|
|
73
|
+
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
|
|
74
|
+
super().__init__(**kwargs)
|
|
61
75
|
|
|
62
76
|
# All instantiation methods work automatically:
|
|
63
77
|
algo1 = MyAlgorithm(0.01, 64) # positional args
|
|
@@ -103,8 +117,7 @@ def auto_register_config(cls):
|
|
|
103
117
|
cls._config_mapping[param_name] = param_name
|
|
104
118
|
|
|
105
119
|
def __init__(self, *args, **kwargs):
|
|
106
|
-
|
|
107
|
-
|
|
120
|
+
log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
|
|
108
121
|
# auto-register the attributes based on the signature
|
|
109
122
|
sig = inspect.signature(original_init)
|
|
110
123
|
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
|
@@ -117,29 +130,26 @@ def auto_register_config(cls):
|
|
|
117
130
|
_ParameterKind.VAR_POSITIONAL,
|
|
118
131
|
_ParameterKind.VAR_KEYWORD,
|
|
119
132
|
]:
|
|
120
|
-
|
|
133
|
+
_set_attr(self, param_name, arg_value)
|
|
121
134
|
|
|
122
135
|
# Handle keyword arguments and defaults
|
|
123
136
|
for param_name in param_names:
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
_ParameterKind.VAR_KEYWORD,
|
|
129
|
-
]
|
|
130
|
-
) and (param_name not in registered_parameters):
|
|
137
|
+
if sig.parameters[param_name].kind not in [
|
|
138
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
139
|
+
_ParameterKind.VAR_KEYWORD,
|
|
140
|
+
]:
|
|
131
141
|
# Skip if already set by positional argument
|
|
132
142
|
param_index = param_names.index(param_name)
|
|
133
143
|
if param_index >= 0 and param_index < len(args):
|
|
134
144
|
continue
|
|
135
145
|
|
|
136
146
|
if param_name in kwargs:
|
|
137
|
-
|
|
147
|
+
_set_attr(self, param_name, kwargs[param_name])
|
|
138
148
|
else:
|
|
139
149
|
# Set default value if available and attribute doesn't exist
|
|
140
150
|
default_value = sig.parameters[param_name].default
|
|
141
151
|
if default_value is not Parameter.empty:
|
|
142
|
-
|
|
152
|
+
_set_attr(self, param_name, default_value)
|
|
143
153
|
|
|
144
154
|
# Call the original __init__
|
|
145
155
|
result = original_init(self, *args, **kwargs)
|
|
@@ -277,7 +277,7 @@ class BaseModelPool(
|
|
|
277
277
|
for dataset_name in self.test_dataset_names:
|
|
278
278
|
yield self.load_test_dataset(dataset_name)
|
|
279
279
|
|
|
280
|
-
def save_model(self, model: nn.Module, path: str):
|
|
280
|
+
def save_model(self, model: nn.Module, path: str, *args, **kwargs):
|
|
281
281
|
"""
|
|
282
282
|
Save the state dictionary of the model to the specified path.
|
|
283
283
|
|