fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ from copy import deepcopy
8
8
  from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
9
9
 
10
10
  import peft
11
- from omegaconf import DictConfig, flag_override
11
+ from omegaconf import DictConfig, OmegaConf, flag_override
12
12
  from torch import nn
13
13
  from torch.nn.modules import Module
14
14
  from transformers import (
@@ -19,45 +19,86 @@ from transformers import (
19
19
  )
20
20
  from typing_extensions import override
21
21
 
22
- from fusion_bench.modelpool import BaseModelPool
23
- from fusion_bench.utils import instantiate
24
- from fusion_bench.utils.dtype import parse_dtype
22
+ from fusion_bench import (
23
+ BaseModelPool,
24
+ auto_register_config,
25
+ import_object,
26
+ instantiate,
27
+ parse_dtype,
28
+ )
25
29
  from fusion_bench.utils.lazy_state_dict import LazyStateDict
26
- from fusion_bench.utils.packages import import_object
27
30
 
28
31
  log = logging.getLogger(__name__)
29
32
 
30
33
 
34
+ @auto_register_config
31
35
  class CausalLMPool(BaseModelPool):
32
- _config_mapping = BaseModelPool._config_mapping | {
33
- "_tokenizer": "tokenizer",
34
- "_model_kwargs": "model_kwargs",
35
- "load_lazy": "load_lazy",
36
- }
36
+ """A model pool for managing and loading causal language models.
37
+
38
+ This class provides a unified interface for loading and managing multiple
39
+ causal language models, typically used in model fusion and ensemble scenarios.
40
+ It supports both eager and lazy loading strategies, and handles model
41
+ configuration through YAML configs or direct instantiation.
42
+
43
+ The pool can manage models from Hugging Face Hub, local paths, or custom
44
+ configurations. It also provides tokenizer management and model saving
45
+ capabilities with optional Hugging Face Hub integration.
46
+
47
+ Args:
48
+ models: Dictionary or configuration specifying the models to be managed.
49
+ Can contain model names mapped to paths or detailed configurations.
50
+ tokenizer: Tokenizer configuration, either a string path/name or
51
+ a DictConfig with detailed tokenizer settings.
52
+ model_kwargs: Additional keyword arguments passed to model loading.
53
+ Common options include torch_dtype, device_map, etc.
54
+ enable_lazy_loading: Whether to use lazy loading for models. When True,
55
+ models are loaded as LazyStateDict objects instead of actual models,
56
+ which can save memory for large model collections.
57
+ **kwargs: Additional arguments passed to the parent BaseModelPool.
58
+
59
+ Example:
60
+ ```python
61
+ >>> pool = CausalLMPool(
62
+ ... models={
63
+ ... "model_a": "microsoft/DialoGPT-medium",
64
+ ... "model_b": "/path/to/local/model"
65
+ ... },
66
+ ... tokenizer="microsoft/DialoGPT-medium",
67
+ ... model_kwargs={"torch_dtype": "bfloat16"}
68
+ ... )
69
+ >>> model = pool.load_model("model_a")
70
+ >>> tokenizer = pool.load_tokenizer()
71
+ ```
72
+ """
37
73
 
38
74
  def __init__(
39
75
  self,
40
76
  models,
41
77
  *,
42
- tokenizer: Optional[DictConfig],
78
+ tokenizer: Optional[DictConfig | str],
43
79
  model_kwargs: Optional[DictConfig] = None,
44
- load_lazy: bool = False,
80
+ enable_lazy_loading: bool = False,
45
81
  **kwargs,
46
82
  ):
47
83
  super().__init__(models, **kwargs)
48
- # process `model_kwargs`
49
- self._tokenizer = tokenizer
50
- self._model_kwargs = model_kwargs
51
- if self._model_kwargs is None:
52
- self._model_kwargs = DictConfig({})
53
- with flag_override(self._model_kwargs, "allow_objects", True):
54
- if hasattr(self._model_kwargs, "torch_dtype"):
55
- self._model_kwargs.torch_dtype = parse_dtype(
56
- self._model_kwargs.torch_dtype
57
- )
58
- self.load_lazy = load_lazy
84
+ if model_kwargs is None:
85
+ self.model_kwargs = DictConfig({})
59
86
 
60
87
  def get_model_path(self, model_name: str):
88
+ """Extract the model path from the model configuration.
89
+
90
+ Args:
91
+ model_name: The name of the model as defined in the models configuration.
92
+
93
+ Returns:
94
+ str: The path or identifier for the model. For string configurations,
95
+ returns the string directly. For dict configurations, extracts
96
+ the 'pretrained_model_name_or_path' field.
97
+
98
+ Raises:
99
+ RuntimeError: If the model configuration is invalid or the model
100
+ name is not found in the configuration.
101
+ """
61
102
  model_name_or_config = self._models[model_name]
62
103
  if isinstance(model_name_or_config, str):
63
104
  return model_name_or_config
@@ -66,39 +107,80 @@ class CausalLMPool(BaseModelPool):
66
107
  else:
67
108
  raise RuntimeError("Invalid model configuration")
68
109
 
110
+ def get_model_kwargs(self):
111
+ """Get processed model keyword arguments for model loading.
112
+
113
+ Converts the stored `model_kwargs` from DictConfig to a regular dictionary
114
+ and processes special arguments like torch_dtype for proper model loading.
115
+
116
+ Returns:
117
+ dict: Processed keyword arguments ready to be passed to model
118
+ loading functions. The torch_dtype field, if present, is
119
+ converted from string to the appropriate torch dtype object.
120
+ """
121
+ model_kwargs = (
122
+ OmegaConf.to_container(self.model_kwargs, resolve=True)
123
+ if isinstance(self.model_kwargs, DictConfig)
124
+ else self.model_kwargs
125
+ )
126
+ if "torch_dtype" in model_kwargs:
127
+ model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
128
+ return model_kwargs
129
+
69
130
  @override
70
131
  def load_model(
71
132
  self,
72
133
  model_name_or_config: str | DictConfig,
73
134
  *args,
74
135
  **kwargs,
75
- ) -> PreTrainedModel:
76
- """
77
- Example of YAML config:
136
+ ) -> Union[PreTrainedModel, LazyStateDict]:
137
+ """Load a causal language model from the model pool.
78
138
 
79
- ```yaml
80
- models:
81
- _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
82
- model_a: path_to_model_a
83
- model_b: path_to_model_b
84
- ```
139
+ This method supports multiple loading strategies:
140
+ 1. Loading by model name from the configured model pool
141
+ 2. Loading from a direct configuration dictionary
142
+ 3. Lazy loading using LazyStateDict for memory efficiency
85
143
 
86
- or equivalently,
87
-
88
- ```yaml
89
- models:
90
- _pretrained_:
91
- _target_: transformers.AutoModelForCausalLM # any callable that returns a model
92
- pretrained_model_name_or_path: path_to_pretrained_model
93
- model_a:
94
- _target_: transformers.AutoModelForCausalLM
95
- pretrained_model_name_or_path: path_to_model_a
96
- model_b:
97
- _target_: transformers.AutoModelForCausalLM
98
- pretrained_model_name_or_path: path_to_model_b
99
- ```
144
+ The method automatically handles different model configuration formats
145
+ and applies the appropriate loading strategy based on the enable_lazy_loading flag.
146
+
147
+ Args:
148
+ model_name_or_config: Either a string model name that exists in the
149
+ model pool configuration, or a DictConfig/dict containing the
150
+ model configuration directly.
151
+ *args: Additional positional arguments passed to the model constructor.
152
+ **kwargs: Additional keyword arguments passed to the model constructor.
153
+ These will be merged with the pool's model_kwargs.
154
+
155
+ Returns:
156
+ Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a
157
+ PreTrainedModel for normal loading or a LazyStateDict for lazy loading.
158
+
159
+ Raises:
160
+ RuntimeError: If the model configuration is invalid.
161
+ KeyError: If the model name is not found in the model pool.
162
+
163
+ Example YAML configurations:
164
+ Simple string configuration:
165
+ ```yaml
166
+ models:
167
+ _pretrained_: path_to_pretrained_model
168
+ model_a: path_to_model_a
169
+ model_b: path_to_model_b
170
+ ```
171
+
172
+ Detailed configuration:
173
+ ```yaml
174
+ models:
175
+ _pretrained_:
176
+ _target_: transformers.AutoModelForCausalLM
177
+ pretrained_model_name_or_path: path_to_pretrained_model
178
+ model_a:
179
+ _target_: transformers.AutoModelForCausalLM
180
+ pretrained_model_name_or_path: path_to_model_a
181
+ ```
100
182
  """
101
- model_kwargs = deepcopy(self._model_kwargs)
183
+ model_kwargs = self.get_model_kwargs()
102
184
  model_kwargs.update(kwargs)
103
185
 
104
186
  if isinstance(model_name_or_config, str):
@@ -108,7 +190,7 @@ class CausalLMPool(BaseModelPool):
108
190
  model_config = self._models[model_name_or_config]
109
191
  if isinstance(model_config, str):
110
192
  # model_config is a string
111
- if not self.load_lazy:
193
+ if not self.enable_lazy_loading:
112
194
  model = AutoModelForCausalLM.from_pretrained(
113
195
  model_config,
114
196
  *args,
@@ -126,7 +208,7 @@ class CausalLMPool(BaseModelPool):
126
208
  elif isinstance(model_name_or_config, (DictConfig, Dict)):
127
209
  model_config = model_name_or_config
128
210
 
129
- if not self.load_lazy:
211
+ if not self.enable_lazy_loading:
130
212
  model = instantiate(model_config, *args, **model_kwargs)
131
213
  else:
132
214
  meta_module_class = model_config.pop("_target_")
@@ -140,30 +222,43 @@ class CausalLMPool(BaseModelPool):
140
222
  return model
141
223
 
142
224
  def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
143
- """
144
- Example of YAML config:
225
+ """Load the tokenizer associated with this model pool.
145
226
 
146
- ```yaml
147
- tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained
148
- ```
227
+ Loads a tokenizer based on the tokenizer configuration provided during
228
+ pool initialization. Supports both simple string paths and detailed
229
+ configuration dictionaries.
149
230
 
150
- or equivalently,
151
-
152
- ```yaml
153
- tokenizer:
154
- _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
155
- pretrained_model_name_or_path: google/gemma-2-2b-it
156
- ```
231
+ Args:
232
+ *args: Additional positional arguments passed to the tokenizer constructor.
233
+ **kwargs: Additional keyword arguments passed to the tokenizer constructor.
157
234
 
158
235
  Returns:
159
- PreTrainedTokenizer: The tokenizer.
236
+ PreTrainedTokenizer: The loaded tokenizer instance.
237
+
238
+ Raises:
239
+ AssertionError: If no tokenizer is defined in the configuration.
240
+
241
+ Example YAML configurations:
242
+ Simple string configuration:
243
+ ```yaml
244
+ tokenizer: google/gemma-2-2b-it
245
+ ```
246
+
247
+ Detailed configuration:
248
+ ```yaml
249
+ tokenizer:
250
+ _target_: transformers.AutoTokenizer
251
+ pretrained_model_name_or_path: google/gemma-2-2b-it
252
+ use_fast: true
253
+ padding_side: left
254
+ ```
160
255
  """
161
- assert self._tokenizer is not None, "Tokenizer is not defined in the config"
256
+ assert self.tokenizer is not None, "Tokenizer is not defined in the config"
162
257
  log.info("Loading tokenizer.", stacklevel=2)
163
- if isinstance(self._tokenizer, str):
164
- tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
258
+ if isinstance(self.tokenizer, str):
259
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
165
260
  else:
166
- tokenizer = instantiate(self._tokenizer, *args, **kwargs)
261
+ tokenizer = instantiate(self.tokenizer, *args, **kwargs)
167
262
  return tokenizer
168
263
 
169
264
  @override
@@ -178,15 +273,49 @@ class CausalLMPool(BaseModelPool):
178
273
  tokenizer: Optional[PreTrainedTokenizer] = None,
179
274
  **kwargs,
180
275
  ):
