fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from inspect import Parameter, _ParameterKind
6
6
  from pathlib import Path
7
7
  from typing import Dict, Mapping, Optional, Union
8
8
 
9
+ from bidict import MutableBidict, bidict
9
10
  from omegaconf import DictConfig, OmegaConf
10
11
 
11
12
  from fusion_bench.constants import FUSION_BENCH_VERSION
@@ -15,22 +16,29 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
15
16
  log = logging.getLogger(__name__)
16
17
 
17
18
  __all__ = [
18
- "YAMLSerializationMixin",
19
19
  "auto_register_config",
20
+ "YAMLSerializationMixin",
20
21
  "BaseYAMLSerializable",
21
22
  ]
22
23
 
23
24
 
24
- def _get_attr_name(config_mapping: Mapping[str, str], param_name):
25
- for attr_name, p in config_mapping.items():
26
- if p == param_name:
27
- return attr_name
28
- else:
29
- raise ValueError(f"Parameter {param_name} not found in config mapping.")
25
+ def _set_attr(self, param_name: str, value):
26
+ """
27
+ Set an attribute on the object using the parameter name from config mapping.
28
+
29
+ This function looks up the corresponding attribute name for the given parameter
30
+ name using the object's _config_mapping, then sets that attribute to the
31
+ specified value. It also logs the operation for debugging purposes.
30
32
 
33
+ Args:
34
+ self: The object instance to set the attribute on.
35
+ param_name (str): The parameter name (config key) to map to an attribute.
36
+ value: The value to assign to the attribute.
31
37
 
32
- def _set_attr(self, param_name: str, value):
33
- attr_name = _get_attr_name(self._config_mapping, param_name)
38
+ Raises:
39
+ ValueError: If the parameter name is not found in the config mapping.
40
+ """
41
+ attr_name = self._config_mapping.inverse[param_name]
34
42
  log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
35
43
  setattr(self, attr_name, value)
36
44
 
@@ -59,37 +67,16 @@ def auto_register_config(cls):
59
67
  functionality and modified __init__ behavior.
60
68
 
61
69
  Behavior:
62
- - **Parameter Registration**: All non-variadic parameters (excluding *args, **kwargs)
70
+ - **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
63
71
  from the __init__ method are automatically added to _config_mapping
64
72
  - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
65
73
  - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
66
74
  - **Default Values**: Applied when parameters are not provided via arguments
67
75
  - **Attribute Setting**: All parameters become instance attributes accessible via dot notation
68
76
 
69
- Example:
70
- ```python
71
- @auto_register_config
72
- class MyAlgorithm(BaseYAMLSerializable):
73
- def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
74
- super().__init__(**kwargs)
75
-
76
- # All instantiation methods work automatically:
77
- algo1 = MyAlgorithm(0.01, 64) # positional args
78
- algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
79
- algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
80
-
81
- # Attributes are automatically set and can be serialized:
82
- print(algo1.learning_rate) # 0.01
83
- print(algo1.batch_size) # 64
84
- print(algo1.model_name) # "default" (from default value)
85
-
86
- config = algo1.config
87
- # DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
88
- ```
89
-
90
77
  Note:
91
78
  - The decorator wraps the original __init__ method while preserving its signature for IDE support
92
- - Parameters with *args or **kwargs signatures are ignored during registration
79
+ - Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
93
80
  - The attributes are auto-registered, then the original __init__ method is called,
94
81
  - Type hints, method name, and other metadata are preserved using functools.wraps
95
82
  - This decorator is designed to work seamlessly with the YAML serialization system
@@ -103,7 +90,10 @@ def auto_register_config(cls):
103
90
 
104
91
  # Auto-register parameters in _config_mapping
105
92
  if not "_config_mapping" in cls.__dict__:
106
- cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
93
+ cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
94
+ if not isinstance(cls._config_mapping, bidict):
95
+ cls._config_mapping = bidict(cls._config_mapping)
96
+
107
97
  registered_parameters = tuple(cls._config_mapping.values())
108
98
 
109
99
  for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
@@ -116,6 +106,7 @@ def auto_register_config(cls):
116
106
  ) and (param_name not in registered_parameters):
117
107
  cls._config_mapping[param_name] = param_name
118
108
 
109
+ @wraps(original_init)
119
110
  def __init__(self, *args, **kwargs):
120
111
  log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
