fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
|
@@ -11,30 +11,81 @@ from numpy.typing import NDArray
|
|
|
11
11
|
from torch import nn
|
|
12
12
|
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
|
-
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
15
|
-
from fusion_bench.mixins import
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
state_dict_to_vector,
|
|
20
|
-
trainable_state_dict,
|
|
14
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, StateDictType, timeit_context
|
|
15
|
+
from fusion_bench.mixins import (
|
|
16
|
+
LightningFabricMixin,
|
|
17
|
+
SimpleProfilerMixin,
|
|
18
|
+
auto_register_config,
|
|
21
19
|
)
|
|
20
|
+
from fusion_bench.utils import state_dict_to_vector, trainable_state_dict
|
|
22
21
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
23
22
|
|
|
24
23
|
log = logging.getLogger(__name__)
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
@auto_register_config
|
|
27
|
+
class TaskVectorViolinPlot(
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
SimpleProfilerMixin,
|
|
30
|
+
BaseAlgorithm,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Creates violin plots to visualize the distribution of task vector values across models.
|
|
34
|
+
|
|
35
|
+
This class implements the task vector visualization technique described in:
|
|
36
|
+
"Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
|
|
37
|
+
by L. Shen, A. Tang, E. Yang et al. (https://arxiv.org/abs/2410.21804)
|
|
38
|
+
|
|
39
|
+
Task vectors represent the parameter differences between fine-tuned models and their
|
|
40
|
+
pretrained base model, computed as:
|
|
41
|
+
task_vector = finetuned_params - pretrained_params
|
|
42
|
+
|
|
43
|
+
The algorithm generates two types of violin plots:
|
|
44
|
+
1. Distribution of raw task vector values (positive and negative)
|
|
45
|
+
2. Distribution of absolute task vector values
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
trainable_only (bool): If True, only consider trainable parameters when computing
|
|
49
|
+
task vectors. If False, use all parameters.
|
|
50
|
+
max_points_per_model (int, optional): Maximum number of parameters to sample
|
|
51
|
+
per model for memory efficiency. If None or 0, uses all parameters.
|
|
52
|
+
Defaults to 1000.
|
|
53
|
+
fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
|
|
54
|
+
matplotlib.pyplot.subplots. Common options include:
|
|
55
|
+
- figsize: Tuple of (width, height) in inches
|
|
56
|
+
- dpi: Dots per inch for resolution
|
|
57
|
+
- facecolor: Figure background color
|
|
58
|
+
Defaults to None.
|
|
59
|
+
output_path (str, optional): Directory to save the violin plots. If None,
|
|
60
|
+
uses the fabric logger's log directory. Defaults to None.
|
|
61
|
+
|
|
62
|
+
Outputs:
|
|
63
|
+
- task_vector_violin.pdf: Violin plot of raw task vector value distributions
|
|
64
|
+
- task_vector_violin_abs.pdf: Violin plot of absolute task vector value distributions
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
The pretrained model from the model pool.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
```python
|
|
71
|
+
plotter = TaskVectorViolinPlot(
|
|
72
|
+
trainable_only=True,
|
|
73
|
+
max_points_per_model=5000,
|
|
74
|
+
fig_kwargs={'figsize': (12, 8), 'dpi': 300},
|
|
75
|
+
output_path='./analysis_plots'
|
|
76
|
+
)
|
|
77
|
+
pretrained_model = plotter.run(modelpool)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
This visualization is particularly useful for understanding:
|
|
82
|
+
- How different tasks affect model parameters
|
|
83
|
+
- The magnitude and distribution of parameter changes
|
|
84
|
+
- Similarities and differences between task adaptations
|
|
31
85
|
"""
|
|
32
86
|
|
|
33
87
|
# config_mapping is a mapping from the attributes to the key in the configuration files
|
|
34
88
|
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
35
|
-
"trainable_only": "trainable_only",
|
|
36
|
-
"max_points_per_model": "max_points_per_model",
|
|
37
|
-
"fig_kwargs": "fig_kwargs",
|
|
38
89
|
"_output_path": "output_path",
|
|
39
90
|
}
|
|
40
91
|
|
|
@@ -46,40 +97,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
46
97
|
output_path: Optional[str] = None,
|
|
47
98
|
**kwargs,
|
|
48
99
|
):
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
between fine-tuned models and their pretrained base model.
|
|
100
|
+
"""
|
|
101
|
+
Initialize the TaskVectorViolinPlot analyzer.
|
|
52
102
|
|
|
53
103
|
Args:
|
|
54
|
-
trainable_only (bool):
|
|
55
|
-
task vectors.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
plotter.run(modelpool)
|
|
76
|
-
```
|
|
104
|
+
trainable_only (bool): Whether to consider only trainable parameters when
|
|
105
|
+
computing task vectors. Set to True to focus on learnable parameters,
|
|
106
|
+
False to include all parameters including frozen ones.
|
|
107
|
+
max_points_per_model (int, optional): Maximum number of parameter values
|
|
108
|
+
to sample per model for visualization. Useful for large models to
|
|
109
|
+
manage memory usage and plot clarity. Set to None or 0 to use all
|
|
110
|
+
parameters. Defaults to 1000.
|
|
111
|
+
fig_kwargs (dict, optional): Keyword arguments passed to matplotlib's
|
|
112
|
+
subplots function for plot customization. Examples:
|
|
113
|
+
- {'figsize': (10, 6)} for plot dimensions
|
|
114
|
+
- {'dpi': 300} for high resolution
|
|
115
|
+
- {'facecolor': 'white'} for background color
|
|
116
|
+
Defaults to None (uses matplotlib defaults).
|
|
117
|
+
output_path (str, optional): Directory path where violin plots will be saved.
|
|
118
|
+
If None, uses the fabric logger's log directory. The directory will be
|
|
119
|
+
created if it doesn't exist. Defaults to None.
|
|
120
|
+
**kwargs: Additional keyword arguments passed to parent classes.
|
|
121
|
+
|
|
122
|
+
Note:
|
|
123
|
+
The parameter name 'fig_kwawrgs' appears to be a typo for 'fig_kwargs'.
|
|
124
|
+
This should be corrected in the parameter name for consistency.
|
|
77
125
|
"""
|
|
78
|
-
self.trainable_only = trainable_only
|
|
79
|
-
self.fig_kwargs = fig_kwawrgs
|
|
80
|
-
self.max_points_per_model = max_points_per_model
|
|
81
|
-
self._output_path = output_path
|
|
82
126
|
super().__init__(**kwargs)
|
|
127
|
+
self._output_path = output_path
|
|
83
128
|
|
|
84
129
|
@property
|
|
85
130
|
def output_path(self):
|
|
@@ -89,20 +134,39 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
89
134
|
return self._output_path
|
|
90
135
|
|
|
91
136
|
def run(self, modelpool: BaseModelPool):
|
|
92
|
-
"""
|
|
137
|
+
"""
|
|
138
|
+
Execute the task vector violin plot analysis and visualization.
|
|
93
139
|
|
|
94
|
-
This method implements the
|
|
95
|
-
|
|
140
|
+
This method implements the core algorithm that:
|
|
141
|
+
1. Loads the pretrained base model from the model pool
|
|
142
|
+
2. Computes task vectors for each fine-tuned model (parameter differences)
|
|
143
|
+
3. Creates two violin plots showing the distribution of task vector values:
|
|
144
|
+
- Raw values plot: Shows positive and negative parameter changes
|
|
145
|
+
- Absolute values plot: Shows magnitude of parameter changes
|
|
146
|
+
4. Saves both plots as PDF files in the output directory
|
|
96
147
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
3. Creates violin plots showing the distribution of values in these task vectors
|
|
148
|
+
The visualization technique follows the approach described in:
|
|
149
|
+
"Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
|
|
100
150
|
|
|
101
151
|
Args:
|
|
102
|
-
modelpool (BaseModelPool):
|
|
152
|
+
modelpool (BaseModelPool): Pool containing both a pretrained model and
|
|
153
|
+
fine-tuned models. Must have `has_pretrained=True`.
|
|
103
154
|
|
|
104
155
|
Returns:
|
|
105
|
-
|
|
156
|
+
nn.Module: The pretrained model loaded from the model pool.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
AssertionError: If the model pool doesn't contain a pretrained model.
|
|
160
|
+
|
|
161
|
+
Side Effects:
|
|
162
|
+
- Creates output directory if it doesn't exist
|
|
163
|
+
- Saves 'task_vector_violin.pdf' (raw values distribution)
|
|
164
|
+
- Saves 'task_vector_violin_abs.pdf' (absolute values distribution)
|
|
165
|
+
- Prints progress information during task vector computation
|
|
166
|
+
|
|
167
|
+
Example Output Files:
|
|
168
|
+
- task_vector_violin.pdf: Shows how parameters change (+ and -)
|
|
169
|
+
- task_vector_violin_abs.pdf: Shows magnitude of parameter changes
|
|
106
170
|
"""
|
|
107
171
|
assert modelpool.has_pretrained
|
|
108
172
|
pretrained_model = modelpool.load_pretrained_model()
|
|
@@ -175,6 +239,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
175
239
|
return pretrained_model
|
|
176
240
|
|
|
177
241
|
def get_task_vector(self, pretrained_model, finetuned_model):
|
|
242
|
+
"""
|
|
243
|
+
Compute the task vector representing parameter changes from pretraining to fine-tuning.
|
|
244
|
+
|
|
245
|
+
The task vector quantifies how model parameters have changed during task-specific
|
|
246
|
+
fine-tuning and is computed as:
|
|
247
|
+
task_vector = finetuned_params - pretrained_params
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
pretrained_model (nn.Module): The base pretrained model
|
|
251
|
+
finetuned_model (nn.Module): The fine-tuned model for a specific task
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
np.ndarray: Flattened numpy array containing parameter differences.
|
|
255
|
+
If max_points_per_model is set, the array may be randomly downsampled
|
|
256
|
+
for memory efficiency and visualization clarity.
|
|
257
|
+
|
|
258
|
+
Processing Steps:
|
|
259
|
+
1. Extract state dictionaries from both models
|
|
260
|
+
2. Compute parameter differences (subtraction)
|
|
261
|
+
3. Flatten to 1D vector
|
|
262
|
+
4. Convert to numpy array with float32 precision
|
|
263
|
+
5. Optionally downsample if max_points_per_model is specified
|
|
264
|
+
|
|
265
|
+
Note:
|
|
266
|
+
- Uses only trainable parameters if trainable_only=True
|
|
267
|
+
- Downsampling uses random sampling without replacement
|
|
268
|
+
- Preserves the relative distribution of parameter changes
|
|
269
|
+
"""
|
|
178
270
|
task_vector = state_dict_sub(
|
|
179
271
|
self.get_state_dict(finetuned_model),
|
|
180
272
|
self.get_state_dict(pretrained_model),
|
|
@@ -199,6 +291,22 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
199
291
|
return task_vector
|
|
200
292
|
|
|
201
293
|
def get_state_dict(self, model: nn.Module):
|
|
294
|
+
"""
|
|
295
|
+
Extract the state dictionary from a model based on parameter filtering settings.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
model (nn.Module): The PyTorch model to extract parameters from
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Dict[str, torch.Tensor]: State dictionary containing model parameters.
|
|
302
|
+
If trainable_only=True, returns only parameters with requires_grad=True.
|
|
303
|
+
If trainable_only=False, returns all parameters including frozen ones.
|
|
304
|
+
|
|
305
|
+
Note:
|
|
306
|
+
This method respects the trainable_only configuration to focus analysis
|
|
307
|
+
on either learnable parameters or the complete parameter set depending
|
|
308
|
+
on the research question being addressed.
|
|
309
|
+
"""
|
|
202
310
|
if self.trainable_only:
|
|
203
311
|
return trainable_state_dict(model)
|
|
204
312
|
else:
|
|
@@ -6,7 +6,11 @@ import torch.nn.functional as F
|
|
|
6
6
|
from tqdm.auto import tqdm
|
|
7
7
|
|
|
8
8
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
9
|
-
from fusion_bench.mixins import
|
|
9
|
+
from fusion_bench.mixins import (
|
|
10
|
+
LightningFabricMixin,
|
|
11
|
+
SimpleProfilerMixin,
|
|
12
|
+
auto_register_config,
|
|
13
|
+
)
|
|
10
14
|
from fusion_bench.modelpool import CausalLMPool
|
|
11
15
|
|
|
12
16
|
from .bitdelta_utils.data import get_dataloader, get_dataset
|
|
@@ -15,23 +19,12 @@ from .bitdelta_utils.diff import compress_diff, save_diff, save_full_model
|
|
|
15
19
|
log = logging.getLogger(__name__)
|
|
16
20
|
|
|
17
21
|
|
|
22
|
+
@auto_register_config
|
|
18
23
|
class BitDeltaAlgorithm(
|
|
19
|
-
BaseAlgorithm,
|
|
20
24
|
LightningFabricMixin,
|
|
21
25
|
SimpleProfilerMixin,
|
|
26
|
+
BaseAlgorithm,
|
|
22
27
|
):
|
|
23
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
24
|
-
"save_dir": "save_dir",
|
|
25
|
-
"save_full_model": "save_full_model",
|
|
26
|
-
"lr": "lr",
|
|
27
|
-
"batch_size": "batch_size",
|
|
28
|
-
"num_steps": "num_steps",
|
|
29
|
-
"dataset_name": "dataset_name",
|
|
30
|
-
"subset": "subset",
|
|
31
|
-
"split": "split",
|
|
32
|
-
"max_length": "max_length",
|
|
33
|
-
}
|
|
34
|
-
|
|
35
28
|
def __init__(
|
|
36
29
|
self,
|
|
37
30
|
save_dir: str,
|
|
@@ -46,15 +39,6 @@ class BitDeltaAlgorithm(
|
|
|
46
39
|
**kwargs,
|
|
47
40
|
):
|
|
48
41
|
super().__init__(**kwargs)
|
|
49
|
-
self.save_dir = save_dir
|
|
50
|
-
self.save_full_model = save_full_model
|
|
51
|
-
self.lr = lr
|
|
52
|
-
self.batch_size = batch_size
|
|
53
|
-
self.num_steps = num_steps
|
|
54
|
-
self.dataset_name = dataset_name
|
|
55
|
-
self.subset = subset
|
|
56
|
-
self.split = split
|
|
57
|
-
self.max_length = max_length
|
|
58
42
|
|
|
59
43
|
def run(self, modelpool: CausalLMPool):
|
|
60
44
|
if self.save_dir is None:
|
|
@@ -393,7 +393,7 @@ def convert_l_lora_state_dict_to_hf(
|
|
|
393
393
|
base_model_name: Optional[str] = None,
|
|
394
394
|
):
|
|
395
395
|
"""
|
|
396
|
-
Convert a linearized Lora model's checkpoint to
|
|
396
|
+
Convert a linearized Lora model's checkpoint to huggingface's format.
|
|
397
397
|
|
|
398
398
|
Args:
|
|
399
399
|
pretrained_path (str): The path to the pretrained model.
|
|
@@ -23,6 +23,7 @@ from transformers import MixtralForCausalLM
|
|
|
23
23
|
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
24
24
|
|
|
25
25
|
import fusion_bench as fb
|
|
26
|
+
from fusion_bench import auto_register_config
|
|
26
27
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
27
28
|
build_calib_loader,
|
|
28
29
|
)
|
|
@@ -97,6 +98,7 @@ def dynamic_skipping(
|
|
|
97
98
|
return model, (res_median, res_mean)
|
|
98
99
|
|
|
99
100
|
|
|
101
|
+
@auto_register_config
|
|
100
102
|
class DynamicSkippingPruningForMixtral(
|
|
101
103
|
fb.BaseAlgorithm,
|
|
102
104
|
fb.mixins.LightningFabricMixin,
|
|
@@ -22,6 +22,7 @@ from transformers import MixtralForCausalLM
|
|
|
22
22
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
|
23
23
|
|
|
24
24
|
import fusion_bench as fb
|
|
25
|
+
from fusion_bench import auto_register_config
|
|
25
26
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
26
27
|
build_calib_loader,
|
|
27
28
|
)
|
|
@@ -81,6 +82,7 @@ def layerwise_pruning(
|
|
|
81
82
|
return model, (global_loss_history,)
|
|
82
83
|
|
|
83
84
|
|
|
85
|
+
@auto_register_config
|
|
84
86
|
class LayerWisePruningForMixtral(
|
|
85
87
|
fb.BaseAlgorithm,
|
|
86
88
|
fb.mixins.LightningFabricMixin,
|
|
@@ -20,6 +20,7 @@ from tqdm import tqdm
|
|
|
20
20
|
from transformers import MixtralForCausalLM
|
|
21
21
|
|
|
22
22
|
import fusion_bench as fb
|
|
23
|
+
from fusion_bench import auto_register_config
|
|
23
24
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
24
25
|
build_calib_loader,
|
|
25
26
|
)
|
|
@@ -95,6 +96,7 @@ def progressive_pruning(
|
|
|
95
96
|
return model, (global_loss_history,)
|
|
96
97
|
|
|
97
98
|
|
|
99
|
+
@auto_register_config
|
|
98
100
|
class ProgressivePruningForMixtral(
|
|
99
101
|
fb.BaseAlgorithm,
|
|
100
102
|
fb.mixins.LightningFabricMixin,
|
|
@@ -32,7 +32,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
32
32
|
zeroshot_weights = {}
|
|
33
33
|
|
|
34
34
|
_config_mapping = FisherMergingAlgorithm._config_mapping | {
|
|
35
|
-
"zeroshot_weights_cache_dir": "zeroshot_weights_cache_dir",
|
|
36
35
|
"_dataloader_kwargs": "dataloader_kwargs",
|
|
37
36
|
}
|
|
38
37
|
|
|
@@ -44,7 +43,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
44
43
|
minimal_fisher_weight,
|
|
45
44
|
num_fisher_examples,
|
|
46
45
|
dataloader_kwargs: DictConfig,
|
|
47
|
-
zeroshot_weights_cache_dir=None,
|
|
48
46
|
**kwargs,
|
|
49
47
|
):
|
|
50
48
|
"""
|
|
@@ -56,7 +54,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
56
54
|
minimal_fisher_weight (float): Minimal value for Fisher weights to avoid numerical issues.
|
|
57
55
|
num_fisher_examples (int): Number of examples to compute Fisher weights.
|
|
58
56
|
dataloader_kwargs (DictConfig): Configuration for the dataloader.
|
|
59
|
-
zeroshot_weights_cache_dir (str, optional): Directory to cache zero-shot weights. Defaults to None.
|
|
60
57
|
**kwargs: Additional keyword arguments.
|
|
61
58
|
"""
|
|
62
59
|
super().__init__(
|
|
@@ -66,7 +63,6 @@ class FisherMergingForCLIPVisionModel(
|
|
|
66
63
|
num_fisher_examples=num_fisher_examples,
|
|
67
64
|
)
|
|
68
65
|
self.dataloader_kwargs = dataloader_kwargs
|
|
69
|
-
self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
|
|
70
66
|
for key, value in kwargs.items():
|
|
71
67
|
log.warning(f"Unused argument: {key}={value}")
|
|
72
68
|
setattr(self, key, value)
|
|
@@ -15,10 +15,10 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
|
|
|
15
15
|
from transformers.data import default_data_collator
|
|
16
16
|
from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
17
17
|
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, auto_register_config
|
|
19
19
|
from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
|
|
20
20
|
from fusion_bench.utils import timeit_context
|
|
21
|
-
|
|
21
|
+
|
|
22
22
|
from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
|
|
23
23
|
|
|
24
24
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from copy import deepcopy
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
@@ -7,13 +8,16 @@ from typing_extensions import override
|
|
|
7
8
|
from fusion_bench import timeit_context
|
|
8
9
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
10
|
from fusion_bench.method.simple_average import SimpleAverageAlgorithm
|
|
11
|
+
from fusion_bench.mixins import auto_register_config
|
|
10
12
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
13
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
11
14
|
from fusion_bench.utils import instantiate
|
|
12
|
-
from fusion_bench.utils.pylogger import
|
|
15
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
13
16
|
|
|
14
|
-
log =
|
|
17
|
+
log = get_rankzero_logger(__name__)
|
|
15
18
|
|
|
16
19
|
|
|
20
|
+
@auto_register_config
|
|
17
21
|
class SimpleAverageForLlama(BaseAlgorithm):
|
|
18
22
|
R"""
|
|
19
23
|
A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
|
|
@@ -29,21 +33,14 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
29
33
|
```
|
|
30
34
|
"""
|
|
31
35
|
|
|
32
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
33
|
-
"merge_backbone": "merge_backbone",
|
|
34
|
-
"show_pbar": "show_pbar",
|
|
35
|
-
}
|
|
36
|
-
|
|
37
36
|
def __init__(
|
|
38
37
|
self,
|
|
39
38
|
merge_backbone: bool,
|
|
40
39
|
model_save_path: Optional[str] = None,
|
|
41
40
|
show_pbar: bool = False,
|
|
41
|
+
**kwargs,
|
|
42
42
|
):
|
|
43
|
-
super().__init__()
|
|
44
|
-
self.merge_backbone = merge_backbone
|
|
45
|
-
self.model_save_path = model_save_path
|
|
46
|
-
self.show_pbar = show_pbar
|
|
43
|
+
super().__init__(**kwargs)
|
|
47
44
|
|
|
48
45
|
@override
|
|
49
46
|
def run(self, modelpool: CausalLMPool):
|
|
@@ -75,4 +72,12 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
75
72
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
76
73
|
tokenizer.save_pretrained(self.model_save_path)
|
|
77
74
|
model.save_pretrained(self.model_save_path)
|
|
75
|
+
model_card_str = create_default_model_card(
|
|
76
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
77
|
+
description="Merged model using simple averaging.",
|
|
78
|
+
algorithm_config=self.config,
|
|
79
|
+
modelpool_config=modelpool.config,
|
|
80
|
+
)
|
|
81
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
82
|
+
f.write(model_card_str)
|
|
78
83
|
return model
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model_stock import ModelStock
|