fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -19,9 +19,32 @@ from . import (
|
|
|
19
19
|
tasks,
|
|
20
20
|
utils,
|
|
21
21
|
)
|
|
22
|
+
from .constants import RuntimeConstants
|
|
22
23
|
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
23
24
|
from .mixins import auto_register_config
|
|
24
25
|
from .modelpool import BaseModelPool
|
|
25
|
-
from .models import
|
|
26
|
+
from .models import (
|
|
27
|
+
create_default_model_card,
|
|
28
|
+
load_model_card_template,
|
|
29
|
+
save_pretrained_with_remote_code,
|
|
30
|
+
separate_io,
|
|
31
|
+
)
|
|
32
|
+
from .programs import BaseHydraProgram
|
|
26
33
|
from .taskpool import BaseTaskPool
|
|
27
|
-
from .utils import
|
|
34
|
+
from .utils import (
|
|
35
|
+
BoolStateDictType,
|
|
36
|
+
LazyStateDict,
|
|
37
|
+
StateDictType,
|
|
38
|
+
TorchModelType,
|
|
39
|
+
cache_with_joblib,
|
|
40
|
+
get_rankzero_logger,
|
|
41
|
+
import_object,
|
|
42
|
+
instantiate,
|
|
43
|
+
parse_dtype,
|
|
44
|
+
print_parameters,
|
|
45
|
+
seed_everything_by_time,
|
|
46
|
+
set_default_cache_dir,
|
|
47
|
+
set_print_function_call,
|
|
48
|
+
set_print_function_call_permeanent,
|
|
49
|
+
timeit_context,
|
|
50
|
+
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
|
+
from typing import Any, List, Type
|
|
2
3
|
|
|
3
4
|
from omegaconf import DictConfig
|
|
4
5
|
|
|
@@ -76,7 +77,9 @@ class AlgorithmFactory:
|
|
|
76
77
|
return algorithm_cls(method_config)
|
|
77
78
|
|
|
78
79
|
@staticmethod
|
|
79
|
-
def register_algorithm(
|
|
80
|
+
def register_algorithm(
|
|
81
|
+
name: str, algorithm_cls: Type[ModelFusionAlgorithm]
|
|
82
|
+
) -> None:
|
|
80
83
|
"""
|
|
81
84
|
Register a new algorithm with the factory.
|
|
82
85
|
|
|
@@ -87,7 +90,7 @@ class AlgorithmFactory:
|
|
|
87
90
|
AlgorithmFactory._aglorithms[name] = algorithm_cls
|
|
88
91
|
|
|
89
92
|
@classmethod
|
|
90
|
-
def available_algorithms(cls):
|
|
93
|
+
def available_algorithms(cls) -> List[str]:
|
|
91
94
|
"""
|
|
92
95
|
Get a list of available algorithms.
|
|
93
96
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
3
|
|
|
4
4
|
from omegaconf import DictConfig
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
|
+
from fusion_bench import BaseModelPool
|
|
7
8
|
from fusion_bench.programs.base_program import BaseHydraProgram
|
|
8
9
|
|
|
9
10
|
__all__ = ["ModelFusionAlgorithm"]
|
|
@@ -51,7 +52,7 @@ class ModelFusionAlgorithm(ABC):
|
|
|
51
52
|
pass
|
|
52
53
|
|
|
53
54
|
@abstractmethod
|
|
54
|
-
def run(self, modelpool):
|
|
55
|
+
def run(self, modelpool: "BaseModelPool") -> Any:
|
|
55
56
|
"""
|
|
56
57
|
Fuse the models in the given model pool.
|
|
57
58
|
|
|
@@ -42,7 +42,7 @@ class ModelPool(ABC):
|
|
|
42
42
|
), "Duplicate model names found in model pool"
|
|
43
43
|
self._model_names = model_names
|
|
44
44
|
|
|
45
|
-
def __len__(self):
|
|
45
|
+
def __len__(self) -> int:
|
|
46
46
|
"""
|
|
47
47
|
Return the number of models in the model pool, exclude special models such as `_pretrained_`.
|
|
48
48
|
|
|
@@ -66,7 +66,7 @@ class ModelPool(ABC):
|
|
|
66
66
|
return names
|
|
67
67
|
|
|
68
68
|
@property
|
|
69
|
-
def has_pretrained(self):
|
|
69
|
+
def has_pretrained(self) -> bool:
|
|
70
70
|
"""
|
|
71
71
|
Check if the pretrained model is available in the model pool.
|
|
72
72
|
|
|
@@ -78,7 +78,7 @@ class ModelPool(ABC):
|
|
|
78
78
|
return True
|
|
79
79
|
return False
|
|
80
80
|
|
|
81
|
-
def get_model_config(self, model_name: str):
|
|
81
|
+
def get_model_config(self, model_name: str) -> Dict:
|
|
82
82
|
"""
|
|
83
83
|
Retrieves the configuration for a specific model from the model pool.
|
|
84
84
|
|
|
@@ -169,7 +169,7 @@ class CLIPImageClassificationTaskPool(TaskPool):
|
|
|
169
169
|
self._fabric = L.Fabric(devices=1)
|
|
170
170
|
self._fabric.launch()
|
|
171
171
|
|
|
172
|
-
# CLIPVisionModel works the same with
|
|
172
|
+
# CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
|
|
173
173
|
self.clip_model.vision_model = model
|
|
174
174
|
report = {}
|
|
175
175
|
training_params, all_params = count_parameters(model)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RuntimeConstants:
|
|
7
|
+
"""
|
|
8
|
+
This class holds constants related to the runtime environment of the Fusion Bench framework.
|
|
9
|
+
It includes default values for cache directories and other runtime configurations.
|
|
10
|
+
|
|
11
|
+
Implemented as a thread-safe singleton to ensure consistent runtime configuration
|
|
12
|
+
across the entire application.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
_instance: Optional["RuntimeConstants"] = None
|
|
16
|
+
_lock = threading.Lock()
|
|
17
|
+
|
|
18
|
+
def __new__(cls) -> "RuntimeConstants":
|
|
19
|
+
"""Create a new instance using singleton pattern with thread safety."""
|
|
20
|
+
with cls._lock:
|
|
21
|
+
# Double-check locking pattern
|
|
22
|
+
if cls._instance is None:
|
|
23
|
+
cls._instance = super(RuntimeConstants, cls).__new__(cls)
|
|
24
|
+
cls._instance._initialized = False
|
|
25
|
+
return cls._instance
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
"""Initialize the singleton instance only once."""
|
|
29
|
+
if not self._initialized:
|
|
30
|
+
# Add your runtime constants here
|
|
31
|
+
self._initialized = True
|
|
32
|
+
|
|
33
|
+
debug = False
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def cache_dir(self) -> Path:
|
|
37
|
+
from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
|
|
38
|
+
|
|
39
|
+
return DEFAULT_CACHE_DIR
|
|
40
|
+
|
|
41
|
+
@cache_dir.setter
|
|
42
|
+
def cache_dir(self, path: Union[str, Path]) -> None:
|
|
43
|
+
from fusion_bench.utils.cache_utils import set_default_cache_dir
|
|
44
|
+
|
|
45
|
+
set_default_cache_dir(path)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def print_function_call(self) -> bool:
|
|
49
|
+
from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
|
|
50
|
+
|
|
51
|
+
return PRINT_FUNCTION_CALL
|
|
52
|
+
|
|
53
|
+
@print_function_call.setter
|
|
54
|
+
def print_function_call(self, enable: bool) -> None:
|
|
55
|
+
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
56
|
+
|
|
57
|
+
set_print_function_call(enable)
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -30,7 +30,7 @@ _import_structure = {
|
|
|
30
30
|
"TaskArithmeticForLlama",
|
|
31
31
|
"LinearInterpolationAlgorithm",
|
|
32
32
|
],
|
|
33
|
-
"slerp": ["SlerpMergeAlgorithm"],
|
|
33
|
+
"slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
|
|
34
34
|
"simple_average": ["SimpleAverageAlgorithm"],
|
|
35
35
|
"weighted_average": ["WeightedAverageAlgorithm", "WeightedAverageForLLama"],
|
|
36
36
|
"task_arithmetic": ["TaskArithmeticAlgorithm"],
|
|
@@ -71,6 +71,7 @@ _import_structure = {
|
|
|
71
71
|
],
|
|
72
72
|
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
73
73
|
"tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
|
|
74
|
+
"model_stock": ["ModelStock"],
|
|
74
75
|
# plug-and-play model merging methods
|
|
75
76
|
"concrete_subspace": [
|
|
76
77
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -90,7 +91,10 @@ _import_structure = {
|
|
|
90
91
|
"MixtralForCausalLMMergingAlgorithm",
|
|
91
92
|
],
|
|
92
93
|
"dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
|
|
93
|
-
"we_moe": [
|
|
94
|
+
"we_moe": [
|
|
95
|
+
"CLIPWeightEnsemblingMoEAlgorithm",
|
|
96
|
+
"FlanT5WeightEnsemblingMoEAlgorithm",
|
|
97
|
+
],
|
|
94
98
|
"rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
|
|
95
99
|
"sparse_we_moe": [
|
|
96
100
|
"SparseWeightEnsemblingMoEAlgorithm",
|
|
@@ -191,6 +195,7 @@ if TYPE_CHECKING:
|
|
|
191
195
|
MixtralUpscalingAlgorithm,
|
|
192
196
|
)
|
|
193
197
|
from .model_recombination import ModelRecombinationAlgorithm
|
|
198
|
+
from .model_stock import ModelStock
|
|
194
199
|
from .opcm import OPCMForCLIP
|
|
195
200
|
from .pruning import (
|
|
196
201
|
MagnitudeDiffPruningAlgorithm,
|
|
@@ -210,7 +215,7 @@ if TYPE_CHECKING:
|
|
|
210
215
|
RegMeanAlgorithmPlusPlus,
|
|
211
216
|
)
|
|
212
217
|
from .simple_average import SimpleAverageAlgorithm
|
|
213
|
-
from .slerp import SlerpMergeAlgorithm
|
|
218
|
+
from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
|
|
214
219
|
from .smile_upscaling import (
|
|
215
220
|
SingularProjectionMergingAlgorithm,
|
|
216
221
|
SmileUpscalingAlgorithm,
|
|
@@ -228,7 +233,10 @@ if TYPE_CHECKING:
|
|
|
228
233
|
from .task_arithmetic import TaskArithmeticAlgorithm
|
|
229
234
|
from .task_singular_vector import TaskSingularVectorMerging
|
|
230
235
|
from .ties_merging import TiesMergingAlgorithm
|
|
231
|
-
from .we_moe import
|
|
236
|
+
from .we_moe import (
|
|
237
|
+
CLIPWeightEnsemblingMoEAlgorithm,
|
|
238
|
+
FlanT5WeightEnsemblingMoEAlgorithm,
|
|
239
|
+
)
|
|
232
240
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
233
241
|
|
|
234
242
|
else:
|
|
@@ -11,7 +11,7 @@ from torch import nn
|
|
|
11
11
|
from tqdm.auto import tqdm
|
|
12
12
|
|
|
13
13
|
from fusion_bench.method import BaseAlgorithm
|
|
14
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
14
|
+
from fusion_bench.mixins import LightningFabricMixin, auto_register_config
|
|
15
15
|
from fusion_bench.modelpool import BaseModelPool
|
|
16
16
|
from fusion_bench.utils.parameters import (
|
|
17
17
|
StateDictType,
|
|
@@ -23,14 +23,50 @@ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
|
23
23
|
log = logging.getLogger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
@auto_register_config
|
|
27
|
+
class TaskVectorCosSimilarity(
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
BaseAlgorithm,
|
|
30
|
+
):
|
|
27
31
|
"""
|
|
28
|
-
|
|
29
|
-
|
|
32
|
+
Computes and analyzes cosine similarity between task vectors of models in a model pool.
|
|
33
|
+
|
|
34
|
+
This algorithm extracts task vectors from fine-tuned models by computing the difference
|
|
35
|
+
between their parameters and a pretrained base model. It then calculates the pairwise
|
|
36
|
+
cosine similarity between all task vectors to understand the relationships and overlap
|
|
37
|
+
between different tasks.
|
|
38
|
+
|
|
39
|
+
The task vector for a model is defined as:
|
|
40
|
+
task_vector = finetuned_model_params - pretrained_model_params
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
plot_heatmap (bool): Whether to generate and save a heatmap visualization
|
|
44
|
+
trainable_only (bool, optional): If True, only consider trainable parameters
|
|
45
|
+
when computing task vectors. Defaults to True.
|
|
46
|
+
max_points_per_model (int, optional): Maximum number of parameters to sample
|
|
47
|
+
per model for memory efficiency. If None, uses all parameters.
|
|
48
|
+
output_path (str, optional): Directory to save outputs. If None, uses the
|
|
49
|
+
fabric logger directory.
|
|
50
|
+
|
|
51
|
+
Outputs:
|
|
52
|
+
- task_vector_cos_similarity.csv: Pairwise cosine similarity matrix
|
|
53
|
+
- task_vector_cos_similarity.pdf: Heatmap visualization (if plot_heatmap=True)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The pretrained model from the model pool.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
```python
|
|
60
|
+
>>> algorithm = TaskVectorCosSimilarity(
|
|
61
|
+
... plot_heatmap=True,
|
|
62
|
+
... trainable_only=True,
|
|
63
|
+
... output_path="/path/to/outputs"
|
|
64
|
+
... )
|
|
65
|
+
>>> result = algorithm.run(modelpool)
|
|
66
|
+
```
|
|
30
67
|
"""
|
|
31
68
|
|
|
32
69
|
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
33
|
-
"plot_heatmap": "plot_heatmap",
|
|
34
70
|
"_output_path": "output_path",
|
|
35
71
|
}
|
|
36
72
|
|
|
@@ -42,11 +78,8 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
42
78
|
output_path: Optional[str] = None,
|
|
43
79
|
**kwargs,
|
|
44
80
|
):
|
|
45
|
-
self.plot_heatmap = plot_heatmap
|
|
46
|
-
self.trainable_only = trainable_only
|
|
47
|
-
self.max_points_per_model = max_points_per_model
|
|
48
|
-
self._output_path = output_path
|
|
49
81
|
super().__init__(**kwargs)
|
|
82
|
+
self._output_path = output_path
|
|
50
83
|
|
|
51
84
|
@property
|
|
52
85
|
def output_path(self):
|
|
@@ -57,6 +90,22 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
57
90
|
|
|
58
91
|
@torch.no_grad()
|
|
59
92
|
def run(self, modelpool: BaseModelPool):
|
|
93
|
+
"""
|
|
94
|
+
Execute the task vector cosine similarity analysis.
|
|
95
|
+
|
|
96
|
+
This method:
|
|
97
|
+
1. Loads the pretrained base model from the model pool
|
|
98
|
+
2. Computes task vectors for each fine-tuned model
|
|
99
|
+
3. Calculates pairwise cosine similarities between all task vectors
|
|
100
|
+
4. Saves the similarity matrix as a CSV file
|
|
101
|
+
5. Optionally generates and saves a heatmap visualization
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
modelpool (BaseModelPool): Pool containing pretrained and fine-tuned models
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
nn.Module: The pretrained model from the model pool
|
|
108
|
+
"""
|
|
60
109
|
pretrained_model = modelpool.load_pretrained_model()
|
|
61
110
|
|
|
62
111
|
task_vectors = []
|
|
@@ -103,11 +152,14 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
103
152
|
|
|
104
153
|
def _plot_heatmap(self, data: pd.DataFrame):
|
|
105
154
|
"""
|
|
106
|
-
|
|
155
|
+
Generate and save a heatmap visualization of the cosine similarity matrix.
|
|
156
|
+
|
|
157
|
+
Creates a color-coded heatmap showing pairwise cosine similarities between
|
|
158
|
+
task vectors. The heatmap is saved as a PDF file in the output directory.
|
|
107
159
|
|
|
108
160
|
Args:
|
|
109
|
-
data (pd.DataFrame):
|
|
110
|
-
|
|
161
|
+
data (pd.DataFrame): Symmetric matrix of cosine similarities between
|
|
162
|
+
task vectors, with model names as both index and columns.
|
|
111
163
|
|
|
112
164
|
Returns:
|
|
113
165
|
None
|
|
@@ -141,6 +193,26 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
141
193
|
def get_task_vector(
|
|
142
194
|
self, pretrained_model: nn.Module, finetuned_model: nn.Module
|
|
143
195
|
) -> torch.Tensor:
|
|
196
|
+
"""
|
|
197
|
+
Compute the task vector for a fine-tuned model.
|
|
198
|
+
|
|
199
|
+
The task vector represents the parameter changes from pretraining to
|
|
200
|
+
fine-tuning and is computed as:
|
|
201
|
+
task_vector = finetuned_params - pretrained_params
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
pretrained_model (nn.Module): The base pretrained model
|
|
205
|
+
finetuned_model (nn.Module): The fine-tuned model for a specific task
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
torch.Tensor: Flattened task vector containing parameter differences.
|
|
209
|
+
If max_points_per_model is set, the vector may be downsampled.
|
|
210
|
+
|
|
211
|
+
Note:
|
|
212
|
+
- Converts parameters to float64 for numerical precision
|
|
213
|
+
- Supports optional downsampling for memory efficiency
|
|
214
|
+
- Uses only trainable parameters if trainable_only=True
|
|
215
|
+
"""
|
|
144
216
|
task_vector = state_dict_sub(
|
|
145
217
|
self.get_state_dict(finetuned_model),
|
|
146
218
|
self.get_state_dict(pretrained_model),
|
|
@@ -166,6 +238,17 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
|
|
|
166
238
|
return task_vector
|
|
167
239
|
|
|
168
240
|
def get_state_dict(self, model: nn.Module):
|
|
241
|
+
"""
|
|
242
|
+
Extract the state dictionary from a model.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model (nn.Module): The model to extract parameters from
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Dict[str, torch.Tensor]: State dictionary containing model parameters.
|
|
249
|
+
Returns only trainable parameters if trainable_only=True,
|
|
250
|
+
otherwise returns all parameters.
|
|
251
|
+
"""
|
|
169
252
|
if self.trainable_only:
|
|
170
253
|
return trainable_state_dict(model)
|
|
171
254
|
else:
|