121
112
  # auto-register the attributes based on the signature
@@ -162,33 +153,10 @@ def auto_register_config(cls):
162
153
 
163
154
  class YAMLSerializationMixin:
164
155
  _config_key: Optional[str] = None
165
- _config_mapping: Dict[str, str] = {}
156
+ _config_mapping: MutableBidict[str, str] = bidict()
166
157
  R"""
167
158
  `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
168
159
 
169
- For example, if an algorithm class is defined as follows:
170
-
171
- ```python
172
- class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
173
- hyper_parameter_1 = None
174
- hyper_parameter_2 = None
175
-
176
- _config_mapping = BaseModelFusionAlgorithm._config_mapping | {
177
- "hyper_parameter_1" : "hyper_param_1",
178
- "hyper_parameter_2" : "hyper_param_2",
179
- }
180
- def __init__(self, hyper_param_1: int, hyper_param_2: int):
181
- self.hyper_parameter_1 = hyper_param_1
182
- self.hyper_parameter_2 = hyper_param_2
183
- super().__init__()
184
- ```
185
-
186
- The model pool will be converted to a DictConfig as follows:
187
-
188
- ```python
189
- algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
190
- ```
191
-
192
160
  >>> algorithm.config
193
161
  DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
194
162
 
@@ -207,17 +175,6 @@ class YAMLSerializationMixin:
207
175
  This property converts the model pool instance into a dictionary
208
176
  configuration, which can be used for serialization or other purposes.
209
177
 
210
- Example:
211
-
212
- ```python
213
- model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
214
- config = model.config
215
- print(config)
216
- # DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
217
- ```
218
-
219
- This is useful for serializing the object to a YAML file or for debugging.
220
-
221
178
  Returns:
222
179
  DictConfig: The configuration of the model pool.
223
180
  """
@@ -282,16 +239,6 @@ class YAMLSerializationMixin:
282
239
  serialization. This is how the attribute will appear in YAML output.
283
240
  value: The value to assign to the attribute.
284
241
 
285
- Example:
286
- ```python
287
- model = BaseYAMLSerializable()
288
- model.set_option("learning_rate", "lr", 0.001)
289
-
290
- # This sets model.learning_rate = 0.001
291
- # and maps it to "lr" in the config output
292
- config = model.config
293
- # config will contain: {"lr": 0.001, ...}
294
- ```
295
242
  """
296
243
  setattr(self, attr_name, value)
297
244
  self._config_mapping[attr_name] = param_name
@@ -9,27 +9,33 @@ __all__ = ["SimpleProfilerMixin"]
9
9
 
10
10
  class SimpleProfilerMixin:
11
11
  """
12
- A mixin class that provides simple profiling capabilities.
12
+ A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.
13
13
 
14
- This mixin allows for easy profiling of code blocks using a context manager.
15
- It also provides methods to start and stop profiling actions, and to print
16
- a summary of the profiling results.
14
+ This mixin allows for easy profiling of code blocks using a context manager or manual
15
+ start/stop methods. It measures the execution time of named actions and provides
16
+ a summary of the profiling results. Unlike statistical profilers, this provides
17
+ precise timing measurements for specific code blocks.
17
18
 
18
- Examples:
19
-
20
- ```python
21
- class MyClass(SimpleProfilerMixin):
22
- def do_something(self):
23
- with self.profile("work"):
24
- # do some work here
25
- ...
26
- with self.profile("more work"):
27
- # do more work here
28
- ...
19
+ Note:
20
+ This mixin uses Lightning's SimpleProfiler which measures wall-clock time
21
+ for named actions. It's suitable for timing discrete operations rather than
22
+ detailed function-level profiling.
29
23
 
30
- # print the profiling summary
31
- self.print_profile_summary()
32
- ```
24
+ Examples:
25
+ ```python
26
+ class MyClass(SimpleProfilerMixin):
27
+ def do_something(self):
28
+ with self.profile("data_loading"):
29
+ # Load data here
30
+ data = load_data()
31
+
32
+ with self.profile("model_training"):
33
+ # Train model here
34
+ model.train(data)
35
+
36
+ # Print the profiling summary
37
+ self.print_profile_summary("Training Profile")
38
+ ```
33
39
 
34
40
  Attributes:
35
41
  _profiler (SimpleProfiler): An instance of the SimpleProfiler class used for profiling.
