fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,7 @@ from copy import deepcopy
|
|
|
8
8
|
from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
|
|
9
9
|
|
|
10
10
|
import peft
|
|
11
|
-
from omegaconf import DictConfig, flag_override
|
|
11
|
+
from omegaconf import DictConfig, OmegaConf, flag_override
|
|
12
12
|
from torch import nn
|
|
13
13
|
from torch.nn.modules import Module
|
|
14
14
|
from transformers import (
|
|
@@ -19,45 +19,86 @@ from transformers import (
|
|
|
19
19
|
)
|
|
20
20
|
from typing_extensions import override
|
|
21
21
|
|
|
22
|
-
from fusion_bench
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
from fusion_bench import (
|
|
23
|
+
BaseModelPool,
|
|
24
|
+
auto_register_config,
|
|
25
|
+
import_object,
|
|
26
|
+
instantiate,
|
|
27
|
+
parse_dtype,
|
|
28
|
+
)
|
|
25
29
|
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
26
|
-
from fusion_bench.utils.packages import import_object
|
|
27
30
|
|
|
28
31
|
log = logging.getLogger(__name__)
|
|
29
32
|
|
|
30
33
|
|
|
34
|
+
@auto_register_config
|
|
31
35
|
class CausalLMPool(BaseModelPool):
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
36
|
+
"""A model pool for managing and loading causal language models.
|
|
37
|
+
|
|
38
|
+
This class provides a unified interface for loading and managing multiple
|
|
39
|
+
causal language models, typically used in model fusion and ensemble scenarios.
|
|
40
|
+
It supports both eager and lazy loading strategies, and handles model
|
|
41
|
+
configuration through YAML configs or direct instantiation.
|
|
42
|
+
|
|
43
|
+
The pool can manage models from Hugging Face Hub, local paths, or custom
|
|
44
|
+
configurations. It also provides tokenizer management and model saving
|
|
45
|
+
capabilities with optional Hugging Face Hub integration.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
models: Dictionary or configuration specifying the models to be managed.
|
|
49
|
+
Can contain model names mapped to paths or detailed configurations.
|
|
50
|
+
tokenizer: Tokenizer configuration, either a string path/name or
|
|
51
|
+
a DictConfig with detailed tokenizer settings.
|
|
52
|
+
model_kwargs: Additional keyword arguments passed to model loading.
|
|
53
|
+
Common options include torch_dtype, device_map, etc.
|
|
54
|
+
enable_lazy_loading: Whether to use lazy loading for models. When True,
|
|
55
|
+
models are loaded as LazyStateDict objects instead of actual models,
|
|
56
|
+
which can save memory for large model collections.
|
|
57
|
+
**kwargs: Additional arguments passed to the parent BaseModelPool.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
```python
|
|
61
|
+
>>> pool = CausalLMPool(
|
|
62
|
+
... models={
|
|
63
|
+
... "model_a": "microsoft/DialoGPT-medium",
|
|
64
|
+
... "model_b": "/path/to/local/model"
|
|
65
|
+
... },
|
|
66
|
+
... tokenizer="microsoft/DialoGPT-medium",
|
|
67
|
+
... model_kwargs={"torch_dtype": "bfloat16"}
|
|
68
|
+
... )
|
|
69
|
+
>>> model = pool.load_model("model_a")
|
|
70
|
+
>>> tokenizer = pool.load_tokenizer()
|
|
71
|
+
```
|
|
72
|
+
"""
|
|
37
73
|
|
|
38
74
|
def __init__(
|
|
39
75
|
self,
|
|
40
76
|
models,
|
|
41
77
|
*,
|
|
42
|
-
tokenizer: Optional[DictConfig],
|
|
78
|
+
tokenizer: Optional[DictConfig | str],
|
|
43
79
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
-
|
|
80
|
+
enable_lazy_loading: bool = False,
|
|
45
81
|
**kwargs,
|
|
46
82
|
):
|
|
47
83
|
super().__init__(models, **kwargs)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
self._model_kwargs = model_kwargs
|
|
51
|
-
if self._model_kwargs is None:
|
|
52
|
-
self._model_kwargs = DictConfig({})
|
|
53
|
-
with flag_override(self._model_kwargs, "allow_objects", True):
|
|
54
|
-
if hasattr(self._model_kwargs, "torch_dtype"):
|
|
55
|
-
self._model_kwargs.torch_dtype = parse_dtype(
|
|
56
|
-
self._model_kwargs.torch_dtype
|
|
57
|
-
)
|
|
58
|
-
self.load_lazy = load_lazy
|
|
84
|
+
if model_kwargs is None:
|
|
85
|
+
self.model_kwargs = DictConfig({})
|
|
59
86
|
|
|
60
87
|
def get_model_path(self, model_name: str):
|
|
88
|
+
"""Extract the model path from the model configuration.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model_name: The name of the model as defined in the models configuration.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
str: The path or identifier for the model. For string configurations,
|
|
95
|
+
returns the string directly. For dict configurations, extracts
|
|
96
|
+
the 'pretrained_model_name_or_path' field.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
RuntimeError: If the model configuration is invalid or the model
|
|
100
|
+
name is not found in the configuration.
|
|
101
|
+
"""
|
|
61
102
|
model_name_or_config = self._models[model_name]
|
|
62
103
|
if isinstance(model_name_or_config, str):
|
|
63
104
|
return model_name_or_config
|
|
@@ -66,39 +107,80 @@ class CausalLMPool(BaseModelPool):
|
|
|
66
107
|
else:
|
|
67
108
|
raise RuntimeError("Invalid model configuration")
|
|
68
109
|
|
|
110
|
+
def get_model_kwargs(self):
|
|
111
|
+
"""Get processed model keyword arguments for model loading.
|
|
112
|
+
|
|
113
|
+
Converts the stored `model_kwargs` from DictConfig to a regular dictionary
|
|
114
|
+
and processes special arguments like torch_dtype for proper model loading.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
dict: Processed keyword arguments ready to be passed to model
|
|
118
|
+
loading functions. The torch_dtype field, if present, is
|
|
119
|
+
converted from string to the appropriate torch dtype object.
|
|
120
|
+
"""
|
|
121
|
+
model_kwargs = (
|
|
122
|
+
OmegaConf.to_container(self.model_kwargs, resolve=True)
|
|
123
|
+
if isinstance(self.model_kwargs, DictConfig)
|
|
124
|
+
else self.model_kwargs
|
|
125
|
+
)
|
|
126
|
+
if "torch_dtype" in model_kwargs:
|
|
127
|
+
model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
|
|
128
|
+
return model_kwargs
|
|
129
|
+
|
|
69
130
|
@override
|
|
70
131
|
def load_model(
|
|
71
132
|
self,
|
|
72
133
|
model_name_or_config: str | DictConfig,
|
|
73
134
|
*args,
|
|
74
135
|
**kwargs,
|
|
75
|
-
) -> PreTrainedModel:
|
|
76
|
-
"""
|
|
77
|
-
Example of YAML config:
|
|
136
|
+
) -> Union[PreTrainedModel, LazyStateDict]:
|
|
137
|
+
"""Load a causal language model from the model pool.
|
|
78
138
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
model_b: path_to_model_b
|
|
84
|
-
```
|
|
139
|
+
This method supports multiple loading strategies:
|
|
140
|
+
1. Loading by model name from the configured model pool
|
|
141
|
+
2. Loading from a direct configuration dictionary
|
|
142
|
+
3. Lazy loading using LazyStateDict for memory efficiency
|
|
85
143
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
144
|
+
The method automatically handles different model configuration formats
|
|
145
|
+
and applies the appropriate loading strategy based on the enable_lazy_loading flag.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
model_name_or_config: Either a string model name that exists in the
|
|
149
|
+
model pool configuration, or a DictConfig/dict containing the
|
|
150
|
+
model configuration directly.
|
|
151
|
+
*args: Additional positional arguments passed to the model constructor.
|
|
152
|
+
**kwargs: Additional keyword arguments passed to the model constructor.
|
|
153
|
+
These will be merged with the pool's model_kwargs.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a
|
|
157
|
+
PreTrainedModel for normal loading or a LazyStateDict for lazy loading.
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
RuntimeError: If the model configuration is invalid.
|
|
161
|
+
KeyError: If the model name is not found in the model pool.
|
|
162
|
+
|
|
163
|
+
Example YAML configurations:
|
|
164
|
+
Simple string configuration:
|
|
165
|
+
```yaml
|
|
166
|
+
models:
|
|
167
|
+
_pretrained_: path_to_pretrained_model
|
|
168
|
+
model_a: path_to_model_a
|
|
169
|
+
model_b: path_to_model_b
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
Detailed configuration:
|
|
173
|
+
```yaml
|
|
174
|
+
models:
|
|
175
|
+
_pretrained_:
|
|
176
|
+
_target_: transformers.AutoModelForCausalLM
|
|
177
|
+
pretrained_model_name_or_path: path_to_pretrained_model
|
|
178
|
+
model_a:
|
|
179
|
+
_target_: transformers.AutoModelForCausalLM
|
|
180
|
+
pretrained_model_name_or_path: path_to_model_a
|
|
181
|
+
```
|
|
100
182
|
"""
|
|
101
|
-
model_kwargs =
|
|
183
|
+
model_kwargs = self.get_model_kwargs()
|
|
102
184
|
model_kwargs.update(kwargs)
|
|
103
185
|
|
|
104
186
|
if isinstance(model_name_or_config, str):
|
|
@@ -108,7 +190,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
108
190
|
model_config = self._models[model_name_or_config]
|
|
109
191
|
if isinstance(model_config, str):
|
|
110
192
|
# model_config is a string
|
|
111
|
-
if not self.
|
|
193
|
+
if not self.enable_lazy_loading:
|
|
112
194
|
model = AutoModelForCausalLM.from_pretrained(
|
|
113
195
|
model_config,
|
|
114
196
|
*args,
|
|
@@ -126,7 +208,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
126
208
|
elif isinstance(model_name_or_config, (DictConfig, Dict)):
|
|
127
209
|
model_config = model_name_or_config
|
|
128
210
|
|
|
129
|
-
if not self.
|
|
211
|
+
if not self.enable_lazy_loading:
|
|
130
212
|
model = instantiate(model_config, *args, **model_kwargs)
|
|
131
213
|
else:
|
|
132
214
|
meta_module_class = model_config.pop("_target_")
|
|
@@ -140,30 +222,43 @@ class CausalLMPool(BaseModelPool):
|
|
|
140
222
|
return model
|
|
141
223
|
|
|
142
224
|
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
143
|
-
"""
|
|
144
|
-
Example of YAML config:
|
|
225
|
+
"""Load the tokenizer associated with this model pool.
|
|
145
226
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
227
|
+
Loads a tokenizer based on the tokenizer configuration provided during
|
|
228
|
+
pool initialization. Supports both simple string paths and detailed
|
|
229
|
+
configuration dictionaries.
|
|
149
230
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
tokenizer:
|
|
154
|
-
_target_: transformers.AutoTokenizer # any callable that returns a tokenizer
|
|
155
|
-
pretrained_model_name_or_path: google/gemma-2-2b-it
|
|
156
|
-
```
|
|
231
|
+
Args:
|
|
232
|
+
*args: Additional positional arguments passed to the tokenizer constructor.
|
|
233
|
+
**kwargs: Additional keyword arguments passed to the tokenizer constructor.
|
|
157
234
|
|
|
158
235
|
Returns:
|
|
159
|
-
PreTrainedTokenizer: The tokenizer.
|
|
236
|
+
PreTrainedTokenizer: The loaded tokenizer instance.
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
AssertionError: If no tokenizer is defined in the configuration.
|
|
240
|
+
|
|
241
|
+
Example YAML configurations:
|
|
242
|
+
Simple string configuration:
|
|
243
|
+
```yaml
|
|
244
|
+
tokenizer: google/gemma-2-2b-it
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
Detailed configuration:
|
|
248
|
+
```yaml
|
|
249
|
+
tokenizer:
|
|
250
|
+
_target_: transformers.AutoTokenizer
|
|
251
|
+
pretrained_model_name_or_path: google/gemma-2-2b-it
|
|
252
|
+
use_fast: true
|
|
253
|
+
padding_side: left
|
|
254
|
+
```
|
|
160
255
|
"""
|
|
161
|
-
assert self.
|
|
256
|
+
assert self.tokenizer is not None, "Tokenizer is not defined in the config"
|
|
162
257
|
log.info("Loading tokenizer.", stacklevel=2)
|
|
163
|
-
if isinstance(self.
|
|
164
|
-
tokenizer = AutoTokenizer.from_pretrained(self.
|
|
258
|
+
if isinstance(self.tokenizer, str):
|
|
259
|
+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
|
|
165
260
|
else:
|
|
166
|
-
tokenizer = instantiate(self.
|
|
261
|
+
tokenizer = instantiate(self.tokenizer, *args, **kwargs)
|
|
167
262
|
return tokenizer
|
|
168
263
|
|
|
169
264
|
@override
|
|
@@ -178,15 +273,49 @@ class CausalLMPool(BaseModelPool):
|
|
|
178
273
|
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
179
274
|
**kwargs,
|
|
180
275
|
):
|
|
181
|
-
"""
|
|
182
|
-
|
|
276
|
+
"""Save a model to the specified path with optional tokenizer and Hub upload.
|
|
277
|
+
|
|
278
|
+
This method provides comprehensive model saving capabilities including
|
|
279
|
+
optional tokenizer saving, dtype conversion, and Hugging Face Hub upload.
|
|
280
|
+
The model is saved in the standard Hugging Face format.
|
|
183
281
|
|
|
184
282
|
Args:
|
|
185
|
-
model
|
|
186
|
-
path
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
283
|
+
model: The PreTrainedModel instance to be saved.
|
|
284
|
+
path: The local path where the model will be saved. Supports tilde
|
|
285
|
+
expansion for home directory paths.
|
|
286
|
+
push_to_hub: Whether to push the saved model to the Hugging Face Hub.
|
|
287
|
+
Requires proper authentication and repository permissions.
|
|
288
|
+
model_dtype: Optional string specifying the target dtype for the model
|
|
289
|
+
before saving (e.g., "float16", "bfloat16"). The model will be
|
|
290
|
+
converted to this dtype before saving.
|
|
291
|
+
save_tokenizer: Whether to save the tokenizer alongside the model.
|
|
292
|
+
If True, the tokenizer will be loaded using the pool's tokenizer
|
|
293
|
+
configuration and saved to the same path.
|
|
294
|
+
tokenizer_kwargs: Additional keyword arguments for tokenizer loading
|
|
295
|
+
when save_tokenizer is True.
|
|
296
|
+
tokenizer: Optional pre-loaded tokenizer instance. If provided, this
|
|
297
|
+
tokenizer will be saved regardless of the save_tokenizer flag.
|
|
298
|
+
**kwargs: Additional keyword arguments passed to the model's
|
|
299
|
+
save_pretrained method.
|
|
300
|
+
|
|
301
|
+
Side Effects:
|
|
302
|
+
- Creates model files in the specified directory
|
|
303
|
+
- Optionally creates tokenizer files in the same directory
|
|
304
|
+
- May convert the model to a different dtype
|
|
305
|
+
- May upload files to Hugging Face Hub
|
|
306
|
+
|
|
307
|
+
Example:
|
|
308
|
+
```python
|
|
309
|
+
>>> pool = CausalLMPool(models=..., tokenizer=...)
|
|
310
|
+
>>> model = pool.load_model("my_model")
|
|
311
|
+
>>> pool.save_model(
|
|
312
|
+
... model,
|
|
313
|
+
... "/path/to/save",
|
|
314
|
+
... save_tokenizer=True,
|
|
315
|
+
... model_dtype="float16",
|
|
316
|
+
... push_to_hub=True
|
|
317
|
+
... )
|
|
318
|
+
```
|
|
190
319
|
"""
|
|
191
320
|
path = os.path.expanduser(path)
|
|
192
321
|
# NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
|
|
@@ -210,15 +339,61 @@ class CausalLMPool(BaseModelPool):
|
|
|
210
339
|
|
|
211
340
|
|
|
212
341
|
class CausalLMBackbonePool(CausalLMPool):
|
|
342
|
+
"""A specialized model pool that loads only the transformer backbone layers.
|
|
343
|
+
|
|
344
|
+
This class extends CausalLMPool to provide access to just the transformer
|
|
345
|
+
layers (backbone) of causal language models, excluding the language modeling
|
|
346
|
+
head and embeddings. This is useful for model fusion scenarios where only
|
|
347
|
+
the core transformer layers are needed.
|
|
348
|
+
|
|
349
|
+
The class automatically extracts the `model.layers` component from loaded
|
|
350
|
+
AutoModelForCausalLM instances, providing direct access to the transformer
|
|
351
|
+
blocks. Lazy loading is not supported for this pool type.
|
|
352
|
+
|
|
353
|
+
Note:
|
|
354
|
+
This pool automatically disables lazy loading as it needs to access
|
|
355
|
+
the internal structure of the model to extract the backbone layers.
|
|
356
|
+
|
|
357
|
+
Example:
|
|
358
|
+
```python
|
|
359
|
+
>>> backbone_pool = CausalLMBackbonePool(
|
|
360
|
+
... models={"model_a": "microsoft/DialoGPT-medium"},
|
|
361
|
+
... tokenizer="microsoft/DialoGPT-medium"
|
|
362
|
+
... )
|
|
363
|
+
>>> layers = backbone_pool.load_model("model_a") # Returns nn.ModuleList of transformer layers
|
|
364
|
+
```
|
|
365
|
+
"""
|
|
366
|
+
|
|
213
367
|
def load_model(
|
|
214
368
|
self, model_name_or_config: str | DictConfig, *args, **kwargs
|
|
215
369
|
) -> Module:
|
|
216
|
-
|
|
370
|
+
"""Load only the transformer backbone layers from a causal language model.
|
|
371
|
+
|
|
372
|
+
This method loads a complete causal language model and then extracts
|
|
373
|
+
only the transformer layers (backbone), discarding the embedding layers
|
|
374
|
+
and language modeling head. This is useful for model fusion scenarios
|
|
375
|
+
where only the core transformer computation is needed.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
model_name_or_config: Either a string model name from the pool
|
|
379
|
+
configuration or a DictConfig with model loading parameters.
|
|
380
|
+
*args: Additional positional arguments passed to the parent load_model method.
|
|
381
|
+
**kwargs: Additional keyword arguments passed to the parent load_model method.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
Module: The transformer layers (typically a nn.ModuleList) containing
|
|
385
|
+
the core transformer blocks without embeddings or output heads.
|
|
386
|
+
|
|
387
|
+
Note:
|
|
388
|
+
Lazy loading is automatically disabled for this method as it needs
|
|
389
|
+
to access the internal model structure to extract the layers.
|
|
390
|
+
"""
|
|
391
|
+
if self.enable_lazy_loading:
|
|
217
392
|
log.warning(
|
|
218
393
|
"CausalLMBackbonePool does not support lazy loading. "
|
|
219
394
|
"Falling back to normal loading."
|
|
220
395
|
)
|
|
221
|
-
self.
|
|
396
|
+
self.enable_lazy_loading = False
|
|
222
397
|
model: AutoModelForCausalLM = super().load_model(
|
|
223
398
|
model_name_or_config, *args, **kwargs
|
|
224
399
|
)
|
|
@@ -232,6 +407,49 @@ def load_peft_causal_lm(
|
|
|
232
407
|
is_trainable: bool = True,
|
|
233
408
|
merge_and_unload: bool = False,
|
|
234
409
|
):
|
|
410
|
+
"""Load a causal language model with PEFT (Parameter-Efficient Fine-Tuning) adapters.
|
|
411
|
+
|
|
412
|
+
This function loads a base causal language model and applies PEFT adapters
|
|
413
|
+
(such as LoRA, AdaLoRA, or other parameter-efficient fine-tuning methods)
|
|
414
|
+
to create a fine-tuned model. It supports both keeping the adapters separate
|
|
415
|
+
or merging them into the base model.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
base_model_path: Path or identifier for the base causal language model.
|
|
419
|
+
Can be a Hugging Face model name or local path.
|
|
420
|
+
peft_model_path: Path to the PEFT adapter configuration and weights.
|
|
421
|
+
This should contain the adapter_config.json and adapter weights.
|
|
422
|
+
torch_dtype: The torch data type to use for the model. Common options
|
|
423
|
+
include "float16", "bfloat16", "float32". Defaults to "bfloat16".
|
|
424
|
+
is_trainable: Whether the loaded PEFT model should be trainable.
|
|
425
|
+
Set to False for inference-only usage to save memory.
|
|
426
|
+
merge_and_unload: Whether to merge the PEFT adapters into the base model
|
|
427
|
+
and unload the adapter weights. When True, returns a standard
|
|
428
|
+
PreTrainedModel instead of a PeftModel.
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
Union[PeftModel, PreTrainedModel]: The loaded model with PEFT adapters.
|
|
432
|
+
Returns a PeftModel if merge_and_unload is False, or a PreTrainedModel
|
|
433
|
+
if the adapters are merged and unloaded.
|
|
434
|
+
|
|
435
|
+
Example:
|
|
436
|
+
```python
|
|
437
|
+
>>> # Load model with adapters for training
|
|
438
|
+
>>> model = load_peft_causal_lm(
|
|
439
|
+
... "microsoft/DialoGPT-medium",
|
|
440
|
+
... "/path/to/lora/adapters",
|
|
441
|
+
... is_trainable=True
|
|
442
|
+
... )
|
|
443
|
+
|
|
444
|
+
>>> # Load and merge adapters for inference
|
|
445
|
+
>>> merged_model = load_peft_causal_lm(
|
|
446
|
+
... "microsoft/DialoGPT-medium",
|
|
447
|
+
... "/path/to/lora/adapters",
|
|
448
|
+
... merge_and_unload=True,
|
|
449
|
+
... is_trainable=False
|
|
450
|
+
... )
|
|
451
|
+
```
|
|
452
|
+
"""
|
|
235
453
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
236
454
|
base_model_path, torch_dtype=torch_dtype
|
|
237
455
|
)
|
|
@@ -18,6 +18,48 @@ def load_lora_model(
|
|
|
18
18
|
is_trainable: bool = True,
|
|
19
19
|
merge_and_unload: bool = True,
|
|
20
20
|
):
|
|
21
|
+
"""Load a sequence-to-sequence model with LoRA (Low-Rank Adaptation) fine-tuning.
|
|
22
|
+
|
|
23
|
+
This function loads a base sequence-to-sequence language model and applies
|
|
24
|
+
LoRA adapters for parameter-efficient fine-tuning. LoRA allows for efficient
|
|
25
|
+
adaptation of large models by adding trainable low-rank matrices to the
|
|
26
|
+
existing weights without modifying the original parameters.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
base_model_path: Path or identifier for the base sequence-to-sequence model.
|
|
30
|
+
Can be a Hugging Face model name (e.g., "t5-base") or local path.
|
|
31
|
+
peft_model_path: Path to the directory containing LoRA adapter weights
|
|
32
|
+
and configuration. Should include adapter_config.json and adapter weights.
|
|
33
|
+
is_trainable: Whether the loaded model should be trainable. Set to False
|
|
34
|
+
for inference-only usage to save memory and computation.
|
|
35
|
+
merge_and_unload: Whether to merge the LoRA weights into the base model
|
|
36
|
+
and unload the adapter. When True, returns a standard model instead
|
|
37
|
+
of a PeftModel, which can be more efficient for inference.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Union[PeftModel, AutoModelForSeq2SeqLM]: The loaded model with LoRA
|
|
41
|
+
adapters. Returns a PeftModel if merge_and_unload is False, or
|
|
42
|
+
a standard AutoModelForSeq2SeqLM if adapters are merged.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
```python
|
|
46
|
+
>>> # Load model with separate adapters for training
|
|
47
|
+
>>> model = load_lora_model(
|
|
48
|
+
... "t5-base",
|
|
49
|
+
... "/path/to/lora/adapters",
|
|
50
|
+
... is_trainable=True,
|
|
51
|
+
... merge_and_unload=False
|
|
52
|
+
... )
|
|
53
|
+
|
|
54
|
+
>>> # Load and merge adapters for efficient inference
|
|
55
|
+
>>> merged_model = load_lora_model(
|
|
56
|
+
... "t5-base",
|
|
57
|
+
... "/path/to/lora/adapters",
|
|
58
|
+
... is_trainable=False,
|
|
59
|
+
... merge_and_unload=True
|
|
60
|
+
... )
|
|
61
|
+
```
|
|
62
|
+
"""
|
|
21
63
|
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path)
|
|
22
64
|
model = PeftModel.from_pretrained(
|
|
23
65
|
base_model,
|
|
@@ -30,6 +72,46 @@ def load_lora_model(
|
|
|
30
72
|
|
|
31
73
|
|
|
32
74
|
class Seq2SeqLMPool(BaseModelPool):
|
|
75
|
+
"""A model pool specialized for sequence-to-sequence language models.
|
|
76
|
+
|
|
77
|
+
This model pool provides management and loading capabilities for sequence-to-sequence
|
|
78
|
+
(seq2seq) language models such as T5, BART, and mT5. It extends the base model pool
|
|
79
|
+
functionality with seq2seq-specific features including tokenizer management and
|
|
80
|
+
model configuration handling.
|
|
81
|
+
|
|
82
|
+
Seq2seq models are particularly useful for tasks that require generating output
|
|
83
|
+
sequences from input sequences, such as translation, summarization, question
|
|
84
|
+
answering, and text generation. This pool streamlines the process of loading
|
|
85
|
+
and configuring multiple seq2seq models for fusion and ensemble scenarios.
|
|
86
|
+
|
|
87
|
+
Key Features:
|
|
88
|
+
- Specialized loading for AutoModelForSeq2SeqLM models
|
|
89
|
+
- Integrated tokenizer management
|
|
90
|
+
- Support for model-specific keyword arguments
|
|
91
|
+
- Automatic dtype parsing and configuration
|
|
92
|
+
- Compatible with PEFT (Parameter-Efficient Fine-Tuning) adapters
|
|
93
|
+
|
|
94
|
+
Attributes:
|
|
95
|
+
_tokenizer: Configuration for the tokenizer associated with the models
|
|
96
|
+
_model_kwargs: Default keyword arguments applied to all model loading operations
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
```python
|
|
100
|
+
pool = Seq2SeqLMPool(
|
|
101
|
+
models={
|
|
102
|
+
"t5_base": "t5-base",
|
|
103
|
+
"t5_large": "t5-large",
|
|
104
|
+
"custom_model": "/path/to/local/model"
|
|
105
|
+
},
|
|
106
|
+
tokenizer={"_target_": "transformers.T5Tokenizer",
|
|
107
|
+
"pretrained_model_name_or_path": "t5-base"},
|
|
108
|
+
model_kwargs={"torch_dtype": "float16", "device_map": "auto"}
|
|
109
|
+
)
|
|
110
|
+
model = pool.load_model("t5_base")
|
|
111
|
+
tokenizer = pool.load_tokenizer()
|
|
112
|
+
```
|
|
113
|
+
"""
|
|
114
|
+
|
|
33
115
|
_config_mapping = BaseModelPool._config_mapping | {
|
|
34
116
|
"_tokenizer": "tokenizer",
|
|
35
117
|
"_model_kwargs": "model_kwargs",
|
|
@@ -43,6 +125,35 @@ class Seq2SeqLMPool(BaseModelPool):
|
|
|
43
125
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
126
|
**kwargs,
|
|
45
127
|
):
|
|
128
|
+
"""Initialize the sequence-to-sequence language model pool.
|
|
129
|
+
|
|
130
|
+
Sets up the model pool with configurations for models, tokenizer, and
|
|
131
|
+
default model loading parameters. Automatically processes model kwargs
|
|
132
|
+
to handle special configurations like torch_dtype parsing.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
models: Configuration dictionary specifying the seq2seq models to manage.
|
|
136
|
+
Keys are model names, values can be model paths/names or detailed configs.
|
|
137
|
+
tokenizer: Configuration for the tokenizer to use with the models.
|
|
138
|
+
Can be a simple path/name or detailed configuration with _target_.
|
|
139
|
+
model_kwargs: Default keyword arguments applied to all model loading
|
|
140
|
+
operations. Common options include torch_dtype, device_map, etc.
|
|
141
|
+
The torch_dtype field is automatically parsed from string to dtype.
|
|
142
|
+
**kwargs: Additional arguments passed to the parent BaseModelPool.
|
|
143
|
+
|
|
144
|
+
Example:
|
|
145
|
+
```python
|
|
146
|
+
pool = Seq2SeqLMPool(
|
|
147
|
+
models={
|
|
148
|
+
"base": "t5-base",
|
|
149
|
+
"large": {"_target_": "transformers.AutoModelForSeq2SeqLM",
|
|
150
|
+
"pretrained_model_name_or_path": "t5-large"}
|
|
151
|
+
},
|
|
152
|
+
tokenizer="t5-base",
|
|
153
|
+
model_kwargs={"torch_dtype": "bfloat16"}
|
|
154
|
+
)
|
|
155
|
+
```
|
|
156
|
+
"""
|
|
46
157
|
super().__init__(models, **kwargs)
|
|
47
158
|
self._tokenizer = tokenizer
|
|
48
159
|
self._model_kwargs = model_kwargs
|
|
@@ -55,11 +166,46 @@ class Seq2SeqLMPool(BaseModelPool):
|
|
|
55
166
|
)
|
|
56
167
|
|
|
57
168
|
def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
|
|
169
|
+
"""Load a sequence-to-sequence language model from the pool.
|
|
170
|
+
|
|
171
|
+
Loads a seq2seq model using the parent class loading mechanism while
|
|
172
|
+
automatically applying the pool's default model kwargs. The method
|
|
173
|
+
merges the pool's model_kwargs with any additional kwargs provided,
|
|
174
|
+
giving priority to the explicitly provided kwargs.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model_name_or_config: Either a string model name from the pool
|
|
178
|
+
configuration or a DictConfig containing model loading parameters.
|
|
179
|
+
*args: Additional positional arguments passed to the parent load_model method.
|
|
180
|
+
**kwargs: Additional keyword arguments that override the pool's default
|
|
181
|
+
model_kwargs. Common options include device, torch_dtype, etc.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
AutoModelForSeq2SeqLM: The loaded sequence-to-sequence language model.
|
|
185
|
+
"""
|
|
58
186
|
model_kwargs = deepcopy(self._model_kwargs)
|
|
59
187
|
model_kwargs.update(kwargs)
|
|
60
188
|
return super().load_model(model_name_or_config, *args, **model_kwargs)
|
|
61
189
|
|
|
62
190
|
def load_tokenizer(self, *args, **kwargs):
|
|
191
|
+
"""Load the tokenizer associated with the sequence-to-sequence models.
|
|
192
|
+
|
|
193
|
+
Loads a tokenizer based on the tokenizer configuration provided during
|
|
194
|
+
pool initialization. The tokenizer should be compatible with the seq2seq
|
|
195
|
+
models in the pool and is typically used for preprocessing input text
|
|
196
|
+
and postprocessing generated output.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
*args: Additional positional arguments passed to the tokenizer constructor.
|
|
200
|
+
**kwargs: Additional keyword arguments passed to the tokenizer constructor.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
PreTrainedTokenizer: The loaded tokenizer instance compatible with
|
|
204
|
+
the seq2seq models in this pool.
|
|
205
|
+
|
|
206
|
+
Raises:
|
|
207
|
+
AssertionError: If no tokenizer configuration is provided.
|
|
208
|
+
"""
|
|
63
209
|
assert self._tokenizer is not None, "Tokenizer is not defined in the config"
|
|
64
210
|
tokenizer = isinstance(self._tokenizer, *args, **kwargs)
|
|
65
211
|
return tokenizer
|
fusion_bench/models/__init__.py
CHANGED