fusion-bench 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/method/__init__.py +9 -1
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/mixins/lightning_fabric.py +2 -8
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +68 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +21 -16
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +165 -25
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +7 -7
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +41 -34
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
|
@@ -89,7 +89,10 @@ class RuntimeConstants:
|
|
|
89
89
|
self._initialized = True
|
|
90
90
|
|
|
91
91
|
debug = False
|
|
92
|
-
"""
|
|
92
|
+
"""
|
|
93
|
+
Global debug flag for enabling verbose logging and debugging features.
|
|
94
|
+
Use `RuntimeConstants().debug` instead of `RuntimeConstants.debug`
|
|
95
|
+
"""
|
|
93
96
|
|
|
94
97
|
@property
|
|
95
98
|
def cache_dir(self) -> Path:
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -144,7 +144,15 @@ _extra_objects = {
|
|
|
144
144
|
|
|
145
145
|
if TYPE_CHECKING:
|
|
146
146
|
from .ada_svd import AdaSVDMergingForCLIPVisionModel
|
|
147
|
-
from .adamerging import
|
|
147
|
+
from .adamerging import (
|
|
148
|
+
CLIPLayerWiseAdaMergingAlgorithm,
|
|
149
|
+
CLIPTaskWiseAdaMergingAlgorithm,
|
|
150
|
+
FlanT5LayerWiseAdaMergingAlgorithm,
|
|
151
|
+
GPT2LayerWiseAdaMergingAlgorithm,
|
|
152
|
+
LayerWiseAdaMergingForLlamaSFT,
|
|
153
|
+
ResNetLayerWiseAdamerging,
|
|
154
|
+
ResNetTaskWiseAdamerging,
|
|
155
|
+
)
|
|
148
156
|
from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
|
|
149
157
|
from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
150
158
|
from .bitdelta import BitDeltaAlgorithm
|
|
@@ -40,6 +40,7 @@ from typing import Optional # noqa: F401
|
|
|
40
40
|
|
|
41
41
|
from fusion_bench.mixins import BaseYAMLSerializable
|
|
42
42
|
from fusion_bench.modelpool import BaseModelPool
|
|
43
|
+
from fusion_bench.utils.misc import DeprecationWarningMeta
|
|
43
44
|
|
|
44
45
|
__all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
|
|
45
46
|
|
|
@@ -202,27 +203,36 @@ class BaseAlgorithm(BaseYAMLSerializable):
|
|
|
202
203
|
pass
|
|
203
204
|
|
|
204
205
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
# Create a deprecated wrapper class that inherits from BaseAlgorithm
|
|
207
|
+
class BaseModelFusionAlgorithm(BaseAlgorithm, metaclass=DeprecationWarningMeta):
|
|
208
|
+
"""
|
|
209
|
+
Alias for BaseAlgorithm class.
|
|
208
210
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
tasks, while others may prefer the shorter 'BaseAlgorithm' name.
|
|
211
|
+
.. deprecated::
|
|
212
|
+
BaseModelFusionAlgorithm is deprecated and will be removed in a future version.
|
|
213
|
+
Use :class:`BaseAlgorithm` instead.
|
|
213
214
|
|
|
214
|
-
|
|
215
|
+
This alias was provided for backward compatibility and semantic clarity.
|
|
216
|
+
Both names refer to the same base class and can be used interchangeably,
|
|
217
|
+
but BaseAlgorithm is now the preferred name for all implementations.
|
|
215
218
|
|
|
216
|
-
Examples:
|
|
217
|
-
|
|
218
|
-
>>> class MyAlgorithm(BaseAlgorithm):
|
|
219
|
-
... def run(self, modelpool): pass
|
|
219
|
+
Examples:
|
|
220
|
+
Preferred (using BaseAlgorithm):
|
|
220
221
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
... def run(self, modelpool): pass
|
|
222
|
+
>>> class MyAlgorithm(BaseAlgorithm):
|
|
223
|
+
... def run(self, modelpool): pass
|
|
224
224
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
225
|
+
Deprecated (using BaseModelFusionAlgorithm):
|
|
226
|
+
|
|
227
|
+
>>> class MyAlgorithm(BaseModelFusionAlgorithm): # Will trigger deprecation warning
|
|
228
|
+
... def run(self, modelpool): pass
|
|
229
|
+
|
|
230
|
+
Note:
|
|
231
|
+
New implementations should use :class:`BaseAlgorithm` exclusively.
|
|
232
|
+
The BaseModelFusionAlgorithm alias will be removed in a future release.
|
|
233
|
+
|
|
234
|
+
Warning:
|
|
235
|
+
Using BaseModelFusionAlgorithm will trigger a DeprecationWarning.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
pass
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench import (
|
|
11
|
+
BaseAlgorithm,
|
|
12
|
+
OpenCLIPClassificationMixin,
|
|
13
|
+
OpenCLIPVisionModelPool,
|
|
14
|
+
SimpleProfilerMixin,
|
|
15
|
+
StateDictType,
|
|
16
|
+
auto_register_config,
|
|
17
|
+
get_rankzero_logger,
|
|
18
|
+
instantiate,
|
|
19
|
+
)
|
|
20
|
+
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
|
|
21
|
+
from fusion_bench.method.task_singular_vector import TaskSingularVectorMerging
|
|
22
|
+
from fusion_bench.method.task_singular_vector.utils import (
|
|
23
|
+
TSVM_utils,
|
|
24
|
+
check_parameterNamesMatch,
|
|
25
|
+
check_state_dicts_equal,
|
|
26
|
+
state_dict_to_vector,
|
|
27
|
+
vector_to_state_dict,
|
|
28
|
+
)
|
|
29
|
+
from fusion_bench.models.masks import MaskModel, mask_sparsity
|
|
30
|
+
from fusion_bench.models.open_clip import (
|
|
31
|
+
ClassificationHead,
|
|
32
|
+
ImageClassifier,
|
|
33
|
+
ImageEncoder,
|
|
34
|
+
)
|
|
35
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
36
|
+
TaskWiseMergedModel,
|
|
37
|
+
get_task_wise_weights,
|
|
38
|
+
)
|
|
39
|
+
from fusion_bench.utils.devices import clear_cuda_cache
|
|
40
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
41
|
+
from fusion_bench.utils.parameters import print_parameters, print_trainable_parameters
|
|
42
|
+
from fusion_bench.utils.rich_utils import print_config_yaml
|
|
43
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
44
|
+
_validate_state_dict_same_keys,
|
|
45
|
+
state_dict_add,
|
|
46
|
+
state_dict_hadamard_product,
|
|
47
|
+
state_dict_mul,
|
|
48
|
+
state_dict_sub,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
log = get_rankzero_logger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@auto_register_config
|
|
55
|
+
class ConcreteTSVMForOpenCLIP(
|
|
56
|
+
OpenCLIPClassificationMixin,
|
|
57
|
+
SimpleProfilerMixin,
|
|
58
|
+
BaseAlgorithm,
|
|
59
|
+
):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
dataloader_kwargs: DictConfig,
|
|
63
|
+
optimizer: DictConfig,
|
|
64
|
+
lr_scheduler: DictConfig,
|
|
65
|
+
max_steps: int,
|
|
66
|
+
save_interval: int,
|
|
67
|
+
initial_logits: float,
|
|
68
|
+
temperature: float,
|
|
69
|
+
eval_mask_type: Literal["continuous", "discrete"],
|
|
70
|
+
mask_checkpoint: Optional[str],
|
|
71
|
+
merge_dtype: str,
|
|
72
|
+
clamp_weights: bool,
|
|
73
|
+
tie_weights: bool,
|
|
74
|
+
strict: bool,
|
|
75
|
+
skip_training: bool,
|
|
76
|
+
# === TSVM parameters ===
|
|
77
|
+
exclude_keys: Optional[List[str]],
|
|
78
|
+
alpha: float,
|
|
79
|
+
return_single_task_models: bool = True,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
super().__init__(**kwargs)
|
|
83
|
+
if not return_single_task_models:
|
|
84
|
+
log.warning("return_single_task_models is forced to be True here.")
|
|
85
|
+
self.return_single_task_models = True
|
|
86
|
+
|
|
87
|
+
@torch.no_grad()
|
|
88
|
+
def setup_models(self):
|
|
89
|
+
"""
|
|
90
|
+
load the pre-trained model, task vectors, and construct the mask model.
|
|
91
|
+
"""
|
|
92
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
93
|
+
modelpool = self.modelpool
|
|
94
|
+
|
|
95
|
+
# load the pre-trained model
|
|
96
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
97
|
+
self.set_clip_processor(stage="test", processor=pretrained_model.val_preprocess)
|
|
98
|
+
|
|
99
|
+
# constrcute mask model
|
|
100
|
+
mask_model = MaskModel(
|
|
101
|
+
pretrained_model, ignore_untrained_params=True, parameter_type="logits"
|
|
102
|
+
)
|
|
103
|
+
if merge_dtype is not None:
|
|
104
|
+
mask_model.to(merge_dtype)
|
|
105
|
+
mask_model.fill_(self.initial_logits)
|
|
106
|
+
|
|
107
|
+
if self.fabric.is_global_zero:
|
|
108
|
+
print("summary of mask model:")
|
|
109
|
+
print_parameters(mask_model)
|
|
110
|
+
|
|
111
|
+
if self.fabric.is_global_zero:
|
|
112
|
+
tsvm_algo = TaskSingularVectorMerging(
|
|
113
|
+
alpha=self.alpha,
|
|
114
|
+
exclude_keys=self.exclude_keys,
|
|
115
|
+
return_single_task_models=self.return_single_task_models,
|
|
116
|
+
)
|
|
117
|
+
tsvm_algo._fabric_instance = self.fabric
|
|
118
|
+
models = tsvm_algo.run(modelpool)
|
|
119
|
+
|
|
120
|
+
finetuned_models = [models[name] for name in modelpool.model_names]
|
|
121
|
+
|
|
122
|
+
task_wise_weight = get_task_wise_weights(
|
|
123
|
+
num_models=len(modelpool.model_names),
|
|
124
|
+
init_values=self.alpha,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# create a wrapped model
|
|
128
|
+
module = TaskWiseMergedModel(
|
|
129
|
+
task_wise_weight=task_wise_weight,
|
|
130
|
+
pretrained_model=pretrained_model,
|
|
131
|
+
finetuned_models=finetuned_models,
|
|
132
|
+
clamp_weights=self.clamp_weights,
|
|
133
|
+
tie_weights=self.tie_weights,
|
|
134
|
+
strict=self.strict,
|
|
135
|
+
task_vector_dtype=merge_dtype,
|
|
136
|
+
)
|
|
137
|
+
module = module.to(dtype=merge_dtype)
|
|
138
|
+
|
|
139
|
+
print("trainable parameter summary of merged model (TaskWiseMergedModel):")
|
|
140
|
+
print_trainable_parameters(module)
|
|
141
|
+
else:
|
|
142
|
+
module = None
|
|
143
|
+
|
|
144
|
+
with torch.no_grad():
|
|
145
|
+
self.fabric.barrier()
|
|
146
|
+
module = self.fabric.broadcast(module, src=0)
|
|
147
|
+
|
|
148
|
+
return module, mask_model
|
|
149
|
+
|
|
150
|
+
def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
|
|
151
|
+
"""
|
|
152
|
+
Train the mask model using the provided module.
|
|
153
|
+
|
|
154
|
+
This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
module (TaskWiseMergedModel): The wrapped model with task-wise weights.
|
|
158
|
+
mask_model (MaskModel): The mask model to be trained.
|
|
159
|
+
"""
|
|
160
|
+
config = self.config
|
|
161
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
162
|
+
log.info(f"Using merge dtype: {merge_dtype}")
|
|
163
|
+
|
|
164
|
+
optimizer: "torch.optim.Optimizer" = instantiate(
|
|
165
|
+
self.optimizer,
|
|
166
|
+
params=filter(lambda p: p.requires_grad, mask_model.parameters()),
|
|
167
|
+
)
|
|
168
|
+
print(f"{optimizer=}")
|
|
169
|
+
if self.lr_scheduler is not None:
|
|
170
|
+
lr_scheduler = instantiate(
|
|
171
|
+
self.lr_scheduler,
|
|
172
|
+
optimizer=optimizer,
|
|
173
|
+
)
|
|
174
|
+
print(f"{lr_scheduler=}")
|
|
175
|
+
else:
|
|
176
|
+
lr_scheduler = None
|
|
177
|
+
|
|
178
|
+
log.info("Setup models and optimizer with Fabric.")
|
|
179
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
180
|
+
|
|
181
|
+
log.info("Move the merged module to the correct device and disable gradients.")
|
|
182
|
+
module.requires_grad_(False)
|
|
183
|
+
module.to(mask_model.device)
|
|
184
|
+
|
|
185
|
+
mask_model.train()
|
|
186
|
+
optimizer.zero_grad()
|
|
187
|
+
for step_idx in (
|
|
188
|
+
pbar := tqdm(
|
|
189
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
190
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
191
|
+
+ "Concrete TSVM Test-Time Adaptation",
|
|
192
|
+
dynamic_ncols=True,
|
|
193
|
+
disable=not self.fabric.is_global_zero,
|
|
194
|
+
)
|
|
195
|
+
):
|
|
196
|
+
metrics = {}
|
|
197
|
+
# sample a shared mask and merge weights
|
|
198
|
+
with self.profile("sample mask"):
|
|
199
|
+
mask = mask_model.sample_mask(
|
|
200
|
+
mask_type="continuous", temperature=config.temperature
|
|
201
|
+
)
|
|
202
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
203
|
+
with self.profile("merge weights"):
|
|
204
|
+
# rescale mask
|
|
205
|
+
for name, m in mask.items():
|
|
206
|
+
mask[name] = m / torch.mean(m)
|
|
207
|
+
module.merge_weights(task_vector_mask=mask)
|
|
208
|
+
|
|
209
|
+
# ------ inner optimization goes here ------
|
|
210
|
+
# NOTE:
|
|
211
|
+
# Because the algorithmic parameters of TSVM are assumed to be chosen on a validation test
|
|
212
|
+
# set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
|
|
213
|
+
# ------------------------------------------
|
|
214
|
+
|
|
215
|
+
total_loss = None
|
|
216
|
+
for task in self.modelpool.model_names:
|
|
217
|
+
with self.profile("data loading"):
|
|
218
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
219
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
220
|
+
images = batch[0].to(dtype=merge_dtype)
|
|
221
|
+
with self.profile("forward pass"):
|
|
222
|
+
logits = self.compute_logits(module, images, task)
|
|
223
|
+
loss = entropy_loss(logits)
|
|
224
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
225
|
+
|
|
226
|
+
with self.profile("compute grad"):
|
|
227
|
+
self.fabric.backward(total_loss)
|
|
228
|
+
|
|
229
|
+
with self.profile("optimizer step"):
|
|
230
|
+
optimizer.step()
|
|
231
|
+
optimizer.zero_grad()
|
|
232
|
+
|
|
233
|
+
if lr_scheduler is not None:
|
|
234
|
+
lr_scheduler.step()
|
|
235
|
+
|
|
236
|
+
metrics.update({"train/loss": loss.item()})
|
|
237
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
238
|
+
pbar.set_postfix(metrics)
|
|
239
|
+
|
|
240
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
241
|
+
with self.profiler.profile("save checkpoint"):
|
|
242
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
243
|
+
if not os.path.exists(save_dir):
|
|
244
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
245
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
246
|
+
print(f"saving checkpoint to {save_path}")
|
|
247
|
+
state = {"model": mask_model}
|
|
248
|
+
self.fabric.save(save_path, state)
|
|
249
|
+
|
|
250
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
251
|
+
if self.fabric.is_global_zero:
|
|
252
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
253
|
+
if os.path.exists(symlink_path):
|
|
254
|
+
os.remove(symlink_path)
|
|
255
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
256
|
+
|
|
257
|
+
self.print_profile_summary()
|
|
258
|
+
|
|
259
|
+
def run(self, modelpool: OpenCLIPVisionModelPool):
|
|
260
|
+
self.modelpool = modelpool
|
|
261
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
262
|
+
|
|
263
|
+
with self.profile("setup models"):
|
|
264
|
+
module, mask_model = self.setup_models()
|
|
265
|
+
self.setup_zero_shot_classification_head(freeze=True, dtype=merge_dtype)
|
|
266
|
+
|
|
267
|
+
if self.mask_checkpoint is None:
|
|
268
|
+
if not self.skip_training:
|
|
269
|
+
clear_cuda_cache()
|
|
270
|
+
self.train_mask(module, mask_model=mask_model)
|
|
271
|
+
else:
|
|
272
|
+
if self.fabric.is_global_zero:
|
|
273
|
+
print("loading mask from checkpoint", self.mask_checkpoint)
|
|
274
|
+
self.fabric.load(self.mask_checkpoint, {"model": mask_model})
|
|
275
|
+
|
|
276
|
+
with torch.no_grad():
|
|
277
|
+
clear_cuda_cache()
|
|
278
|
+
mask = mask_model.sample_mask(
|
|
279
|
+
mask_type=self.eval_mask_type, temperature=self.temperature
|
|
280
|
+
)
|
|
281
|
+
# rescale mask
|
|
282
|
+
for name, m in mask.items():
|
|
283
|
+
mask[name] = m / torch.mean(m)
|
|
284
|
+
model = module.merge_and_unload(mask)
|
|
285
|
+
return model.to(dtype=torch.float32)
|
|
@@ -249,12 +249,13 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
249
249
|
# - SVD finds the principal components (most important directions)
|
|
250
250
|
# - Task vectors are reconstructed using only the most significant components
|
|
251
251
|
# - The reconstructed vectors are merged (summed) to create a unified task vector
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
252
|
+
with torch.no_grad():
|
|
253
|
+
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
254
|
+
task_vectors,
|
|
255
|
+
exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
|
|
256
|
+
accelerator=accelerator, # Use GPU if available
|
|
257
|
+
return_single_task_models=self.return_single_task_models,
|
|
258
|
+
)
|
|
258
259
|
|
|
259
260
|
# Handle the case where individual transformed task vectors are also returned
|
|
260
261
|
if self.return_single_task_models:
|
|
@@ -311,7 +311,6 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(
|
|
|
311
311
|
|
|
312
312
|
###############
|
|
313
313
|
#### TSV Merge Orthogonalization
|
|
314
|
-
@torch.no_grad()
|
|
315
314
|
def compute_and_sum_svd_mem_reduction(
|
|
316
315
|
task_vectors: List[StateDictType],
|
|
317
316
|
exclude_keys: Optional[List[str]] = None,
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .utility import Metric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def cosine_similarity(a, b):
|
|
11
|
+
similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
|
|
12
|
+
return similarity
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def calculate_model_kinship(
|
|
16
|
+
delta1: numpy.ndarray, delta2: numpy.ndarray, metrics: List[str]
|
|
17
|
+
) -> dict:
|
|
18
|
+
"""
|
|
19
|
+
Calculate model kinship using specified metrics.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
delta1: Delta parameters for first model
|
|
23
|
+
delta2: Delta parameters for second model
|
|
24
|
+
metrics: List of metrics to calculate
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
dict: Dictionary of metric names and their calculated values
|
|
28
|
+
"""
|
|
29
|
+
results = {}
|
|
30
|
+
for metric in metrics:
|
|
31
|
+
try:
|
|
32
|
+
if metric not in Metric.list():
|
|
33
|
+
raise ValueError(f"Unsupported metric: {metric}")
|
|
34
|
+
results[metric] = calculate_metric(delta1, delta2, metric)
|
|
35
|
+
except Exception as e:
|
|
36
|
+
results[metric] = f"Error calculating {metric}: {str(e)}"
|
|
37
|
+
return results
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def calculate_metric(
|
|
41
|
+
d_vector_1: torch.Tensor, d_vector_2: torch.Tensor, metric: str
|
|
42
|
+
) -> str:
|
|
43
|
+
"""
|
|
44
|
+
Calculate the specified metric between two delta vectors.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
d_vector_1 (torch.Tensor): Delta parameters for model 1.
|
|
48
|
+
d_vector_2 (torch.Tensor): Delta parameters for model 2.
|
|
49
|
+
metric (str): The metric to calculate ('pcc', 'ed', 'cs').
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
str: A formatted string with the result of the chosen metric.
|
|
53
|
+
"""
|
|
54
|
+
logging.info(f"Starting calculation of {metric.upper()} metric...")
|
|
55
|
+
|
|
56
|
+
# Pearson Correlation Coefficient (PCC)
|
|
57
|
+
if metric == "pcc":
|
|
58
|
+
# Stack the two vectors and calculate the Pearson correlation coefficient
|
|
59
|
+
stack = torch.stack((d_vector_1, d_vector_2), dim=0)
|
|
60
|
+
pcc = torch.corrcoef(stack)[0, 1].item()
|
|
61
|
+
return f"Model Kinship based on Pearson Correlation Coefficient: {pcc}"
|
|
62
|
+
|
|
63
|
+
# Euclidean Distance (ED)
|
|
64
|
+
elif metric == "ed":
|
|
65
|
+
# Compute the Euclidean distance between the vectors
|
|
66
|
+
distance = torch.dist(d_vector_1, d_vector_2).item()
|
|
67
|
+
return f"Model Kinship based on Euclidean Distance: {distance}"
|
|
68
|
+
|
|
69
|
+
# Cosine Similarity (CS)
|
|
70
|
+
elif metric == "cs":
|
|
71
|
+
# Compute cosine similarity
|
|
72
|
+
cs = cosine_similarity(d_vector_1, d_vector_2)
|
|
73
|
+
return f"Model Kinship based on Cosine Similarity: {cs}"
|
|
74
|
+
|
|
75
|
+
# If metric is not recognized
|
|
76
|
+
else:
|
|
77
|
+
return "Invalid metric specified."
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import numpy
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from .utility import Metric, load_model_state_dict, quantize_8bit
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def cosine_similarity(a, b):
|
|
12
|
+
similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
|
|
13
|
+
return similarity
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def calculate_model_kinship_split(
|
|
17
|
+
model_1_name: str,
|
|
18
|
+
model_2_name: str,
|
|
19
|
+
model_base_name: str,
|
|
20
|
+
low_precision: bool,
|
|
21
|
+
metrics: List[str],
|
|
22
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
23
|
+
) -> dict:
|
|
24
|
+
|
|
25
|
+
# Extract state dictionaries from models
|
|
26
|
+
state_dict_1 = load_model_state_dict(model_1_name, device)
|
|
27
|
+
state_dict_2 = load_model_state_dict(model_2_name, device)
|
|
28
|
+
state_dict_base = load_model_state_dict(model_base_name, device)
|
|
29
|
+
results = {}
|
|
30
|
+
|
|
31
|
+
# Validate metrics before processing
|
|
32
|
+
valid_metrics = Metric.list()
|
|
33
|
+
for metric in metrics:
|
|
34
|
+
try:
|
|
35
|
+
if metric not in valid_metrics:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Unsupported metric: {metric}. Valid metrics are: {', '.join(valid_metrics)}"
|
|
38
|
+
)
|
|
39
|
+
results[metric] = calculate_metrics_by_split(
|
|
40
|
+
state_dict_1, state_dict_2, state_dict_base, low_precision, metric
|
|
41
|
+
)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logging.error(f"Error calculating {metric}: {str(e)}")
|
|
44
|
+
results[metric] = f"Error calculating {metric}: {str(e)}"
|
|
45
|
+
|
|
46
|
+
return results
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def calculate_metrics_by_split(
|
|
50
|
+
state_dict_1: dict,
|
|
51
|
+
state_dict_2: dict,
|
|
52
|
+
state_dict_base: dict,
|
|
53
|
+
low_precision: bool,
|
|
54
|
+
metric: str,
|
|
55
|
+
) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Calculate metrics for each key and integrate results.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state_dict_1 (dict): State dictionary of first model
|
|
61
|
+
state_dict_2 (dict): State dictionary of second model
|
|
62
|
+
state_dict_base (dict): State dictionary of base model
|
|
63
|
+
low_precision (bool): Whether to use 8-bit quantization
|
|
64
|
+
metric (str): Metric to calculate ('pcc', 'ed', 'cs')
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
str: Integrated metric result as formatted string
|
|
68
|
+
"""
|
|
69
|
+
total_similarity = 0.0
|
|
70
|
+
total_weight = 0.0
|
|
71
|
+
split_results = {}
|
|
72
|
+
|
|
73
|
+
# Determine the number of layers
|
|
74
|
+
num_layers = state_dict_base["lm_head.weight"].shape[0]
|
|
75
|
+
|
|
76
|
+
# Check architectures
|
|
77
|
+
if (
|
|
78
|
+
state_dict_1["lm_head.weight"].shape[0]
|
|
79
|
+
!= state_dict_2["lm_head.weight"].shape[0]
|
|
80
|
+
):
|
|
81
|
+
shape_1 = state_dict_1["lm_head.weight"].shape
|
|
82
|
+
shape_2 = state_dict_2["lm_head.weight"].shape
|
|
83
|
+
logging.warning(
|
|
84
|
+
f"Warning: Model architectures do not match. "
|
|
85
|
+
f"Using sub weight space instead.\n"
|
|
86
|
+
f"Vocab sizes in model 1: {shape_1[0]}, "
|
|
87
|
+
f"Vocab sizes in model 2: {shape_2[0]}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Process each key
|
|
91
|
+
for key, base_params in tqdm(
|
|
92
|
+
state_dict_base.items(), desc=f"Processing {metric.upper()} by key"
|
|
93
|
+
):
|
|
94
|
+
try:
|
|
95
|
+
if key not in state_dict_1 or key not in state_dict_2:
|
|
96
|
+
logging.warning(f"Key {key} not found in one of the models")
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# Get parameters and calculate deltas
|
|
100
|
+
params_1 = state_dict_1[key][:num_layers]
|
|
101
|
+
params_2 = state_dict_2[key][:num_layers]
|
|
102
|
+
|
|
103
|
+
delta_1 = (params_1 - base_params).view(-1)
|
|
104
|
+
delta_2 = (params_2 - base_params).view(-1)
|
|
105
|
+
|
|
106
|
+
if low_precision:
|
|
107
|
+
delta_1 = quantize_8bit(delta_1)
|
|
108
|
+
delta_2 = quantize_8bit(delta_2)
|
|
109
|
+
|
|
110
|
+
# Calculate weight based on parameter count
|
|
111
|
+
weight = delta_1.numel()
|
|
112
|
+
|
|
113
|
+
# Calculate metric for current key
|
|
114
|
+
if metric == "pcc":
|
|
115
|
+
stack = torch.stack((delta_1, delta_2), dim=0)
|
|
116
|
+
split_similarity = torch.corrcoef(stack)[0, 1].item()
|
|
117
|
+
elif metric == "ed":
|
|
118
|
+
split_similarity = torch.dist(delta_1, delta_2).item()
|
|
119
|
+
elif metric == "cs":
|
|
120
|
+
split_similarity = cosine_similarity(delta_1, delta_2)
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(f"Unsupported metric: {metric}")
|
|
123
|
+
|
|
124
|
+
# Skip NaN values
|
|
125
|
+
if torch.isnan(torch.tensor(split_similarity)):
|
|
126
|
+
logging.warning(f"Skipping key {key} due to NaN result")
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
# Store valid result
|
|
130
|
+
split_results[key] = split_similarity
|
|
131
|
+
|
|
132
|
+
# Update weighted average only for valid results
|
|
133
|
+
weight = delta_1.numel()
|
|
134
|
+
total_similarity += split_similarity * weight
|
|
135
|
+
total_weight += weight
|
|
136
|
+
|
|
137
|
+
# Log progress for large layers
|
|
138
|
+
if weight > 1000000:
|
|
139
|
+
logging.info(
|
|
140
|
+
f"Layer {key}: {metric.upper()} = {split_similarity:.4f}, parameters = {weight}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Free memory
|
|
144
|
+
del delta_1, delta_2
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logging.error(f"Error processing key {key}: {str(e)}")
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
# Calculate final weighted average
|
|
151
|
+
if total_weight > 0:
|
|
152
|
+
final_result = total_similarity / total_weight
|
|
153
|
+
|
|
154
|
+
# Log summary statistics
|
|
155
|
+
logging.info(f"\nSummary for {metric.upper()}:")
|
|
156
|
+
logging.info(f"Total parameters: {total_weight}")
|
|
157
|
+
|
|
158
|
+
# Log detailed results for valid splits
|
|
159
|
+
logging.info(f"\nDetailed {metric.upper()} results by key:")
|
|
160
|
+
for key, value in split_results.items():
|
|
161
|
+
logging.info(f"{key}: {value:.4f}")
|
|
162
|
+
|
|
163
|
+
metric_names = {
|
|
164
|
+
"pcc": "Pearson Correlation Coefficient",
|
|
165
|
+
"ed": "Euclidean Distance",
|
|
166
|
+
"cs": "Cosine Similarity",
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
return f"Model Kinship based on {metric_names[metric]} (weighted average): {final_result:.4f}"
|
|
170
|
+
else:
|
|
171
|
+
return f"Error: No valid parameters found for {metric.upper()} calculation"
|