@@ -38,7 +44,13 @@ class SimpleProfilerMixin:
38
44
  _profiler: SimpleProfiler = None
39
45
 
40
46
  @property
41
- def profiler(self):
47
+ def profiler(self) -> SimpleProfiler:
48
+ """
49
+ Get the SimpleProfiler instance, creating it if necessary.
50
+
51
+ Returns:
52
+ SimpleProfiler: The profiler instance used for timing measurements.
53
+ """
42
54
  # Lazy initialization of the profiler instance
43
55
  if self._profiler is None:
44
56
  self._profiler = SimpleProfiler()
@@ -47,14 +59,24 @@ class SimpleProfilerMixin:
47
59
  @contextmanager
48
60
  def profile(self, action_name: str) -> Generator:
49
61
  """
50
- Context manager for profiling a code block
62
+ Context manager for profiling a code block.
63
+
64
+ This context manager automatically starts profiling when entering the block
65
+ and stops profiling when exiting the block (even if an exception occurs).
66
+
67
+ Args:
68
+ action_name: A descriptive name for the action being profiled.
69
+ This name will appear in the profiling summary.
70
+
71
+ Yields:
72
+ str: The action name that was provided.
51
73
 
52
74
  Example:
53
75
 
54
76
  ```python
55
- with self.profile("work"):
56
- # do some work here
57
- ...
77
+ with self.profile("data_processing"):
78
+ # Process data here
79
+ result = process_large_dataset()
58
80
  ```
59
81
  """
60
82
  try:
@@ -64,18 +86,79 @@ class SimpleProfilerMixin:
64
86
  self.stop_profile(action_name)
65
87
 
66
88
  def start_profile(self, action_name: str):
89
+ """
90
+ Start profiling for a named action.
91
+
92
+ This method begins timing for the specified action. You must call
93
+ stop_profile() with the same action name to complete the measurement.
94
+
95
+ Args:
96
+ action_name: A descriptive name for the action being profiled.
97
+ This name will appear in the profiling summary.
98
+
99
+ Example:
100
+ ```python
101
+ self.start_profile("model_inference")
102
+ result = model.predict(data)
103
+ self.stop_profile("model_inference")
104
+ ```
105
+ """
67
106
  self.profiler.start(action_name)
68
107
 
69
108
  def stop_profile(self, action_name: str):
109
+ """
110
+ Stop profiling for a named action.
111
+
112
+ This method ends timing for the specified action that was previously
113
+ started with start_profile().
114
+
115
+ Args:
116
+ action_name: The name of the action to stop profiling.
117
+ Must match the name used in start_profile().
118
+
119
+ Example:
120
+ ```python
121
+ self.start_profile("data_loading")
122
+ data = load_data()
123
+ self.stop_profile("data_loading")
124
+ ```
125
+ """
70
126
  self.profiler.stop(action_name)
71
127
 
72
128
  @rank_zero_only
73
129
  def print_profile_summary(self, title: Optional[str] = None):
130
+ """
131
+ Print a summary of all profiled actions.
132
+
133
+ This method outputs a formatted summary showing the timing information
134
+ for all actions that have been profiled. The output includes action names
135
+ and their execution times.
136
+
137
+ Args:
138
+ title: Optional title to print before the profiling summary.
139
+ If provided, this will be printed as a header.
140
+
141
+ Note:
142
+ This method is decorated with @rank_zero_only, meaning it will only
143
+ execute on the main process in distributed training scenarios.
144
+
145
+ Example:
146
+ ```python
147
+ # After profiling some actions
148
+ self.print_profile_summary("Training Performance Summary")
149
+ ```
150
+ """
74
151
  if title is not None:
75
152
  print(title)
76
153
  print(self.profiler.summary())
77
154
 
78
155
  def __del__(self):
156
+ """
157
+ Cleanup when the object is destroyed.
158
+
159
+ Ensures that the profiler instance is properly cleaned up to prevent
160
+ memory leaks when the mixin instance is garbage collected.
161
+ """
79
162
  if self._profiler is not None:
80
163
  del self._profiler
81
164
  self._profiler = None
@@ -18,6 +18,7 @@ _import_structure = {
18
18
  "GPT2ForSequenceClassificationPool",
19
19
  ],
20
20
  "seq_classification_lm": ["SequenceClassificationModelPool"],
21
+ "resnet_for_image_classification": ["ResNetForImageClassificationPool"],
21
22
  }