181
- """
182
- Save the model to the specified path.
276
+ """Save a model to the specified path with optional tokenizer and Hub upload.
277
+
278
+ This method provides comprehensive model saving capabilities including
279
+ optional tokenizer saving, dtype conversion, and Hugging Face Hub upload.
280
+ The model is saved in the standard Hugging Face format.
183
281
 
184
282
  Args:
185
- model (PreTrainedModel): The model to be saved.
186
- path (str): The path where the model will be saved.
187
- push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
188
- save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
189
- **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
283
+ model: The PreTrainedModel instance to be saved.
284
+ path: The local path where the model will be saved. Supports tilde
285
+ expansion for home directory paths.
286
+ push_to_hub: Whether to push the saved model to the Hugging Face Hub.
287
+ Requires proper authentication and repository permissions.
288
+ model_dtype: Optional string specifying the target dtype for the model
289
+ before saving (e.g., "float16", "bfloat16"). The model will be
290
+ converted to this dtype before saving.
291
+ save_tokenizer: Whether to save the tokenizer alongside the model.
292
+ If True, the tokenizer will be loaded using the pool's tokenizer
293
+ configuration and saved to the same path.
294
+ tokenizer_kwargs: Additional keyword arguments for tokenizer loading
295
+ when save_tokenizer is True.
296
+ tokenizer: Optional pre-loaded tokenizer instance. If provided, this
297
+ tokenizer will be saved regardless of the save_tokenizer flag.
298
+ **kwargs: Additional keyword arguments passed to the model's
299
+ save_pretrained method.
300
+
301
+ Side Effects:
302
+ - Creates model files in the specified directory
303
+ - Optionally creates tokenizer files in the same directory
304
+ - May convert the model to a different dtype
305
+ - May upload files to Hugging Face Hub
306
+
307
+ Example:
308
+ ```python
309
+ >>> pool = CausalLMPool(models=..., tokenizer=...)
310
+ >>> model = pool.load_model("my_model")
311
+ >>> pool.save_model(
312
+ ... model,
313
+ ... "/path/to/save",
314
+ ... save_tokenizer=True,
315
+ ... model_dtype="float16",
316
+ ... push_to_hub=True
317
+ ... )
318
+ ```
190
319
  """
191
320
  path = os.path.expanduser(path)
192
321
  # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
@@ -210,15 +339,61 @@ class CausalLMPool(BaseModelPool):
210
339
 
211
340
 
212
341
  class CausalLMBackbonePool(CausalLMPool):
342
+ """A specialized model pool that loads only the transformer backbone layers.
343
+
344
+ This class extends CausalLMPool to provide access to just the transformer
345
+ layers (backbone) of causal language models, excluding the language modeling
346
+ head and embeddings. This is useful for model fusion scenarios where only
347
+ the core transformer layers are needed.
348
+
349
+ The class automatically extracts the `model.layers` component from loaded
350
+ AutoModelForCausalLM instances, providing direct access to the transformer
351
+ blocks. Lazy loading is not supported for this pool type.
352
+
353
+ Note:
354
+ This pool automatically disables lazy loading as it needs to access
355
+ the internal structure of the model to extract the backbone layers.
356
+
357
+ Example:
358
+ ```python
359
+ >>> backbone_pool = CausalLMBackbonePool(
360
+ ... models={"model_a": "microsoft/DialoGPT-medium"},
361
+ ... tokenizer="microsoft/DialoGPT-medium"
362
+ ... )
363
+ >>> layers = backbone_pool.load_model("model_a") # Returns nn.ModuleList of transformer layers
364
+ ```
365
+ """
366
+
213
367
  def load_model(
214
368
  self, model_name_or_config: str | DictConfig, *args, **kwargs
215
369
  ) -> Module:
216
- if self.load_lazy:
370
+ """Load only the transformer backbone layers from a causal language model.
371
+
372
+ This method loads a complete causal language model and then extracts
373
+ only the transformer layers (backbone), discarding the embedding layers
374
+ and language modeling head. This is useful for model fusion scenarios
375
+ where only the core transformer computation is needed.
376
+
377
+ Args:
378
+ model_name_or_config: Either a string model name from the pool
379
+ configuration or a DictConfig with model loading parameters.
380
+ *args: Additional positional arguments passed to the parent load_model method.
381
+ **kwargs: Additional keyword arguments passed to the parent load_model method.
382
+
383
+ Returns:
384
+ Module: The transformer layers (typically a nn.ModuleList) containing
385
+ the core transformer blocks without embeddings or output heads.
386
+
387
+ Note:
388
+ Lazy loading is automatically disabled for this method as it needs
389
+ to access the internal model structure to extract the layers.
390
+ """
391
+ if self.enable_lazy_loading:
217
392
  log.warning(
218
393
  "CausalLMBackbonePool does not support lazy loading. "
219
394
  "Falling back to normal loading."
220
395
  )
221
- self.load_lazy = False
396
+ self.enable_lazy_loading = False
222
397
  model: AutoModelForCausalLM = super().load_model(
223
398
  model_name_or_config, *args, **kwargs
224
399
  )
@@ -232,6 +407,49 @@ def load_peft_causal_lm(
232
407
  is_trainable: bool = True,
233
408
  merge_and_unload: bool = False,
234
409
  ):
410
+ """Load a causal language model with PEFT (Parameter-Efficient Fine-Tuning) adapters.
411
+
412
+ This function loads a base causal language model and applies PEFT adapters
413
+ (such as LoRA, AdaLoRA, or other parameter-efficient fine-tuning methods)
414
+ to create a fine-tuned model. It supports both keeping the adapters separate
415
+ or merging them into the base model.
416
+
417
+ Args:
418
+ base_model_path: Path or identifier for the base causal language model.
419
+ Can be a Hugging Face model name or local path.
420
+ peft_model_path: Path to the PEFT adapter configuration and weights.
421
+ This should contain the adapter_config.json and adapter weights.
422
+ torch_dtype: The torch data type to use for the model. Common options
423
+ include "float16", "bfloat16", "float32". Defaults to "bfloat16".
424
+ is_trainable: Whether the loaded PEFT model should be trainable.
425
+ Set to False for inference-only usage to save memory.
426
+ merge_and_unload: Whether to merge the PEFT adapters into the base model
427
+ and unload the adapter weights. When True, returns a standard
428
+ PreTrainedModel instead of a PeftModel.
429
+
430
+ Returns:
431
+ Union[PeftModel, PreTrainedModel]: The loaded model with PEFT adapters.
432
+ Returns a PeftModel if merge_and_unload is False, or a PreTrainedModel
433
+ if the adapters are merged and unloaded.
434
+
435
+ Example:
436
+ ```python
437
+ >>> # Load model with adapters for training
438
+ >>> model = load_peft_causal_lm(
439
+ ... "microsoft/DialoGPT-medium",
440
+ ... "/path/to/lora/adapters",
441
+ ... is_trainable=True
442
+ ... )
443
+
444
+ >>> # Load and merge adapters for inference
445
+ >>> merged_model = load_peft_causal_lm(
446
+ ... "microsoft/DialoGPT-medium",
447
+ ... "/path/to/lora/adapters",
448
+ ... merge_and_unload=True,
449
+ ... is_trainable=False
450
+ ... )
451
+ ```
452
+ """
235
453
  base_model = AutoModelForCausalLM.from_pretrained(
236
454
  base_model_path, torch_dtype=torch_dtype
237
455
  )
@@ -18,6 +18,48 @@ def load_lora_model(
18
18
  is_trainable: bool = True,
19
19
  merge_and_unload: bool = True,
20
20
  ):
21
+ """Load a sequence-to-sequence model with LoRA (Low-Rank Adaptation) fine-tuning.
22
+
23
+ This function loads a base sequence-to-sequence language model and applies
24
+ LoRA adapters for parameter-efficient fine-tuning. LoRA allows for efficient
25
+ adaptation of large models by adding trainable low-rank matrices to the
26
+ existing weights without modifying the original parameters.
27
+
28
+ Args:
29
+ base_model_path: Path or identifier for the base sequence-to-sequence model.
30
+ Can be a Hugging Face model name (e.g., "t5-base") or local path.
31
+ peft_model_path: Path to the directory containing LoRA adapter weights
32
+ and configuration. Should include adapter_config.json and adapter weights.
33
+ is_trainable: Whether the loaded model should be trainable. Set to False
34
+ for inference-only usage to save memory and computation.
35
+ merge_and_unload: Whether to merge the LoRA weights into the base model
36
+ and unload the adapter. When True, returns a standard model instead
37
+ of a PeftModel, which can be more efficient for inference.
38
+
39
+ Returns:
40
+ Union[PeftModel, AutoModelForSeq2SeqLM]: The loaded model with LoRA
41
+ adapters. Returns a PeftModel if merge_and_unload is False, or
42
+ a standard AutoModelForSeq2SeqLM if adapters are merged.
43
+
44
+ Example:
45
+ ```python
46
+ >>> # Load model with separate adapters for training
47
+ >>> model = load_lora_model(
48
+ ... "t5-base",
49
+ ... "/path/to/lora/adapters",
50
+ ... is_trainable=True,
51
+ ... merge_and_unload=False
52
+ ... )
53
+
54
+ >>> # Load and merge adapters for efficient inference
55
+ >>> merged_model = load_lora_model(
56
+ ... "t5-base",
57
+ ... "/path/to/lora/adapters",
58
+ ... is_trainable=False,
59
+ ... merge_and_unload=True
60
+ ... )
61
+ ```
62
+ """
21
63
  base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path)
22
64
  model = PeftModel.from_pretrained(
23
65
  base_model,
@@ -30,6 +72,46 @@ def load_lora_model(
30
72
 
31
73
 
32
74
  class Seq2SeqLMPool(BaseModelPool):
75
+ """A model pool specialized for sequence-to-sequence language models.
76
+
77
+ This model pool provides management and loading capabilities for sequence-to-sequence
78
+ (seq2seq) language models such as T5, BART, and mT5. It extends the base model pool
79
+ functionality with seq2seq-specific features including tokenizer management and
80
+ model configuration handling.
81
+
82
+ Seq2seq models are particularly useful for tasks that require generating output
83
+ sequences from input sequences, such as translation, summarization, question
84
+ answering, and text generation. This pool streamlines the process of loading
85
+ and configuring multiple seq2seq models for fusion and ensemble scenarios.
86
+
87
+ Key Features:
88
+ - Specialized loading for AutoModelForSeq2SeqLM models
89
+ - Integrated tokenizer management
90
+ - Support for model-specific keyword arguments
91
+ - Automatic dtype parsing and configuration
92
+ - Compatible with PEFT (Parameter-Efficient Fine-Tuning) adapters
93
+
94
+ Attributes:
95
+ _tokenizer: Configuration for the tokenizer associated with the models
96
+ _model_kwargs: Default keyword arguments applied to all model loading operations
97
+
98
+ Example:
99
+ ```python
100
+ pool = Seq2SeqLMPool(
101
+ models={
102
+ "t5_base": "t5-base",
103
+ "t5_large": "t5-large",
104
+ "custom_model": "/path/to/local/model"
105
+ },
106
+ tokenizer={"_target_": "transformers.T5Tokenizer",
107
+ "pretrained_model_name_or_path": "t5-base"},
108
+ model_kwargs={"torch_dtype": "float16", "device_map": "auto"}
109
+ )
110
+ model = pool.load_model("t5_base")
111
+ tokenizer = pool.load_tokenizer()
112
+ ```
113
+ """
114
+
33
115
  _config_mapping = BaseModelPool._config_mapping | {
34
116
  "_tokenizer": "tokenizer",
35
117
  "_model_kwargs": "model_kwargs",
@@ -43,6 +125,35 @@ class Seq2SeqLMPool(BaseModelPool):
43
125
  model_kwargs: Optional[DictConfig] = None,
44
126
  **kwargs,
45
127
  ):
128
+ """Initialize the sequence-to-sequence language model pool.
129
+
130
+ Sets up the model pool with configurations for models, tokenizer, and
131
+ default model loading parameters. Automatically processes model kwargs
132
+ to handle special configurations like torch_dtype parsing.
133
+
134
+ Args:
135
+ models: Configuration dictionary specifying the seq2seq models to manage.
136
+ Keys are model names, values can be model paths/names or detailed configs.
137
+ tokenizer: Configuration for the tokenizer to use with the models.
138
+ Can be a simple path/name or detailed configuration with _target_.
139
+ model_kwargs: Default keyword arguments applied to all model loading
140
+ operations. Common options include torch_dtype, device_map, etc.
141
+ The torch_dtype field is automatically parsed from string to dtype.
142
+ **kwargs: Additional arguments passed to the parent BaseModelPool.
143
+
144
+ Example:
145
+ ```python
146
+ pool = Seq2SeqLMPool(
147
+ models={
148
+ "base": "t5-base",
149
+ "large": {"_target_": "transformers.AutoModelForSeq2SeqLM",
150
+ "pretrained_model_name_or_path": "t5-large"}
151
+ },
152
+ tokenizer="t5-base",
153
+ model_kwargs={"torch_dtype": "bfloat16"}
154
+ )
155
+ ```
156
+ """
46
157
  super().__init__(models, **kwargs)
47
158
  self._tokenizer = tokenizer
48
159
  self._model_kwargs = model_kwargs
@@ -55,11 +166,46 @@ class Seq2SeqLMPool(BaseModelPool):
55
166
  )
56
167
 
57
168
  def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
169
+ """Load a sequence-to-sequence language model from the pool.
170
+
171
+ Loads a seq2seq model using the parent class loading mechanism while
172
+ automatically applying the pool's default model kwargs. The method
173
+ merges the pool's model_kwargs with any additional kwargs provided,
174
+ giving priority to the explicitly provided kwargs.
175
+
176
+ Args:
177
+ model_name_or_config: Either a string model name from the pool
178
+ configuration or a DictConfig containing model loading parameters.
179
+ *args: Additional positional arguments passed to the parent load_model method.
180
+ **kwargs: Additional keyword arguments that override the pool's default
181
+ model_kwargs. Common options include device, torch_dtype, etc.
182
+
183
+ Returns:
184
+ AutoModelForSeq2SeqLM: The loaded sequence-to-sequence language model.
185
+ """
58
186
  model_kwargs = deepcopy(self._model_kwargs)
59
187
  model_kwargs.update(kwargs)
60
188
  return super().load_model(model_name_or_config, *args, **model_kwargs)
61
189
 
62
190
  def load_tokenizer(self, *args, **kwargs):
191
+ """Load the tokenizer associated with the sequence-to-sequence models.
192
+
193
+ Loads a tokenizer based on the tokenizer configuration provided during
194
+ pool initialization. The tokenizer should be compatible with the seq2seq
195
+ models in the pool and is typically used for preprocessing input text
196
+ and postprocessing generated output.
197
+
198
+ Args:
199
+ *args: Additional positional arguments passed to the tokenizer constructor.
200
+ **kwargs: Additional keyword arguments passed to the tokenizer constructor.
201
+
202
+ Returns:
203
+ PreTrainedTokenizer: The loaded tokenizer instance compatible with
204
+ the seq2seq models in this pool.
205
+
206
+ Raises:
207
+ AssertionError: If no tokenizer configuration is provided.
208
+ """
63
209
  assert self._tokenizer is not None, "Tokenizer is not defined in the config"
64
210
  tokenizer = isinstance(self._tokenizer, *args, **kwargs)
65
211
  return tokenizer
@@ -2,4 +2,9 @@
2
2
  from fusion_bench.utils import LazyStateDict
3
3
 
4
4
  from . import separate_io, utils
5
+ from .hf_utils import (
6
+ create_default_model_card,
7
+ load_model_card_template,
8
+ save_pretrained_with_remote_code,
9
+ )
5
10
  from .parameter_dict import ParameterDictModel