fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__main__.py +4 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/__init__.py +26 -4
- fusion_bench/method/classification/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -3
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/dare/__init__.py +1 -0
- fusion_bench/method/dare/task_arithmetic.py +14 -7
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/slerp/slerp.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
- fusion_bench/method/ties_merging/ties_merging.py +10 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/mixins/clip_classification.py +4 -1
- fusion_bench/programs/fabric_fusion_program.py +22 -11
- fusion_bench/scripts/cli.py +1 -0
- fusion_bench/taskpool/base_pool.py +1 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/instantiate.py +7 -1
- fusion_bench/utils/json.py +30 -0
- fusion_bench/utils/parameters.py +27 -7
- fusion_bench/utils/path.py +15 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/rich_utils.py +15 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/tensorboard.py +51 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +2 -2
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
from transformers import CLIPVisionModel
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
+
from fusion_bench.method.ties_merging.ties_merging_utils import (
|
|
19
|
+
state_dict_to_vector,
|
|
20
|
+
ties_merging,
|
|
21
|
+
vector_to_state_dict,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
24
|
+
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
25
|
+
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
26
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
scaling_factor: float,
|
|
36
|
+
threshold: float,
|
|
37
|
+
remove_keys: Optional[List[str]] = None,
|
|
38
|
+
merge_func: Literal["sum", "mean", "max"] = "sum",
|
|
39
|
+
shuffle_order: bool = True,
|
|
40
|
+
seed: Optional[int] = None,
|
|
41
|
+
save_on_every_step: bool = True,
|
|
42
|
+
evaluate_on_every_step: bool = False,
|
|
43
|
+
**kwargs,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Continual Model Merging via Ties-Merging.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
scaling_factor (float): the scaling factor to use.
|
|
50
|
+
shuffle_order (bool): whether to shuffle the order of the models.
|
|
51
|
+
seed (Optional[int]): the seed to use.
|
|
52
|
+
save_on_every_step (bool): whether to save the merged model on every step.
|
|
53
|
+
evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
|
|
54
|
+
"""
|
|
55
|
+
self.scaling_factor = scaling_factor
|
|
56
|
+
self.threshold = threshold
|
|
57
|
+
self.remove_keys = remove_keys if remove_keys is not None else []
|
|
58
|
+
self.merge_func = merge_func
|
|
59
|
+
self.shuffle_order = shuffle_order
|
|
60
|
+
self.seed = seed
|
|
61
|
+
self.save_on_every_step = save_on_every_step
|
|
62
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
63
|
+
super().__init__(**kwargs)
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def run(self, modelpool: BaseModelPool):
|
|
67
|
+
if self.seed is not None:
|
|
68
|
+
L.seed_everything(self.seed)
|
|
69
|
+
|
|
70
|
+
model_names = modelpool.model_names
|
|
71
|
+
if self.shuffle_order:
|
|
72
|
+
random.shuffle(model_names)
|
|
73
|
+
|
|
74
|
+
self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
|
|
75
|
+
self._test_datasets = deepcopy(self.taskpool._test_datasets)
|
|
76
|
+
"""Configuration for the test datasets"""
|
|
77
|
+
|
|
78
|
+
# log the model names
|
|
79
|
+
if self.log_dir is not None:
|
|
80
|
+
save_to_json(model_names, Path(self.log_dir) / "model_names.json")
|
|
81
|
+
tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
|
|
82
|
+
tensorboard_summarywriter.add_text(
|
|
83
|
+
"global/model_names", str(model_names), global_step=0
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# get the average model
|
|
87
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
88
|
+
merged_model = deepcopy(pretrained_model)
|
|
89
|
+
|
|
90
|
+
for model_idx, model_name in tqdm(
|
|
91
|
+
enumerate(model_names), desc="Processing models"
|
|
92
|
+
):
|
|
93
|
+
task_model = modelpool.load_model(model_name)
|
|
94
|
+
|
|
95
|
+
task_vector = state_dict_sub(
|
|
96
|
+
task_model.state_dict(),
|
|
97
|
+
pretrained_model.state_dict(),
|
|
98
|
+
)
|
|
99
|
+
if model_idx == 0:
|
|
100
|
+
# if is the first model, the merged task vector is equal to the task vector
|
|
101
|
+
ties_merging_state_dict = task_vector
|
|
102
|
+
else:
|
|
103
|
+
# if is not the first model, we need to merge the task vector with the previous merged task vector
|
|
104
|
+
merged_tv = state_dict_sub(
|
|
105
|
+
merged_model.state_dict(),
|
|
106
|
+
pretrained_model.state_dict(),
|
|
107
|
+
)
|
|
108
|
+
tv_flat_checks = torch.vstack(
|
|
109
|
+
[
|
|
110
|
+
state_dict_to_vector(merged_tv, remove_keys=self.remove_keys),
|
|
111
|
+
state_dict_to_vector(task_vector, remove_keys=self.remove_keys),
|
|
112
|
+
]
|
|
113
|
+
)
|
|
114
|
+
# perform the TIES merging
|
|
115
|
+
ties_merging_tv = ties_merging(
|
|
116
|
+
tv_flat_checks,
|
|
117
|
+
reset_thresh=self.threshold,
|
|
118
|
+
merge_func=self.merge_func,
|
|
119
|
+
)
|
|
120
|
+
# convert the merged task vector back to a state dict
|
|
121
|
+
ties_merging_state_dict = vector_to_state_dict(
|
|
122
|
+
ties_merging_tv,
|
|
123
|
+
merged_model.state_dict(),
|
|
124
|
+
remove_keys=self.remove_keys,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
for param_name, param in task_model.named_parameters():
|
|
128
|
+
if not param.requires_grad:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
132
|
+
new_param = (
|
|
133
|
+
merged_param
|
|
134
|
+
+ self.scaling_factor * ties_merging_state_dict[param_name]
|
|
135
|
+
)
|
|
136
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
137
|
+
|
|
138
|
+
if self.save_on_every_step:
|
|
139
|
+
self.save_merged_model(merged_model, model_idx)
|
|
140
|
+
|
|
141
|
+
if self.evaluate_on_every_step:
|
|
142
|
+
self.taskpool._is_setup = False
|
|
143
|
+
self.taskpool._test_datasets = DictConfig(
|
|
144
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
145
|
+
)
|
|
146
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
147
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
148
|
+
|
|
149
|
+
return merged_model
|
|
150
|
+
|
|
151
|
+
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
152
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
153
|
+
torch.save(
|
|
154
|
+
merged_model.state_dict(),
|
|
155
|
+
Path(self.log_dir) / "checkpoints" / f"model_{step}.pth",
|
|
156
|
+
)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
7
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
11
|
+
"""
|
|
12
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
w (Tensor): The input tensor.
|
|
16
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
20
|
+
"""
|
|
21
|
+
u, s, vh = torch.linalg.svd(
|
|
22
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
23
|
+
)
|
|
24
|
+
v = vh.T
|
|
25
|
+
return u, s, v
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def svd(
|
|
29
|
+
w: Tensor, full_matrices=True, accelerator=None
|
|
30
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
31
|
+
"""
|
|
32
|
+
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
w (Tensor): The input tensor.
|
|
36
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
37
|
+
accelerator (str): The device to perform the computation on.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
41
|
+
"""
|
|
42
|
+
if accelerator is None:
|
|
43
|
+
return _svd(w, full_matrices=full_matrices)
|
|
44
|
+
original_device = w.device
|
|
45
|
+
w = w.to(accelerator)
|
|
46
|
+
u, s, v = _svd(w)
|
|
47
|
+
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
|
|
51
|
+
return torch.trace(w1.T @ w2)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
55
|
+
return len(list(module.children())) == 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
|
|
59
|
+
"""
|
|
60
|
+
Get the vector norm of the task model.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model (nn.Module): The task model.
|
|
64
|
+
pretrained_model (nn.Module): The pretrained model.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Tensor: The vector norm of the task model.
|
|
68
|
+
"""
|
|
69
|
+
return torch.linalg.norm(
|
|
70
|
+
state_dict_to_vector(
|
|
71
|
+
state_dict_sub(model.state_dict(), pretrained_model.state_dict())
|
|
72
|
+
)
|
|
73
|
+
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
from transformers import CLIPVisionModel
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
19
|
+
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
|
+
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ContinualWeightAverageForCLIP(
|
|
27
|
+
BaseAlgorithm,
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
shuffle_order: bool = True,
|
|
33
|
+
seed: Optional[int] = None,
|
|
34
|
+
save_on_every_step: bool = True,
|
|
35
|
+
evaluate_on_every_step: bool = False,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Continual Model Merging via Weight Average.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
shuffle_order (bool): whether to shuffle the order of the models.
|
|
43
|
+
seed (Optional[int]): the seed to use.
|
|
44
|
+
save_on_every_step (bool): whether to save the merged model on every step.
|
|
45
|
+
evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
|
|
46
|
+
"""
|
|
47
|
+
self.shuffle_order = shuffle_order
|
|
48
|
+
self.seed = seed
|
|
49
|
+
self.save_on_every_step = save_on_every_step
|
|
50
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
51
|
+
super().__init__(**kwargs)
|
|
52
|
+
|
|
53
|
+
def run(self, modelpool: BaseModelPool):
|
|
54
|
+
if self.seed is not None:
|
|
55
|
+
L.seed_everything(self.seed)
|
|
56
|
+
|
|
57
|
+
model_names = modelpool.model_names
|
|
58
|
+
if self.shuffle_order:
|
|
59
|
+
random.shuffle(model_names)
|
|
60
|
+
|
|
61
|
+
self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
|
|
62
|
+
self._test_datasets = deepcopy(self.taskpool._test_datasets)
|
|
63
|
+
"""Configuration for the test datasets"""
|
|
64
|
+
|
|
65
|
+
# log the model names
|
|
66
|
+
if self.log_dir is not None:
|
|
67
|
+
save_to_json(model_names, Path(self.log_dir) / "model_names.json")
|
|
68
|
+
tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
|
|
69
|
+
tensorboard_summarywriter.add_text(
|
|
70
|
+
"global/model_names", str(model_names), global_step=0
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# get the average model
|
|
74
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
75
|
+
|
|
76
|
+
if self.evaluate_on_every_step:
|
|
77
|
+
self.taskpool._is_setup = False
|
|
78
|
+
self.taskpool._test_datasets = DictConfig(
|
|
79
|
+
{model_names[0]: self._test_datasets[model_names[0]]}
|
|
80
|
+
)
|
|
81
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
82
|
+
save_to_json(report, Path(self.log_dir) / "report_0.json")
|
|
83
|
+
|
|
84
|
+
if self.save_on_every_step:
|
|
85
|
+
self.save_merged_model(merged_model, 0)
|
|
86
|
+
|
|
87
|
+
for model_idx, model_name in tqdm(
|
|
88
|
+
enumerate(model_names[1:]), desc="Processing models"
|
|
89
|
+
):
|
|
90
|
+
model_idx += 1
|
|
91
|
+
task_model = modelpool.load_model(model_name)
|
|
92
|
+
|
|
93
|
+
for param_name, param in task_model.named_parameters():
|
|
94
|
+
if not param.requires_grad:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
task_param = param
|
|
98
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
99
|
+
|
|
100
|
+
new_param = (merged_param * model_idx + task_param) / (model_idx + 1)
|
|
101
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
102
|
+
|
|
103
|
+
if self.save_on_every_step:
|
|
104
|
+
self.save_merged_model(merged_model, model_idx)
|
|
105
|
+
|
|
106
|
+
if self.evaluate_on_every_step:
|
|
107
|
+
self.taskpool._is_setup = False
|
|
108
|
+
self.taskpool._test_datasets = DictConfig(
|
|
109
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
110
|
+
)
|
|
111
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
112
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
113
|
+
|
|
114
|
+
return merged_model
|
|
115
|
+
|
|
116
|
+
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
117
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
118
|
+
merged_model.save_pretrained(
|
|
119
|
+
Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
|
|
120
|
+
)
|
|
@@ -51,7 +51,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
51
51
|
General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
|
|
52
52
|
"""
|
|
53
53
|
|
|
54
|
-
_config_mapping = BaseAlgorithm._config_mapping
|
|
54
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
55
55
|
"t": "t",
|
|
56
56
|
"DOT_THRESHOLD": "DOT_THRESHOLD",
|
|
57
57
|
"epsilon": "epsilon",
|
|
@@ -9,15 +9,20 @@ fusion_bench \
|
|
|
9
9
|
```
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from typing import List, Optional
|
|
12
|
+
from typing import List, Optional, Union, Iterable
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
from torch import Tensor, nn
|
|
16
|
+
from omegaconf import ListConfig
|
|
16
17
|
|
|
17
18
|
from fusion_bench import BaseAlgorithm
|
|
18
19
|
from fusion_bench.mixins import LightningFabricMixin
|
|
19
20
|
from fusion_bench.utils import timeit_context
|
|
20
|
-
from fusion_bench.utils.state_dict_arithmetic import
|
|
21
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
22
|
+
state_dict_add,
|
|
23
|
+
state_dict_sub,
|
|
24
|
+
state_dict_mul,
|
|
25
|
+
)
|
|
21
26
|
from fusion_bench.utils.type import StateDictType
|
|
22
27
|
|
|
23
28
|
from .utils import (
|
|
@@ -33,9 +38,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
33
38
|
|
|
34
39
|
def __init__(
|
|
35
40
|
self,
|
|
41
|
+
alpha: Union[float, Iterable[float]] = None,
|
|
36
42
|
remove_keys: Optional[List[str]] = None,
|
|
37
43
|
**kwargs,
|
|
38
44
|
):
|
|
45
|
+
self.alpha = alpha
|
|
39
46
|
self.remove_keys = remove_keys if remove_keys is not None else []
|
|
40
47
|
super().__init__(**kwargs)
|
|
41
48
|
|
|
@@ -50,6 +57,14 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
50
57
|
|
|
51
58
|
with timeit_context("Flattening out Checkpoints"):
|
|
52
59
|
task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
|
|
60
|
+
if isinstance(self.alpha, Iterable):
|
|
61
|
+
assert len(self.alpha) == len(
|
|
62
|
+
task_vectors
|
|
63
|
+
), "Alpha and task vectors must have the same length"
|
|
64
|
+
task_vectors = [
|
|
65
|
+
state_dict_mul(state_dict=tv, scalar=alpha)
|
|
66
|
+
for alpha, tv in zip(self.alpha, task_vectors)
|
|
67
|
+
]
|
|
53
68
|
|
|
54
69
|
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
55
70
|
task_vectors,
|
|
@@ -57,6 +72,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
57
72
|
accelerator=self.fabric.device,
|
|
58
73
|
)
|
|
59
74
|
|
|
75
|
+
# If alpha is a float, we need to scale the new merged task vector by alpha
|
|
76
|
+
if self.alpha is not None and isinstance(self.alpha, float):
|
|
77
|
+
print(f"Scaling new merged task vector by alpha: {self.alpha}")
|
|
78
|
+
new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
|
|
79
|
+
|
|
60
80
|
pretrained_model.load_state_dict(
|
|
61
81
|
state_dict_add(new_merged_tv, pretrained_model.state_dict())
|
|
62
82
|
)
|