22
23
 
23
24
 
@@ -33,6 +34,7 @@ if TYPE_CHECKING:
33
34
  from .nyuv2_modelpool import NYUv2ModelPool
34
35
  from .openclip_vision import OpenCLIPVisionModelPool
35
36
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
37
+ from .resnet_for_image_classification import ResNetForImageClassificationPool
36
38
  from .seq2seq_lm import Seq2SeqLMPool
37
39
  from .seq_classification_lm import SequenceClassificationModelPool
38
40
 
@@ -180,26 +180,59 @@ class BaseModelPool(
180
180
 
181
181
  Args:
182
182
  model_name_or_config (Union[str, DictConfig]): The model name or configuration.
183
+ - If str: should be a key in self._models
184
+ - If DictConfig: should be a configuration dict for instantiation
185
+ *args: Additional positional arguments passed to model instantiation.
186
+ **kwargs: Additional keyword arguments passed to model instantiation.
183
187
 
184
188
  Returns:
185
- nn.Module: The instantiated model.
189
+ nn.Module: The instantiated or retrieved model.
186
190
  """
187
191
  log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
188
- if isinstance(self._models, DictConfig):
189
- model_config = (
190
- self._models[model_name_or_config]
191
- if isinstance(model_name_or_config, str)
192
- else model_name_or_config
193
- )
194
- model = instantiate(model_config, *args, **kwargs)
195
- elif isinstance(self._models, Dict) and isinstance(model_name_or_config, str):
196
- model = self._models[model_name_or_config]
192
+
193
+ if isinstance(model_name_or_config, str):
194
+ model_name = model_name_or_config
195
+ # Handle string model names - lookup in the model pool
196
+ if model_name not in self._models:
197
+ raise KeyError(
198
+ f"Model '{model_name}' not found in model pool. "
199
+ f"Available models: {list(self._models.keys())}"
200
+ )
201
+ model_config = self._models[model_name]
202
+
203
+ # Handle different types of model configurations
204
+ match model_config:
205
+ case dict() | DictConfig() as config:
206
+ # Configuration that needs instantiation
207
+ log.debug(f"Instantiating model '{model_name}' from configuration")
208
+ return instantiate(config, *args, **kwargs)
209
+
210
+ case nn.Module() as model:
211
+ # Pre-instantiated model - return directly
212
+ log.debug(
213
+ f"Returning pre-instantiated model '{model_name}' of type {type(model)}"
214
+ )
215
+ return model
216
+
217
+ case _:
218
+ # Unsupported model configuration type
219
+ raise ValueError(
220
+ f"Unsupported model configuration type for '{model_name}': {type(model_config)}. "
221
+ f"Expected nn.Module, dict, or DictConfig."
222
+ )
223
+
224
+ elif isinstance(model_name_or_config, (dict, DictConfig)):
225
+ # Direct configuration - instantiate directly
226
+ log.debug("Instantiating model from direct DictConfig")
227
+ model_config = model_name_or_config
228
+ return instantiate(model_config, *args, **kwargs)
229
+
197
230
  else:
198
- raise ValueError(
199
- "The model pool configuration is not in the expected format."
200
- f"We expected a DictConfig or Dict, but got {type(self._models)}."
231
+ # Unsupported input type
232
+ raise TypeError(
233
+ f"Unsupported input type: {type(model_name_or_config)}. "
234
+ f"Expected str or DictConfig."
201
235
  )
202
- return model
203
236
 
204
237
  def load_pretrained_model(self, *args, **kwargs):
205
238
  assert (
@@ -229,6 +262,36 @@ class BaseModelPool(
229
262
  for model_name in self.model_names:
230
263
  yield model_name, self.load_model(model_name)
231
264
 
265
+ @property
266
+ def has_train_dataset(self) -> bool:
267
+ """
268
+ Check if the model pool contains training datasets.
269
+
270
+ Returns:
271
+ bool: True if training datasets are available, False otherwise.
272
+ """
273
+ return self._train_datasets is not None and len(self._train_datasets) > 0
274
+
275
+ @property
276
+ def has_val_dataset(self) -> bool:
277
+ """
278
+ Check if the model pool contains validation datasets.
279
+
280
+ Returns:
281
+ bool: True if validation datasets are available, False otherwise.
282
+ """
283
+ return self._val_datasets is not None and len(self._val_datasets) > 0
284
+
285
+ @property
286
+ def has_test_dataset(self) -> bool:
287
+ """
288
+ Check if the model pool contains testing datasets.
289
+
290
+ Returns:
291
+ bool: True if testing datasets are available, False otherwise.
292
+ """
293
+ return self._test_datasets is not None and len(self._test_datasets) > 0
294
+
232
295
  def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
233
296
  """
234
297
  Load the training dataset for the specified model.
@@ -1,5 +1,5 @@
1
1
  """
2
- Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/causal_lm
2
+ Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/llm
3
3
  """
4
4
 
5
5
  import logging
@@ -26,6 +26,7 @@ from fusion_bench import (
26
26
  instantiate,
27
27
  parse_dtype,
28
28
  )
29
+ from fusion_bench.models.hf_utils import create_default_model_card
29
30
  from fusion_bench.utils.lazy_state_dict import LazyStateDict
30
31
 
31
32
  log = logging.getLogger(__name__)
@@ -271,13 +272,16 @@ class CausalLMPool(BaseModelPool):
271
272
  save_tokenizer: bool = False,
272
273
  tokenizer_kwargs=None,
273
274
  tokenizer: Optional[PreTrainedTokenizer] = None,
275
+ algorithm_config: Optional[DictConfig] = None,
276
+ description: Optional[str] = None,
277
+ base_model_in_modelcard: bool = True,
274
278
  **kwargs,
275
279
  ):
