fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ from inspect import Parameter, _ParameterKind
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Dict, Mapping, Optional, Union
|
|
8
8
|
|
|
9
|
+
from bidict import MutableBidict, bidict
|
|
9
10
|
from omegaconf import DictConfig, OmegaConf
|
|
10
11
|
|
|
11
12
|
from fusion_bench.constants import FUSION_BENCH_VERSION
|
|
@@ -15,22 +16,29 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
|
15
16
|
log = logging.getLogger(__name__)
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
|
-
"YAMLSerializationMixin",
|
|
19
19
|
"auto_register_config",
|
|
20
|
+
"YAMLSerializationMixin",
|
|
20
21
|
"BaseYAMLSerializable",
|
|
21
22
|
]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
def _set_attr(self, param_name: str, value):
|
|
26
|
+
"""
|
|
27
|
+
Set an attribute on the object using the parameter name from config mapping.
|
|
28
|
+
|
|
29
|
+
This function looks up the corresponding attribute name for the given parameter
|
|
30
|
+
name using the object's _config_mapping, then sets that attribute to the
|
|
31
|
+
specified value. It also logs the operation for debugging purposes.
|
|
30
32
|
|
|
33
|
+
Args:
|
|
34
|
+
self: The object instance to set the attribute on.
|
|
35
|
+
param_name (str): The parameter name (config key) to map to an attribute.
|
|
36
|
+
value: The value to assign to the attribute.
|
|
31
37
|
|
|
32
|
-
|
|
33
|
-
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the parameter name is not found in the config mapping.
|
|
40
|
+
"""
|
|
41
|
+
attr_name = self._config_mapping.inverse[param_name]
|
|
34
42
|
log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
|
|
35
43
|
setattr(self, attr_name, value)
|
|
36
44
|
|
|
@@ -59,37 +67,16 @@ def auto_register_config(cls):
|
|
|
59
67
|
functionality and modified __init__ behavior.
|
|
60
68
|
|
|
61
69
|
Behavior:
|
|
62
|
-
- **Parameter Registration**: All non-variadic parameters (excluding
|
|
70
|
+
- **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
|
|
63
71
|
from the __init__ method are automatically added to _config_mapping
|
|
64
72
|
- **Positional Arguments**: Handled in order and mapped to corresponding parameter names
|
|
65
73
|
- **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
|
|
66
74
|
- **Default Values**: Applied when parameters are not provided via arguments
|
|
67
75
|
- **Attribute Setting**: All parameters become instance attributes accessible via dot notation
|
|
68
76
|
|
|
69
|
-
Example:
|
|
70
|
-
```python
|
|
71
|
-
@auto_register_config
|
|
72
|
-
class MyAlgorithm(BaseYAMLSerializable):
|
|
73
|
-
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
|
|
74
|
-
super().__init__(**kwargs)
|
|
75
|
-
|
|
76
|
-
# All instantiation methods work automatically:
|
|
77
|
-
algo1 = MyAlgorithm(0.01, 64) # positional args
|
|
78
|
-
algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
|
|
79
|
-
algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
|
|
80
|
-
|
|
81
|
-
# Attributes are automatically set and can be serialized:
|
|
82
|
-
print(algo1.learning_rate) # 0.01
|
|
83
|
-
print(algo1.batch_size) # 64
|
|
84
|
-
print(algo1.model_name) # "default" (from default value)
|
|
85
|
-
|
|
86
|
-
config = algo1.config
|
|
87
|
-
# DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
|
|
88
|
-
```
|
|
89
|
-
|
|
90
77
|
Note:
|
|
91
78
|
- The decorator wraps the original __init__ method while preserving its signature for IDE support
|
|
92
|
-
- Parameters with
|
|
79
|
+
- Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
|
|
93
80
|
- The attributes are auto-registered, then the original __init__ method is called,
|
|
94
81
|
- Type hints, method name, and other metadata are preserved using functools.wraps
|
|
95
82
|
- This decorator is designed to work seamlessly with the YAML serialization system
|
|
@@ -103,7 +90,10 @@ def auto_register_config(cls):
|
|
|
103
90
|
|
|
104
91
|
# Auto-register parameters in _config_mapping
|
|
105
92
|
if not "_config_mapping" in cls.__dict__:
|
|
106
|
-
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping",
|
|
93
|
+
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
|
|
94
|
+
if not isinstance(cls._config_mapping, bidict):
|
|
95
|
+
cls._config_mapping = bidict(cls._config_mapping)
|
|
96
|
+
|
|
107
97
|
registered_parameters = tuple(cls._config_mapping.values())
|
|
108
98
|
|
|
109
99
|
for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
|
|
@@ -116,6 +106,7 @@ def auto_register_config(cls):
|
|
|
116
106
|
) and (param_name not in registered_parameters):
|
|
117
107
|
cls._config_mapping[param_name] = param_name
|
|
118
108
|
|
|
109
|
+
@wraps(original_init)
|
|
119
110
|
def __init__(self, *args, **kwargs):
|
|
120
111
|
log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
|
|
121
112
|
# auto-register the attributes based on the signature
|
|
@@ -162,33 +153,10 @@ def auto_register_config(cls):
|
|
|
162
153
|
|
|
163
154
|
class YAMLSerializationMixin:
|
|
164
155
|
_config_key: Optional[str] = None
|
|
165
|
-
_config_mapping:
|
|
156
|
+
_config_mapping: MutableBidict[str, str] = bidict()
|
|
166
157
|
R"""
|
|
167
158
|
`_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
|
|
168
159
|
|
|
169
|
-
For example, if an algorithm class is defined as follows:
|
|
170
|
-
|
|
171
|
-
```python
|
|
172
|
-
class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
173
|
-
hyper_parameter_1 = None
|
|
174
|
-
hyper_parameter_2 = None
|
|
175
|
-
|
|
176
|
-
_config_mapping = BaseModelFusionAlgorithm._config_mapping | {
|
|
177
|
-
"hyper_parameter_1" : "hyper_param_1",
|
|
178
|
-
"hyper_parameter_2" : "hyper_param_2",
|
|
179
|
-
}
|
|
180
|
-
def __init__(self, hyper_param_1: int, hyper_param_2: int):
|
|
181
|
-
self.hyper_parameter_1 = hyper_param_1
|
|
182
|
-
self.hyper_parameter_2 = hyper_param_2
|
|
183
|
-
super().__init__()
|
|
184
|
-
```
|
|
185
|
-
|
|
186
|
-
The model pool will be converted to a DictConfig as follows:
|
|
187
|
-
|
|
188
|
-
```python
|
|
189
|
-
algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
190
|
-
```
|
|
191
|
-
|
|
192
160
|
>>> algorithm.config
|
|
193
161
|
DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
194
162
|
|
|
@@ -207,17 +175,6 @@ class YAMLSerializationMixin:
|
|
|
207
175
|
This property converts the model pool instance into a dictionary
|
|
208
176
|
configuration, which can be used for serialization or other purposes.
|
|
209
177
|
|
|
210
|
-
Example:
|
|
211
|
-
|
|
212
|
-
```python
|
|
213
|
-
model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
214
|
-
config = model.config
|
|
215
|
-
print(config)
|
|
216
|
-
# DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
217
|
-
```
|
|
218
|
-
|
|
219
|
-
This is useful for serializing the object to a YAML file or for debugging.
|
|
220
|
-
|
|
221
178
|
Returns:
|
|
222
179
|
DictConfig: The configuration of the model pool.
|
|
223
180
|
"""
|
|
@@ -282,16 +239,6 @@ class YAMLSerializationMixin:
|
|
|
282
239
|
serialization. This is how the attribute will appear in YAML output.
|
|
283
240
|
value: The value to assign to the attribute.
|
|
284
241
|
|
|
285
|
-
Example:
|
|
286
|
-
```python
|
|
287
|
-
model = BaseYAMLSerializable()
|
|
288
|
-
model.set_option("learning_rate", "lr", 0.001)
|
|
289
|
-
|
|
290
|
-
# This sets model.learning_rate = 0.001
|
|
291
|
-
# and maps it to "lr" in the config output
|
|
292
|
-
config = model.config
|
|
293
|
-
# config will contain: {"lr": 0.001, ...}
|
|
294
|
-
```
|
|
295
242
|
"""
|
|
296
243
|
setattr(self, attr_name, value)
|
|
297
244
|
self._config_mapping[attr_name] = param_name
|
|
@@ -9,27 +9,33 @@ __all__ = ["SimpleProfilerMixin"]
|
|
|
9
9
|
|
|
10
10
|
class SimpleProfilerMixin:
|
|
11
11
|
"""
|
|
12
|
-
A mixin class that provides simple profiling capabilities.
|
|
12
|
+
A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.
|
|
13
13
|
|
|
14
|
-
This mixin allows for easy profiling of code blocks using a context manager
|
|
15
|
-
|
|
16
|
-
a summary of the profiling results.
|
|
14
|
+
This mixin allows for easy profiling of code blocks using a context manager or manual
|
|
15
|
+
start/stop methods. It measures the execution time of named actions and provides
|
|
16
|
+
a summary of the profiling results. Unlike statistical profilers, this provides
|
|
17
|
+
precise timing measurements for specific code blocks.
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def do_something(self):
|
|
23
|
-
with self.profile("work"):
|
|
24
|
-
# do some work here
|
|
25
|
-
...
|
|
26
|
-
with self.profile("more work"):
|
|
27
|
-
# do more work here
|
|
28
|
-
...
|
|
19
|
+
Note:
|
|
20
|
+
This mixin uses Lightning's SimpleProfiler which measures wall-clock time
|
|
21
|
+
for named actions. It's suitable for timing discrete operations rather than
|
|
22
|
+
detailed function-level profiling.
|
|
29
23
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
24
|
+
Examples:
|
|
25
|
+
```python
|
|
26
|
+
class MyClass(SimpleProfilerMixin):
|
|
27
|
+
def do_something(self):
|
|
28
|
+
with self.profile("data_loading"):
|
|
29
|
+
# Load data here
|
|
30
|
+
data = load_data()
|
|
31
|
+
|
|
32
|
+
with self.profile("model_training"):
|
|
33
|
+
# Train model here
|
|
34
|
+
model.train(data)
|
|
35
|
+
|
|
36
|
+
# Print the profiling summary
|
|
37
|
+
self.print_profile_summary("Training Profile")
|
|
38
|
+
```
|
|
33
39
|
|
|
34
40
|
Attributes:
|
|
35
41
|
_profiler (SimpleProfiler): An instance of the SimpleProfiler class used for profiling.
|
|
@@ -38,7 +44,13 @@ class SimpleProfilerMixin:
|
|
|
38
44
|
_profiler: SimpleProfiler = None
|
|
39
45
|
|
|
40
46
|
@property
|
|
41
|
-
def profiler(self):
|
|
47
|
+
def profiler(self) -> SimpleProfiler:
|
|
48
|
+
"""
|
|
49
|
+
Get the SimpleProfiler instance, creating it if necessary.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
SimpleProfiler: The profiler instance used for timing measurements.
|
|
53
|
+
"""
|
|
42
54
|
# Lazy initialization of the profiler instance
|
|
43
55
|
if self._profiler is None:
|
|
44
56
|
self._profiler = SimpleProfiler()
|
|
@@ -47,14 +59,24 @@ class SimpleProfilerMixin:
|
|
|
47
59
|
@contextmanager
|
|
48
60
|
def profile(self, action_name: str) -> Generator:
|
|
49
61
|
"""
|
|
50
|
-
Context manager for profiling a code block
|
|
62
|
+
Context manager for profiling a code block.
|
|
63
|
+
|
|
64
|
+
This context manager automatically starts profiling when entering the block
|
|
65
|
+
and stops profiling when exiting the block (even if an exception occurs).
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
action_name: A descriptive name for the action being profiled.
|
|
69
|
+
This name will appear in the profiling summary.
|
|
70
|
+
|
|
71
|
+
Yields:
|
|
72
|
+
str: The action name that was provided.
|
|
51
73
|
|
|
52
74
|
Example:
|
|
53
75
|
|
|
54
76
|
```python
|
|
55
|
-
with self.profile("
|
|
56
|
-
#
|
|
57
|
-
|
|
77
|
+
with self.profile("data_processing"):
|
|
78
|
+
# Process data here
|
|
79
|
+
result = process_large_dataset()
|
|
58
80
|
```
|
|
59
81
|
"""
|
|
60
82
|
try:
|
|
@@ -64,18 +86,79 @@ class SimpleProfilerMixin:
|
|
|
64
86
|
self.stop_profile(action_name)
|
|
65
87
|
|
|
66
88
|
def start_profile(self, action_name: str):
|
|
89
|
+
"""
|
|
90
|
+
Start profiling for a named action.
|
|
91
|
+
|
|
92
|
+
This method begins timing for the specified action. You must call
|
|
93
|
+
stop_profile() with the same action name to complete the measurement.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
action_name: A descriptive name for the action being profiled.
|
|
97
|
+
This name will appear in the profiling summary.
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
```python
|
|
101
|
+
self.start_profile("model_inference")
|
|
102
|
+
result = model.predict(data)
|
|
103
|
+
self.stop_profile("model_inference")
|
|
104
|
+
```
|
|
105
|
+
"""
|
|
67
106
|
self.profiler.start(action_name)
|
|
68
107
|
|
|
69
108
|
def stop_profile(self, action_name: str):
|
|
109
|
+
"""
|
|
110
|
+
Stop profiling for a named action.
|
|
111
|
+
|
|
112
|
+
This method ends timing for the specified action that was previously
|
|
113
|
+
started with start_profile().
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
action_name: The name of the action to stop profiling.
|
|
117
|
+
Must match the name used in start_profile().
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
```python
|
|
121
|
+
self.start_profile("data_loading")
|
|
122
|
+
data = load_data()
|
|
123
|
+
self.stop_profile("data_loading")
|
|
124
|
+
```
|
|
125
|
+
"""
|
|
70
126
|
self.profiler.stop(action_name)
|
|
71
127
|
|
|
72
128
|
@rank_zero_only
|
|
73
129
|
def print_profile_summary(self, title: Optional[str] = None):
|
|
130
|
+
"""
|
|
131
|
+
Print a summary of all profiled actions.
|
|
132
|
+
|
|
133
|
+
This method outputs a formatted summary showing the timing information
|
|
134
|
+
for all actions that have been profiled. The output includes action names
|
|
135
|
+
and their execution times.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
title: Optional title to print before the profiling summary.
|
|
139
|
+
If provided, this will be printed as a header.
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
This method is decorated with @rank_zero_only, meaning it will only
|
|
143
|
+
execute on the main process in distributed training scenarios.
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
```python
|
|
147
|
+
# After profiling some actions
|
|
148
|
+
self.print_profile_summary("Training Performance Summary")
|
|
149
|
+
```
|
|
150
|
+
"""
|
|
74
151
|
if title is not None:
|
|
75
152
|
print(title)
|
|
76
153
|
print(self.profiler.summary())
|
|
77
154
|
|
|
78
155
|
def __del__(self):
|
|
156
|
+
"""
|
|
157
|
+
Cleanup when the object is destroyed.
|
|
158
|
+
|
|
159
|
+
Ensures that the profiler instance is properly cleaned up to prevent
|
|
160
|
+
memory leaks when the mixin instance is garbage collected.
|
|
161
|
+
"""
|
|
79
162
|
if self._profiler is not None:
|
|
80
163
|
del self._profiler
|
|
81
164
|
self._profiler = None
|
|
@@ -18,6 +18,7 @@ _import_structure = {
|
|
|
18
18
|
"GPT2ForSequenceClassificationPool",
|
|
19
19
|
],
|
|
20
20
|
"seq_classification_lm": ["SequenceClassificationModelPool"],
|
|
21
|
+
"resnet_for_image_classification": ["ResNetForImageClassificationPool"],
|
|
21
22
|
}
|
|
22
23
|
|
|
23
24
|
|
|
@@ -33,6 +34,7 @@ if TYPE_CHECKING:
|
|
|
33
34
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
34
35
|
from .openclip_vision import OpenCLIPVisionModelPool
|
|
35
36
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
37
|
+
from .resnet_for_image_classification import ResNetForImageClassificationPool
|
|
36
38
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
37
39
|
from .seq_classification_lm import SequenceClassificationModelPool
|
|
38
40
|
|
|
@@ -180,26 +180,59 @@ class BaseModelPool(
|
|
|
180
180
|
|
|
181
181
|
Args:
|
|
182
182
|
model_name_or_config (Union[str, DictConfig]): The model name or configuration.
|
|
183
|
+
- If str: should be a key in self._models
|
|
184
|
+
- If DictConfig: should be a configuration dict for instantiation
|
|
185
|
+
*args: Additional positional arguments passed to model instantiation.
|
|
186
|
+
**kwargs: Additional keyword arguments passed to model instantiation.
|
|
183
187
|
|
|
184
188
|
Returns:
|
|
185
|
-
nn.Module: The instantiated model.
|
|
189
|
+
nn.Module: The instantiated or retrieved model.
|
|
186
190
|
"""
|
|
187
191
|
log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
192
|
+
|
|
193
|
+
if isinstance(model_name_or_config, str):
|
|
194
|
+
model_name = model_name_or_config
|
|
195
|
+
# Handle string model names - lookup in the model pool
|
|
196
|
+
if model_name not in self._models:
|
|
197
|
+
raise KeyError(
|
|
198
|
+
f"Model '{model_name}' not found in model pool. "
|
|
199
|
+
f"Available models: {list(self._models.keys())}"
|
|
200
|
+
)
|
|
201
|
+
model_config = self._models[model_name]
|
|
202
|
+
|
|
203
|
+
# Handle different types of model configurations
|
|
204
|
+
match model_config:
|
|
205
|
+
case dict() | DictConfig() as config:
|
|
206
|
+
# Configuration that needs instantiation
|
|
207
|
+
log.debug(f"Instantiating model '{model_name}' from configuration")
|
|
208
|
+
return instantiate(config, *args, **kwargs)
|
|
209
|
+
|
|
210
|
+
case nn.Module() as model:
|
|
211
|
+
# Pre-instantiated model - return directly
|
|
212
|
+
log.debug(
|
|
213
|
+
f"Returning pre-instantiated model '{model_name}' of type {type(model)}"
|
|
214
|
+
)
|
|
215
|
+
return model
|
|
216
|
+
|
|
217
|
+
case _:
|
|
218
|
+
# Unsupported model configuration type
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Unsupported model configuration type for '{model_name}': {type(model_config)}. "
|
|
221
|
+
f"Expected nn.Module, dict, or DictConfig."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
elif isinstance(model_name_or_config, (dict, DictConfig)):
|
|
225
|
+
# Direct configuration - instantiate directly
|
|
226
|
+
log.debug("Instantiating model from direct DictConfig")
|
|
227
|
+
model_config = model_name_or_config
|
|
228
|
+
return instantiate(model_config, *args, **kwargs)
|
|
229
|
+
|
|
197
230
|
else:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
f"
|
|
231
|
+
# Unsupported input type
|
|
232
|
+
raise TypeError(
|
|
233
|
+
f"Unsupported input type: {type(model_name_or_config)}. "
|
|
234
|
+
f"Expected str or DictConfig."
|
|
201
235
|
)
|
|
202
|
-
return model
|
|
203
236
|
|
|
204
237
|
def load_pretrained_model(self, *args, **kwargs):
|
|
205
238
|
assert (
|
|
@@ -229,6 +262,36 @@ class BaseModelPool(
|
|
|
229
262
|
for model_name in self.model_names:
|
|
230
263
|
yield model_name, self.load_model(model_name)
|
|
231
264
|
|
|
265
|
+
@property
|
|
266
|
+
def has_train_dataset(self) -> bool:
|
|
267
|
+
"""
|
|
268
|
+
Check if the model pool contains training datasets.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
bool: True if training datasets are available, False otherwise.
|
|
272
|
+
"""
|
|
273
|
+
return self._train_datasets is not None and len(self._train_datasets) > 0
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def has_val_dataset(self) -> bool:
|
|
277
|
+
"""
|
|
278
|
+
Check if the model pool contains validation datasets.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
bool: True if validation datasets are available, False otherwise.
|
|
282
|
+
"""
|
|
283
|
+
return self._val_datasets is not None and len(self._val_datasets) > 0
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def has_test_dataset(self) -> bool:
|
|
287
|
+
"""
|
|
288
|
+
Check if the model pool contains testing datasets.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
bool: True if testing datasets are available, False otherwise.
|
|
292
|
+
"""
|
|
293
|
+
return self._test_datasets is not None and len(self._test_datasets) > 0
|
|
294
|
+
|
|
232
295
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
|
|
233
296
|
"""
|
|
234
297
|
Load the training dataset for the specified model.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/
|
|
2
|
+
Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/llm
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import logging
|
|
@@ -26,6 +26,7 @@ from fusion_bench import (
|
|
|
26
26
|
instantiate,
|
|
27
27
|
parse_dtype,
|
|
28
28
|
)
|
|
29
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
29
30
|
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
30
31
|
|
|
31
32
|
log = logging.getLogger(__name__)
|
|
@@ -271,13 +272,16 @@ class CausalLMPool(BaseModelPool):
|
|
|
271
272
|
save_tokenizer: bool = False,
|
|
272
273
|
tokenizer_kwargs=None,
|
|
273
274
|
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
275
|
+
algorithm_config: Optional[DictConfig] = None,
|
|
276
|
+
description: Optional[str] = None,
|
|
277
|
+
base_model_in_modelcard: bool = True,
|
|
274
278
|
**kwargs,
|
|
275
279
|
):
|
|
276
280
|
"""Save a model to the specified path with optional tokenizer and Hub upload.
|
|
277
281
|
|
|
278
282
|
This method provides comprehensive model saving capabilities including
|
|
279
|
-
optional tokenizer saving, dtype conversion,
|
|
280
|
-
The model is saved in the standard Hugging Face format.
|
|
283
|
+
optional tokenizer saving, dtype conversion, model card creation, and
|
|
284
|
+
Hugging Face Hub upload. The model is saved in the standard Hugging Face format.
|
|
281
285
|
|
|
282
286
|
Args:
|
|
283
287
|
model: The PreTrainedModel instance to be saved.
|
|
@@ -295,15 +299,13 @@ class CausalLMPool(BaseModelPool):
|
|
|
295
299
|
when save_tokenizer is True.
|
|
296
300
|
tokenizer: Optional pre-loaded tokenizer instance. If provided, this
|
|
297
301
|
tokenizer will be saved regardless of the save_tokenizer flag.
|
|
302
|
+
algorithm_config: Optional DictConfig containing algorithm configuration.
|
|
303
|
+
If provided, a model card will be created with algorithm details.
|
|
304
|
+
description: Optional description for the model card. If not provided
|
|
305
|
+
and algorithm_config is given, a default description will be generated.
|
|
298
306
|
**kwargs: Additional keyword arguments passed to the model's
|
|
299
307
|
save_pretrained method.
|
|
300
308
|
|
|
301
|
-
Side Effects:
|
|
302
|
-
- Creates model files in the specified directory
|
|
303
|
-
- Optionally creates tokenizer files in the same directory
|
|
304
|
-
- May convert the model to a different dtype
|
|
305
|
-
- May upload files to Hugging Face Hub
|
|
306
|
-
|
|
307
309
|
Example:
|
|
308
310
|
```python
|
|
309
311
|
>>> pool = CausalLMPool(models=..., tokenizer=...)
|
|
@@ -313,7 +315,9 @@ class CausalLMPool(BaseModelPool):
|
|
|
313
315
|
... "/path/to/save",
|
|
314
316
|
... save_tokenizer=True,
|
|
315
317
|
... model_dtype="float16",
|
|
316
|
-
... push_to_hub=True
|
|
318
|
+
... push_to_hub=True,
|
|
319
|
+
... algorithm_config=algorithm_config,
|
|
320
|
+
... description="Custom merged model"
|
|
317
321
|
... )
|
|
318
322
|
```
|
|
319
323
|
"""
|
|
@@ -337,6 +341,24 @@ class CausalLMPool(BaseModelPool):
|
|
|
337
341
|
**kwargs,
|
|
338
342
|
)
|
|
339
343
|
|
|
344
|
+
# Create and save model card if algorithm_config is provided
|
|
345
|
+
if algorithm_config is not None:
|
|
346
|
+
if description is None:
|
|
347
|
+
description = "Model created using FusionBench."
|
|
348
|
+
model_card_str = create_default_model_card(
|
|
349
|
+
base_model=(
|
|
350
|
+
self.get_model_path("_pretrained_")
|
|
351
|
+
if base_model_in_modelcard and self.has_pretrained
|
|
352
|
+
else None
|
|
353
|
+
),
|
|
354
|
+
models=[self.get_model_path(m) for m in self.model_names],
|
|
355
|
+
description=description,
|
|
356
|
+
algorithm_config=algorithm_config,
|
|
357
|
+
modelpool_config=self.config,
|
|
358
|
+
)
|
|
359
|
+
with open(os.path.join(path, "README.md"), "w") as f:
|
|
360
|
+
f.write(model_card_str)
|
|
361
|
+
|
|
340
362
|
|
|
341
363
|
class CausalLMBackbonePool(CausalLMPool):
|
|
342
364
|
"""A specialized model pool that loads only the transformer backbone layers.
|
|
@@ -93,42 +93,79 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
93
93
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
94
94
|
) -> CLIPVisionModel:
|
|
95
95
|
"""
|
|
96
|
-
|
|
96
|
+
Load a CLIPVisionModel from the model pool with support for various configuration formats.
|
|
97
97
|
|
|
98
|
-
|
|
98
|
+
This method provides flexible model loading capabilities, handling different types of model
|
|
99
|
+
configurations including string paths, pre-instantiated models, and complex configurations.
|
|
99
100
|
|
|
101
|
+
Supported configuration formats:
|
|
102
|
+
1. String model paths (e.g., Hugging Face model IDs)
|
|
103
|
+
2. Pre-instantiated nn.Module objects
|
|
104
|
+
3. DictConfig objects for complex configurations
|
|
105
|
+
|
|
106
|
+
Example configuration:
|
|
100
107
|
```yaml
|
|
101
108
|
models:
|
|
109
|
+
# Simple string paths to Hugging Face models
|
|
102
110
|
cifar10: tanganke/clip-vit-base-patch32_cifar10
|
|
103
111
|
sun397: tanganke/clip-vit-base-patch32_sun397
|
|
104
112
|
stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
|
|
113
|
+
|
|
114
|
+
# Complex configuration with additional parameters
|
|
115
|
+
custom_model:
|
|
116
|
+
_target_: transformers.CLIPVisionModel.from_pretrained
|
|
117
|
+
pretrained_model_name_or_path: openai/clip-vit-base-patch32
|
|
118
|
+
torch_dtype: float16
|
|
105
119
|
```
|
|
106
120
|
|
|
107
121
|
Args:
|
|
108
|
-
model_name_or_config (Union[str, DictConfig]):
|
|
122
|
+
model_name_or_config (Union[str, DictConfig]): Either a model name from the pool
|
|
123
|
+
or a configuration dictionary for instantiating the model.
|
|
124
|
+
*args: Additional positional arguments passed to model loading/instantiation.
|
|
125
|
+
**kwargs: Additional keyword arguments passed to model loading/instantiation.
|
|
109
126
|
|
|
110
127
|
Returns:
|
|
111
|
-
CLIPVisionModel: The loaded CLIPVisionModel.
|
|
128
|
+
CLIPVisionModel: The loaded CLIPVisionModel instance.
|
|
112
129
|
"""
|
|
130
|
+
# Check if we have a string model name that exists in our model pool
|
|
113
131
|
if (
|
|
114
132
|
isinstance(model_name_or_config, str)
|
|
115
133
|
and model_name_or_config in self._models
|
|
116
134
|
):
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
model
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
135
|
+
model_name = model_name_or_config
|
|
136
|
+
|
|
137
|
+
# handle different model configuration types
|
|
138
|
+
match self._models[model_name_or_config]:
|
|
139
|
+
case str() as model_path:
|
|
140
|
+
# Handle string model paths (e.g., Hugging Face model IDs)
|
|
141
|
+
if rank_zero_only.rank == 0:
|
|
142
|
+
log.info(
|
|
143
|
+
f"Loading model `{model_name}` of type `transformers.CLIPVisionModel` from {model_path}"
|
|
144
|
+
)
|
|
145
|
+
# Resolve the repository path (supports both HuggingFace and ModelScope)
|
|
146
|
+
repo_path = resolve_repo_path(
|
|
147
|
+
model_path, repo_type="model", platform=self._platform
|
|
148
|
+
)
|
|
149
|
+
# Load and return the CLIPVisionModel from the resolved path
|
|
150
|
+
return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
|
|
151
|
+
|
|
152
|
+
case nn.Module() as model:
|
|
153
|
+
# Handle pre-instantiated model objects
|
|
154
|
+
if rank_zero_only.rank == 0:
|
|
155
|
+
log.info(
|
|
156
|
+
f"Returning existing model `{model_name}` of type {type(model)}"
|
|
157
|
+
)
|
|
158
|
+
return model
|
|
159
|
+
|
|
160
|
+
case _:
|
|
161
|
+
# Handle other configuration types (e.g., DictConfig) via parent class
|
|
162
|
+
# This fallback prevents returning None when the model config doesn't
|
|
163
|
+
# match the expected string or nn.Module patterns
|
|
164
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
165
|
+
|
|
166
|
+
# If model_name_or_config is not a string in our pool, delegate to parent class
|
|
167
|
+
# This handles cases where model_name_or_config is a DictConfig directly
|
|
168
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
132
169
|
|
|
133
170
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
134
171
|
dataset_config = self._train_datasets[dataset_name]
|