fusion-bench 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +4 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/serialization.py +25 -15
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +262 -43
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_utils.py +9 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +82 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +53 -47
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -32,6 +32,10 @@ from .models import (
|
|
|
32
32
|
from .programs import BaseHydraProgram
|
|
33
33
|
from .taskpool import BaseTaskPool
|
|
34
34
|
from .utils import (
|
|
35
|
+
BoolStateDictType,
|
|
36
|
+
LazyStateDict,
|
|
37
|
+
StateDictType,
|
|
38
|
+
TorchModelType,
|
|
35
39
|
cache_with_joblib,
|
|
36
40
|
get_rankzero_logger,
|
|
37
41
|
import_object,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
|
+
from typing import Any, List, Type
|
|
2
3
|
|
|
3
4
|
from omegaconf import DictConfig
|
|
4
5
|
|
|
@@ -76,7 +77,9 @@ class AlgorithmFactory:
|
|
|
76
77
|
return algorithm_cls(method_config)
|
|
77
78
|
|
|
78
79
|
@staticmethod
|
|
79
|
-
def register_algorithm(
|
|
80
|
+
def register_algorithm(
|
|
81
|
+
name: str, algorithm_cls: Type[ModelFusionAlgorithm]
|
|
82
|
+
) -> None:
|
|
80
83
|
"""
|
|
81
84
|
Register a new algorithm with the factory.
|
|
82
85
|
|
|
@@ -87,7 +90,7 @@ class AlgorithmFactory:
|
|
|
87
90
|
AlgorithmFactory._aglorithms[name] = algorithm_cls
|
|
88
91
|
|
|
89
92
|
@classmethod
|
|
90
|
-
def available_algorithms(cls):
|
|
93
|
+
def available_algorithms(cls) -> List[str]:
|
|
91
94
|
"""
|
|
92
95
|
Get a list of available algorithms.
|
|
93
96
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
3
|
|
|
4
4
|
from omegaconf import DictConfig
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
|
+
from fusion_bench import BaseModelPool
|
|
7
8
|
from fusion_bench.programs.base_program import BaseHydraProgram
|
|
8
9
|
|
|
9
10
|
__all__ = ["ModelFusionAlgorithm"]
|
|
@@ -51,7 +52,7 @@ class ModelFusionAlgorithm(ABC):
|
|
|
51
52
|
pass
|
|
52
53
|
|
|
53
54
|
@abstractmethod
|
|
54
|
-
def run(self, modelpool):
|
|
55
|
+
def run(self, modelpool: "BaseModelPool") -> Any:
|
|
55
56
|
"""
|
|
56
57
|
Fuse the models in the given model pool.
|
|
57
58
|
|
|
@@ -42,7 +42,7 @@ class ModelPool(ABC):
|
|
|
42
42
|
), "Duplicate model names found in model pool"
|
|
43
43
|
self._model_names = model_names
|
|
44
44
|
|
|
45
|
-
def __len__(self):
|
|
45
|
+
def __len__(self) -> int:
|
|
46
46
|
"""
|
|
47
47
|
Return the number of models in the model pool, exclude special models such as `_pretrained_`.
|
|
48
48
|
|
|
@@ -66,7 +66,7 @@ class ModelPool(ABC):
|
|
|
66
66
|
return names
|
|
67
67
|
|
|
68
68
|
@property
|
|
69
|
-
def has_pretrained(self):
|
|
69
|
+
def has_pretrained(self) -> bool:
|
|
70
70
|
"""
|
|
71
71
|
Check if the pretrained model is available in the model pool.
|
|
72
72
|
|
|
@@ -78,7 +78,7 @@ class ModelPool(ABC):
|
|
|
78
78
|
return True
|
|
79
79
|
return False
|
|
80
80
|
|
|
81
|
-
def get_model_config(self, model_name: str):
|
|
81
|
+
def get_model_config(self, model_name: str) -> Dict:
|
|
82
82
|
"""
|
|
83
83
|
Retrieves the configuration for a specific model from the model pool.
|
|
84
84
|
|
|
@@ -169,7 +169,7 @@ class CLIPImageClassificationTaskPool(TaskPool):
|
|
|
169
169
|
self._fabric = L.Fabric(devices=1)
|
|
170
170
|
self._fabric.launch()
|
|
171
171
|
|
|
172
|
-
# CLIPVisionModel works the same with
|
|
172
|
+
# CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
|
|
173
173
|
self.clip_model.vision_model = model
|
|
174
174
|
report = {}
|
|
175
175
|
training_params, all_params = count_parameters(model)
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -30,7 +30,7 @@ _import_structure = {
|
|
|
30
30
|
"TaskArithmeticForLlama",
|
|
31
31
|
"LinearInterpolationAlgorithm",
|
|
32
32
|
],
|
|
33
|
-
"slerp": ["SlerpMergeAlgorithm"],
|
|
33
|
+
"slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
|
|
34
34
|
"simple_average": ["SimpleAverageAlgorithm"],
|
|
35
35
|
"weighted_average": ["WeightedAverageAlgorithm", "WeightedAverageForLLama"],
|
|
36
36
|
"task_arithmetic": ["TaskArithmeticAlgorithm"],
|
|
@@ -71,6 +71,7 @@ _import_structure = {
|
|
|
71
71
|
],
|
|
72
72
|
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
73
73
|
"tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
|
|
74
|
+
"model_stock": ["ModelStock"],
|
|
74
75
|
# plug-and-play model merging methods
|
|
75
76
|
"concrete_subspace": [
|
|
76
77
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -194,6 +195,7 @@ if TYPE_CHECKING:
|
|
|
194
195
|
MixtralUpscalingAlgorithm,
|
|
195
196
|
)
|
|
196
197
|
from .model_recombination import ModelRecombinationAlgorithm
|
|
198
|
+
from .model_stock import ModelStock
|
|
197
199
|
from .opcm import OPCMForCLIP
|
|
198
200
|
from .pruning import (
|
|
199
201
|
MagnitudeDiffPruningAlgorithm,
|
|
@@ -213,7 +215,7 @@ if TYPE_CHECKING:
|
|
|
213
215
|
RegMeanAlgorithmPlusPlus,
|
|
214
216
|
)
|
|
215
217
|
from .simple_average import SimpleAverageAlgorithm
|
|
216
|
-
from .slerp import SlerpMergeAlgorithm
|
|
218
|
+
from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
|
|
217
219
|
from .smile_upscaling import (
|
|
218
220
|
SingularProjectionMergingAlgorithm,
|
|
219
221
|
SmileUpscalingAlgorithm,
|
|
@@ -11,7 +11,7 @@ from torch import nn
|
|
|
11
11
|
from tqdm.auto import tqdm
|
|
12
12
|
|
|
13
13
|
from fusion_bench.method import BaseAlgorithm
|
|
14
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
14
|
+
from fusion_bench.mixins import LightningFabricMixin, auto_register_config
|
|
15
15
|
from fusion_bench.modelpool import BaseModelPool
|
|
16
16
|
from fusion_bench.utils.parameters import (
|
|
17
17
|
StateDictType,
|
|
@@ -23,14 +23,50 @@ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
|
23
23
|
log = logging.getLogger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
@auto_register_config
|
|
27
|
+
class TaskVectorCosSimilarity(
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
BaseAlgorithm,
|
|
30
|
+
):
|
|
27
31
|
"""
|
|
28
|
-
|
|
29
|
-
|
|
32
|
+
Computes and analyzes cosine similarity between task vectors of models in a model pool.
|
|
33
|
+
|
|
34
|
+
This algorithm extracts task vectors from fine-tuned models by computing the difference
|
|
35
|
+
between their parameters and a pretrained base model. It then calculates the pairwise
|
|
36
|
+
cosine similarity between all task vectors to understand the relationships and overlap
|
|
37
|
+
between different tasks.
|
|
38
|
+
|
|
39
|
+
The task vector for a model is defined as:
|
|
40
|
+
task_vector = finetuned_model_params - pretrained_model_params
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
plot_heatmap (bool): Whether to generate and save a heatmap visualization
|
|
44
|
+
trainable_only (bool, optional): If True, only consider trainable parameters
|
|
45
|
+
when computing task vectors. Defaults to True.
|
|
46
|
+
max_points_per_model (int, optional): Maximum number of parameters to sample
|
|
47
|
+
per model for memory efficiency. If None, uses all parameters.
|
|
48
|
+
output_path (str, optional): Directory to save outputs. If None, uses the
|
|
49
|
+
fabric logger directory.
|
|
50
|
+
|
|
51
|
+
Outputs:
|
|
52
|
+
- task_vector_cos_similarity.csv: Pairwise cosine similarity matrix
|
|
53
|
+
- task_vector_cos_similarity.pdf: Heatmap visualization (if plot_heatmap=True)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The pretrained model from the model pool.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
```python
|
|
60
|
+
>>> algorithm = TaskVectorCosSimilarity(
|
|
61
|
+
... plot_heatmap=True,
|
|
62
|
+
... trainable_only=True,
|
|
63
|
+
... output_path="/path/to/outputs"
|
|
64
|
+
... )
|
|
65
|
+
>>> result = algorithm.run(modelpool)
|
|
66
|
+
```
|
|
30
67
|
"""
|
|
31
68
|
|
|
32
69
|
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
33
|
-
"plot_heatmap": "plot_heatmap",
|
|
34
70
|
"_output_path": "output_path",
|
|
35
71
|
}
|
|
36
72
|
|
|
@@ -42,11 +78,8 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
42
78
|
output_path: Optional[str] = None,
|
|
43
79
|
**kwargs,
|
|
44
80
|
):
|
|
45
|
-
self.plot_heatmap = plot_heatmap
|
|
46
|
-
self.trainable_only = trainable_only
|
|
47
|
-
self.max_points_per_model = max_points_per_model
|
|
48
|
-
self._output_path = output_path
|
|
49
81
|
super().__init__(**kwargs)
|
|
82
|
+
self._output_path = output_path
|
|
50
83
|
|
|
51
84
|
@property
|
|
52
85
|
def output_path(self):
|
|
@@ -57,6 +90,22 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
57
90
|
|
|
58
91
|
@torch.no_grad()
|
|
59
92
|
def run(self, modelpool: BaseModelPool):
|
|
93
|
+
"""
|
|
94
|
+
Execute the task vector cosine similarity analysis.
|
|
95
|
+
|
|
96
|
+
This method:
|
|
97
|
+
1. Loads the pretrained base model from the model pool
|
|
98
|
+
2. Computes task vectors for each fine-tuned model
|
|
99
|
+
3. Calculates pairwise cosine similarities between all task vectors
|
|
100
|
+
4. Saves the similarity matrix as a CSV file
|
|
101
|
+
5. Optionally generates and saves a heatmap visualization
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
modelpool (BaseModelPool): Pool containing pretrained and fine-tuned models
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
nn.Module: The pretrained model from the model pool
|
|
108
|
+
"""
|
|
60
109
|
pretrained_model = modelpool.load_pretrained_model()
|
|
61
110
|
|
|
62
111
|
task_vectors = []
|
|
@@ -103,11 +152,14 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
103
152
|
|
|
104
153
|
def _plot_heatmap(self, data: pd.DataFrame):
|
|
105
154
|
"""
|
|
106
|
-
|
|
155
|
+
Generate and save a heatmap visualization of the cosine similarity matrix.
|
|
156
|
+
|
|
157
|
+
Creates a color-coded heatmap showing pairwise cosine similarities between
|
|
158
|
+
task vectors. The heatmap is saved as a PDF file in the output directory.
|
|
107
159
|
|
|
108
160
|
Args:
|
|
109
|
-
data (pd.DataFrame):
|
|
110
|
-
|
|
161
|
+
data (pd.DataFrame): Symmetric matrix of cosine similarities between
|
|
162
|
+
task vectors, with model names as both index and columns.
|
|
111
163
|
|
|
112
164
|
Returns:
|
|
113
165
|
None
|
|
@@ -141,6 +193,26 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
141
193
|
def get_task_vector(
|
|
142
194
|
self, pretrained_model: nn.Module, finetuned_model: nn.Module
|
|
143
195
|
) -> torch.Tensor:
|
|
196
|
+
"""
|
|
197
|
+
Compute the task vector for a fine-tuned model.
|
|
198
|
+
|
|
199
|
+
The task vector represents the parameter changes from pretraining to
|
|
200
|
+
fine-tuning and is computed as:
|
|
201
|
+
task_vector = finetuned_params - pretrained_params
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
pretrained_model (nn.Module): The base pretrained model
|
|
205
|
+
finetuned_model (nn.Module): The fine-tuned model for a specific task
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
torch.Tensor: Flattened task vector containing parameter differences.
|
|
209
|
+
If max_points_per_model is set, the vector may be downsampled.
|
|
210
|
+
|
|
211
|
+
Note:
|
|
212
|
+
- Converts parameters to float64 for numerical precision
|
|
213
|
+
- Supports optional downsampling for memory efficiency
|
|
214
|
+
- Uses only trainable parameters if trainable_only=True
|
|
215
|
+
"""
|
|
144
216
|
task_vector = state_dict_sub(
|
|
145
217
|
self.get_state_dict(finetuned_model),
|
|
146
218
|
self.get_state_dict(pretrained_model),
|
|
@@ -166,6 +238,17 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
166
238
|
return task_vector
|
|
167
239
|
|
|
168
240
|
def get_state_dict(self, model: nn.Module):
|
|
241
|
+
"""
|
|
242
|
+
Extract the state dictionary from a model.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model (nn.Module): The model to extract parameters from
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Dict[str, torch.Tensor]: State dictionary containing model parameters.
|
|
249
|
+
Returns only trainable parameters if trainable_only=True,
|
|
250
|
+
otherwise returns all parameters.
|
|
251
|
+
"""
|
|
169
252
|
if self.trainable_only:
|
|
170
253
|
return trainable_state_dict(model)
|
|
171
254
|
else:
|
|
@@ -11,30 +11,81 @@ from numpy.typing import NDArray
|
|
|
11
11
|
from torch import nn
|
|
12
12
|
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
|
-
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
15
|
-
from fusion_bench.mixins import
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
state_dict_to_vector,
|
|
20
|
-
trainable_state_dict,
|
|
14
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, StateDictType, timeit_context
|
|
15
|
+
from fusion_bench.mixins import (
|
|
16
|
+
LightningFabricMixin,
|
|
17
|
+
SimpleProfilerMixin,
|
|
18
|
+
auto_register_config,
|
|
21
19
|
)
|
|
20
|
+
from fusion_bench.utils import state_dict_to_vector, trainable_state_dict
|
|
22
21
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
23
22
|
|
|
24
23
|
log = logging.getLogger(__name__)
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
@auto_register_config
|
|
27
|
+
class TaskVectorViolinPlot(
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
SimpleProfilerMixin,
|
|
30
|
+
BaseAlgorithm,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Creates violin plots to visualize the distribution of task vector values across models.
|
|
34
|
+
|
|
35
|
+
This class implements the task vector visualization technique described in:
|
|
36
|
+
"Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
|
|
37
|
+
by L. Shen, A. Tang, E. Yang et al. (https://arxiv.org/abs/2410.21804)
|
|
38
|
+
|
|
39
|
+
Task vectors represent the parameter differences between fine-tuned models and their
|
|
40
|
+
pretrained base model, computed as:
|
|
41
|
+
task_vector = finetuned_params - pretrained_params
|
|
42
|
+
|
|
43
|
+
The algorithm generates two types of violin plots:
|
|
44
|
+
1. Distribution of raw task vector values (positive and negative)
|
|
45
|
+
2. Distribution of absolute task vector values
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
trainable_only (bool): If True, only consider trainable parameters when computing
|
|
49
|
+
task vectors. If False, use all parameters.
|
|
50
|
+
max_points_per_model (int, optional): Maximum number of parameters to sample
|
|
51
|
+
per model for memory efficiency. If None or 0, uses all parameters.
|
|
52
|
+
Defaults to 1000.
|
|
53
|
+
fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
|
|
54
|
+
matplotlib.pyplot.subplots. Common options include:
|
|
55
|
+
- figsize: Tuple of (width, height) in inches
|
|
56
|
+
- dpi: Dots per inch for resolution
|
|
57
|
+
- facecolor: Figure background color
|
|
58
|
+
Defaults to None.
|
|
59
|
+
output_path (str, optional): Directory to save the violin plots. If None,
|
|
60
|
+
uses the fabric logger's log directory. Defaults to None.
|
|
61
|
+
|
|
62
|
+
Outputs:
|
|
63
|
+
- task_vector_violin.pdf: Violin plot of raw task vector value distributions
|
|
64
|
+
- task_vector_violin_abs.pdf: Violin plot of absolute task vector value distributions
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
The pretrained model from the model pool.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
```python
|
|
71
|
+
plotter = TaskVectorViolinPlot(
|
|
72
|
+
trainable_only=True,
|
|
73
|
+
max_points_per_model=5000,
|
|
74
|
+
fig_kwargs={'figsize': (12, 8), 'dpi': 300},
|
|
75
|
+
output_path='./analysis_plots'
|
|
76
|
+
)
|
|
77
|
+
pretrained_model = plotter.run(modelpool)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
This visualization is particularly useful for understanding:
|
|
82
|
+
- How different tasks affect model parameters
|
|
83
|
+
- The magnitude and distribution of parameter changes
|
|
84
|
+
- Similarities and differences between task adaptations
|
|
31
85
|
"""
|
|
32
86
|
|
|
33
87
|
# config_mapping is a mapping from the attributes to the key in the configuration files
|
|
34
88
|
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
35
|
-
"trainable_only": "trainable_only",
|
|
36
|
-
"max_points_per_model": "max_points_per_model",
|
|
37
|
-
"fig_kwargs": "fig_kwargs",
|
|
38
89
|
"_output_path": "output_path",
|
|
39
90
|
}
|
|
40
91
|
|
|
@@ -46,40 +97,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
46
97
|
output_path: Optional[str] = None,
|
|
47
98
|
**kwargs,
|
|
48
99
|
):
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
between fine-tuned models and their pretrained base model.
|
|
100
|
+
"""
|
|
101
|
+
Initialize the TaskVectorViolinPlot analyzer.
|
|
52
102
|
|
|
53
103
|
Args:
|
|
54
|
-
trainable_only (bool):
|
|
55
|
-
task vectors.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
plotter.run(modelpool)
|
|
76
|
-
```
|
|
104
|
+
trainable_only (bool): Whether to consider only trainable parameters when
|
|
105
|
+
computing task vectors. Set to True to focus on learnable parameters,
|
|
106
|
+
False to include all parameters including frozen ones.
|
|
107
|
+
max_points_per_model (int, optional): Maximum number of parameter values
|
|
108
|
+
to sample per model for visualization. Useful for large models to
|
|
109
|
+
manage memory usage and plot clarity. Set to None or 0 to use all
|
|
110
|
+
parameters. Defaults to 1000.
|
|
111
|
+
fig_kwargs (dict, optional): Keyword arguments passed to matplotlib's
|
|
112
|
+
subplots function for plot customization. Examples:
|
|
113
|
+
- {'figsize': (10, 6)} for plot dimensions
|
|
114
|
+
- {'dpi': 300} for high resolution
|
|
115
|
+
- {'facecolor': 'white'} for background color
|
|
116
|
+
Defaults to None (uses matplotlib defaults).
|
|
117
|
+
output_path (str, optional): Directory path where violin plots will be saved.
|
|
118
|
+
If None, uses the fabric logger's log directory. The directory will be
|
|
119
|
+
created if it doesn't exist. Defaults to None.
|
|
120
|
+
**kwargs: Additional keyword arguments passed to parent classes.
|
|
121
|
+
|
|
122
|
+
Note:
|
|
123
|
+
The parameter name 'fig_kwawrgs' appears to be a typo for 'fig_kwargs'.
|
|
124
|
+
This should be corrected in the parameter name for consistency.
|
|
77
125
|
"""
|
|
78
|
-
self.trainable_only = trainable_only
|
|
79
|
-
self.fig_kwargs = fig_kwawrgs
|
|
80
|
-
self.max_points_per_model = max_points_per_model
|
|
81
|
-
self._output_path = output_path
|
|
82
126
|
super().__init__(**kwargs)
|
|
127
|
+
self._output_path = output_path
|
|
83
128
|
|
|
84
129
|
@property
|
|
85
130
|
def output_path(self):
|
|
@@ -89,20 +134,39 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
89
134
|
return self._output_path
|
|
90
135
|
|
|
91
136
|
def run(self, modelpool: BaseModelPool):
|
|
92
|
-
"""
|
|
137
|
+
"""
|
|
138
|
+
Execute the task vector violin plot analysis and visualization.
|
|
93
139
|
|
|
94
|
-
This method implements the
|
|
95
|
-
|
|
140
|
+
This method implements the core algorithm that:
|
|
141
|
+
1. Loads the pretrained base model from the model pool
|
|
142
|
+
2. Computes task vectors for each fine-tuned model (parameter differences)
|
|
143
|
+
3. Creates two violin plots showing the distribution of task vector values:
|
|
144
|
+
- Raw values plot: Shows positive and negative parameter changes
|
|
145
|
+
- Absolute values plot: Shows magnitude of parameter changes
|
|
146
|
+
4. Saves both plots as PDF files in the output directory
|
|
96
147
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
3. Creates violin plots showing the distribution of values in these task vectors
|
|
148
|
+
The visualization technique follows the approach described in:
|
|
149
|
+
"Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
|
|
100
150
|
|
|
101
151
|
Args:
|
|
102
|
-
modelpool (BaseModelPool):
|
|
152
|
+
modelpool (BaseModelPool): Pool containing both a pretrained model and
|
|
153
|
+
fine-tuned models. Must have `has_pretrained=True`.
|
|
103
154
|
|
|
104
155
|
Returns:
|
|
105
|
-
|
|
156
|
+
nn.Module: The pretrained model loaded from the model pool.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
AssertionError: If the model pool doesn't contain a pretrained model.
|
|
160
|
+
|
|
161
|
+
Side Effects:
|
|
162
|
+
- Creates output directory if it doesn't exist
|
|
163
|
+
- Saves 'task_vector_violin.pdf' (raw values distribution)
|
|
164
|
+
- Saves 'task_vector_violin_abs.pdf' (absolute values distribution)
|
|
165
|
+
- Prints progress information during task vector computation
|
|
166
|
+
|
|
167
|
+
Example Output Files:
|
|
168
|
+
- task_vector_violin.pdf: Shows how parameters change (+ and -)
|
|
169
|
+
- task_vector_violin_abs.pdf: Shows magnitude of parameter changes
|
|
106
170
|
"""
|
|
107
171
|
assert modelpool.has_pretrained
|
|
108
172
|
pretrained_model = modelpool.load_pretrained_model()
|
|
@@ -175,6 +239,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
175
239
|
return pretrained_model
|
|
176
240
|
|
|
177
241
|
def get_task_vector(self, pretrained_model, finetuned_model):
|
|
242
|
+
"""
|
|
243
|
+
Compute the task vector representing parameter changes from pretraining to fine-tuning.
|
|
244
|
+
|
|
245
|
+
The task vector quantifies how model parameters have changed during task-specific
|
|
246
|
+
fine-tuning and is computed as:
|
|
247
|
+
task_vector = finetuned_params - pretrained_params
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
pretrained_model (nn.Module): The base pretrained model
|
|
251
|
+
finetuned_model (nn.Module): The fine-tuned model for a specific task
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
np.ndarray: Flattened numpy array containing parameter differences.
|
|
255
|
+
If max_points_per_model is set, the array may be randomly downsampled
|
|
256
|
+
for memory efficiency and visualization clarity.
|
|
257
|
+
|
|
258
|
+
Processing Steps:
|
|
259
|
+
1. Extract state dictionaries from both models
|
|
260
|
+
2. Compute parameter differences (subtraction)
|
|
261
|
+
3. Flatten to 1D vector
|
|
262
|
+
4. Convert to numpy array with float32 precision
|
|
263
|
+
5. Optionally downsample if max_points_per_model is specified
|
|
264
|
+
|
|
265
|
+
Note:
|
|
266
|
+
- Uses only trainable parameters if trainable_only=True
|
|
267
|
+
- Downsampling uses random sampling without replacement
|
|
268
|
+
- Preserves the relative distribution of parameter changes
|
|
269
|
+
"""
|
|
178
270
|
task_vector = state_dict_sub(
|
|
179
271
|
self.get_state_dict(finetuned_model),
|
|
180
272
|
self.get_state_dict(pretrained_model),
|
|
@@ -199,6 +291,22 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
|
|
|
199
291
|
return task_vector
|
|
200
292
|
|
|
201
293
|
def get_state_dict(self, model: nn.Module):
|
|
294
|
+
"""
|
|
295
|
+
Extract the state dictionary from a model based on parameter filtering settings.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
model (nn.Module): The PyTorch model to extract parameters from
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Dict[str, torch.Tensor]: State dictionary containing model parameters.
|
|
302
|
+
If trainable_only=True, returns only parameters with requires_grad=True.
|
|
303
|
+
If trainable_only=False, returns all parameters including frozen ones.
|
|
304
|
+
|
|
305
|
+
Note:
|
|
306
|
+
This method respects the trainable_only configuration to focus analysis
|
|
307
|
+
on either learnable parameters or the complete parameter set depending
|
|
308
|
+
on the research question being addressed.
|
|
309
|
+
"""
|
|
202
310
|
if self.trainable_only:
|
|
203
311
|
return trainable_state_dict(model)
|
|
204
312
|
else:
|
|
@@ -6,7 +6,11 @@ import torch.nn.functional as F
|
|
|
6
6
|
from tqdm.auto import tqdm
|
|
7
7
|
|
|
8
8
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
9
|
-
from fusion_bench.mixins import
|
|
9
|
+
from fusion_bench.mixins import (
|
|
10
|
+
LightningFabricMixin,
|
|
11
|
+
SimpleProfilerMixin,
|
|
12
|
+
auto_register_config,
|
|
13
|
+
)
|
|
10
14
|
from fusion_bench.modelpool import CausalLMPool
|
|
11
15
|
|
|
12
16
|
from .bitdelta_utils.data import get_dataloader, get_dataset
|
|
@@ -15,23 +19,12 @@ from .bitdelta_utils.diff import compress_diff, save_diff, save_full_model
|
|
|
15
19
|
log = logging.getLogger(__name__)
|
|
16
20
|
|
|
17
21
|
|
|
22
|
+
@auto_register_config
|
|
18
23
|
class BitDeltaAlgorithm(
|
|
19
|
-
BaseAlgorithm,
|
|
20
24
|
LightningFabricMixin,
|
|
21
25
|
SimpleProfilerMixin,
|
|
26
|
+
BaseAlgorithm,
|
|
22
27
|
):
|
|
23
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
24
|
-
"save_dir": "save_dir",
|
|
25
|
-
"save_full_model": "save_full_model",
|
|
26
|
-
"lr": "lr",
|
|
27
|
-
"batch_size": "batch_size",
|
|
28
|
-
"num_steps": "num_steps",
|
|
29
|
-
"dataset_name": "dataset_name",
|
|
30
|
-
"subset": "subset",
|
|
31
|
-
"split": "split",
|
|
32
|
-
"max_length": "max_length",
|
|
33
|
-
}
|
|
34
|
-
|
|
35
28
|
def __init__(
|
|
36
29
|
self,
|
|
37
30
|
save_dir: str,
|
|
@@ -46,15 +39,6 @@ class BitDeltaAlgorithm(
|
|
|
46
39
|
**kwargs,
|
|
47
40
|
):
|
|
48
41
|
super().__init__(**kwargs)
|
|
49
|
-
self.save_dir = save_dir
|
|
50
|
-
self.save_full_model = save_full_model
|
|
51
|
-
self.lr = lr
|
|
52
|
-
self.batch_size = batch_size
|
|
53
|
-
self.num_steps = num_steps
|
|
54
|
-
self.dataset_name = dataset_name
|
|
55
|
-
self.subset = subset
|
|
56
|
-
self.split = split
|
|
57
|
-
self.max_length = max_length
|
|
58
42
|
|
|
59
43
|
def run(self, modelpool: CausalLMPool):
|
|
60
44
|
if self.save_dir is None:
|
|
@@ -23,6 +23,7 @@ from transformers import MixtralForCausalLM
|
|
|
23
23
|
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
24
24
|
|
|
25
25
|
import fusion_bench as fb
|
|
26
|
+
from fusion_bench import auto_register_config
|
|
26
27
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
27
28
|
build_calib_loader,
|
|
28
29
|
)
|
|
@@ -97,6 +98,7 @@ def dynamic_skipping(
|
|
|
97
98
|
return model, (res_median, res_mean)
|
|
98
99
|
|
|
99
100
|
|
|
101
|
+
@auto_register_config
|
|
100
102
|
class DynamicSkippingPruningForMixtral(
|
|
101
103
|
fb.BaseAlgorithm,
|
|
102
104
|
fb.mixins.LightningFabricMixin,
|
|
@@ -22,6 +22,7 @@ from transformers import MixtralForCausalLM
|
|
|
22
22
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
|
23
23
|
|
|
24
24
|
import fusion_bench as fb
|
|
25
|
+
from fusion_bench import auto_register_config
|
|
25
26
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
26
27
|
build_calib_loader,
|
|
27
28
|
)
|
|
@@ -81,6 +82,7 @@ def layerwise_pruning(
|
|
|
81
82
|
return model, (global_loss_history,)
|
|
82
83
|
|
|
83
84
|
|
|
85
|
+
@auto_register_config
|
|
84
86
|
class LayerWisePruningForMixtral(
|
|
85
87
|
fb.BaseAlgorithm,
|
|
86
88
|
fb.mixins.LightningFabricMixin,
|
|
@@ -20,6 +20,7 @@ from tqdm import tqdm
|
|
|
20
20
|
from transformers import MixtralForCausalLM
|
|
21
21
|
|
|
22
22
|
import fusion_bench as fb
|
|
23
|
+
from fusion_bench import auto_register_config
|
|
23
24
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
24
25
|
build_calib_loader,
|
|
25
26
|
)
|
|
@@ -95,6 +96,7 @@ def progressive_pruning(
|
|
|
95
96
|
return model, (global_loss_history,)
|
|
96
97
|
|
|
97
98
|
|
|
99
|
+
@auto_register_config
|
|
98
100
|
class ProgressivePruningForMixtral(
|
|
99
101
|
fb.BaseAlgorithm,
|
|
100
102
|
fb.mixins.LightningFabricMixin,
|