276
280
  """Save a model to the specified path with optional tokenizer and Hub upload.
277
281
 
278
282
  This method provides comprehensive model saving capabilities including
279
- optional tokenizer saving, dtype conversion, and Hugging Face Hub upload.
280
- The model is saved in the standard Hugging Face format.
283
+ optional tokenizer saving, dtype conversion, model card creation, and
284
+ Hugging Face Hub upload. The model is saved in the standard Hugging Face format.
281
285
 
282
286
  Args:
283
287
  model: The PreTrainedModel instance to be saved.
@@ -295,15 +299,13 @@ class CausalLMPool(BaseModelPool):
295
299
  when save_tokenizer is True.
296
300
  tokenizer: Optional pre-loaded tokenizer instance. If provided, this
297
301
  tokenizer will be saved regardless of the save_tokenizer flag.
302
+ algorithm_config: Optional DictConfig containing algorithm configuration.
303
+ If provided, a model card will be created with algorithm details.
304
+ description: Optional description for the model card. If not provided
305
+ and algorithm_config is given, a default description will be generated.
298
306
  **kwargs: Additional keyword arguments passed to the model's
299
307
  save_pretrained method.
300
308
 
301
- Side Effects:
302
- - Creates model files in the specified directory
303
- - Optionally creates tokenizer files in the same directory
304
- - May convert the model to a different dtype
305
- - May upload files to Hugging Face Hub
306
-
307
309
  Example:
308
310
  ```python
309
311
  >>> pool = CausalLMPool(models=..., tokenizer=...)
@@ -313,7 +315,9 @@ class CausalLMPool(BaseModelPool):
313
315
  ... "/path/to/save",
314
316
  ... save_tokenizer=True,
315
317
  ... model_dtype="float16",
316
- ... push_to_hub=True
318
+ ... push_to_hub=True,
319
+ ... algorithm_config=algorithm_config,
320
+ ... description="Custom merged model"
317
321
  ... )
318
322
  ```
319
323
  """
@@ -337,6 +341,24 @@ class CausalLMPool(BaseModelPool):
337
341
  **kwargs,
338
342
  )
339
343
 
344
+ # Create and save model card if algorithm_config is provided
345
+ if algorithm_config is not None:
346
+ if description is None:
347
+ description = "Model created using FusionBench."
348
+ model_card_str = create_default_model_card(
349
+ base_model=(
350
+ self.get_model_path("_pretrained_")
351
+ if base_model_in_modelcard and self.has_pretrained
352
+ else None
353
+ ),
354
+ models=[self.get_model_path(m) for m in self.model_names],
355
+ description=description,
356
+ algorithm_config=algorithm_config,
357
+ modelpool_config=self.config,
358
+ )
359
+ with open(os.path.join(path, "README.md"), "w") as f:
360
+ f.write(model_card_str)
361
+
340
362
 
341
363
  class CausalLMBackbonePool(CausalLMPool):
342
364
  """A specialized model pool that loads only the transformer backbone layers.
@@ -93,42 +93,79 @@ class CLIPVisionModelPool(BaseModelPool):
93
93
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
94
94
  ) -> CLIPVisionModel:
95
95
  """
96
- This method is used to load a CLIPVisionModel from the model pool.
96
+ Load a CLIPVisionModel from the model pool with support for various configuration formats.
97
97
 
98
- Example configuration could be:
98
+ This method provides flexible model loading capabilities, handling different types of model
99
+ configurations including string paths, pre-instantiated models, and complex configurations.
99
100
 
101
+ Supported configuration formats:
102
+ 1. String model paths (e.g., Hugging Face model IDs)
103
+ 2. Pre-instantiated nn.Module objects
104
+ 3. DictConfig objects for complex configurations
105
+
106
+ Example configuration:
100
107
  ```yaml
101
108
  models:
109
+ # Simple string paths to Hugging Face models
102
110
  cifar10: tanganke/clip-vit-base-patch32_cifar10
103
111
  sun397: tanganke/clip-vit-base-patch32_sun397
104
112
  stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
113
+
114
+ # Complex configuration with additional parameters
115
+ custom_model:
116
+ _target_: transformers.CLIPVisionModel.from_pretrained
117
+ pretrained_model_name_or_path: openai/clip-vit-base-patch32
118
+ torch_dtype: float16
105
119
  ```
106
120
 
107
121
  Args:
108
- model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.
122
+ model_name_or_config (Union[str, DictConfig]): Either a model name from the pool
123
+ or a configuration dictionary for instantiating the model.
124
+ *args: Additional positional arguments passed to model loading/instantiation.
125
+ **kwargs: Additional keyword arguments passed to model loading/instantiation.
109
126
 
110
127
  Returns:
111
- CLIPVisionModel: The loaded CLIPVisionModel.
128
+ CLIPVisionModel: The loaded CLIPVisionModel instance.
112
129
  """
130
+ # Check if we have a string model name that exists in our model pool
113
131
  if (
114
132
  isinstance(model_name_or_config, str)
115
133
  and model_name_or_config in self._models
116
134
  ):
117
- model = self._models[model_name_or_config]
118
- if isinstance(model, str):
119
- if rank_zero_only.rank == 0:
120
- log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
121
- repo_path = resolve_repo_path(
122
- model, repo_type="model", platform=self._platform
123
- )
124
- return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
125
- if isinstance(model, nn.Module):
126
- if rank_zero_only.rank == 0:
127
- log.info(f"Returning existing model: {model}")
128
- return model
129
- else:
130
- # If the model is not a string, we use the default load_model method
131
- return super().load_model(model_name_or_config, *args, **kwargs)
135
+ model_name = model_name_or_config
136
+
137
+ # handle different model configuration types
138
+ match self._models[model_name_or_config]:
139
+ case str() as model_path:
140
+ # Handle string model paths (e.g., Hugging Face model IDs)
141
+ if rank_zero_only.rank == 0:
142
+ log.info(
143
+ f"Loading model `{model_name}` of type `transformers.CLIPVisionModel` from {model_path}"
144
+ )
145
+ # Resolve the repository path (supports both HuggingFace and ModelScope)
146
+ repo_path = resolve_repo_path(
147
+ model_path, repo_type="model", platform=self._platform
148
+ )
149
+ # Load and return the CLIPVisionModel from the resolved path
150
+ return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
151
+
152
+ case nn.Module() as model:
153
+ # Handle pre-instantiated model objects
154
+ if rank_zero_only.rank == 0:
155
+ log.info(
156
+ f"Returning existing model `{model_name}` of type {type(model)}"
157
+ )
158
+ return model
159
+
160
+ case _:
161
+ # Handle other configuration types (e.g., DictConfig) via parent class
162
+ # This fallback prevents returning None when the model config doesn't
163
+ # match the expected string or nn.Module patterns
164
+ return super().load_model(model_name_or_config, *args, **kwargs)
165
+
166
+ # If model_name_or_config is not a string in our pool, delegate to parent class
167
+ # This handles cases where model_name_or_config is a DictConfig directly
168
+ return super().load_model(model_name_or_config, *args, **kwargs)
132
169
 
133
170
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
134
171
  dataset_config = self._train_datasets[dataset_name]