fusion-bench 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fusion_bench/constants/runtime.py +4 -1
  2. fusion_bench/method/__init__.py +9 -1
  3. fusion_bench/method/base_algorithm.py +29 -19
  4. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  5. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  6. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  7. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  8. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  9. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  10. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  11. fusion_bench/metrics/model_kinship/utility.py +184 -0
  12. fusion_bench/mixins/lightning_fabric.py +2 -8
  13. fusion_bench/mixins/openclip_classification.py +155 -1
  14. fusion_bench/modelpool/base_pool.py +1 -0
  15. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  16. fusion_bench/models/masks/mask_model.py +8 -2
  17. fusion_bench/models/open_clip/modeling.py +68 -5
  18. fusion_bench/models/open_clip/utils.py +13 -2
  19. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  20. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  21. fusion_bench/py.typed +1 -0
  22. fusion_bench/scripts/cli.py +21 -16
  23. fusion_bench/scripts/imgui.py +2 -2
  24. fusion_bench/scripts/webui.py +2 -2
  25. fusion_bench/utils/__init__.py +2 -0
  26. fusion_bench/utils/devices.py +3 -1
  27. fusion_bench/utils/hydra_utils.py +75 -0
  28. fusion_bench/utils/instantiate_utils.py +29 -18
  29. fusion_bench/utils/misc.py +16 -0
  30. fusion_bench/utils/parameters.py +33 -0
  31. fusion_bench/utils/rich_utils.py +165 -25
  32. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +7 -7
  33. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +41 -34
  34. fusion_bench_config/README.md +9 -0
  35. fusion_bench_config/fabric/auto.yaml +1 -0
  36. fusion_bench_config/hydra/default.yaml +3 -1
  37. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  38. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
  39. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
  40. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
  41. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
@@ -89,7 +89,10 @@ class RuntimeConstants:
89
89
  self._initialized = True
90
90
 
91
91
  debug = False
92
- """Global debug flag for enabling verbose logging and debugging features."""
92
+ """
93
+ Global debug flag for enabling verbose logging and debugging features.
94
+ Use `RuntimeConstants().debug` instead of `RuntimeConstants.debug`
95
+ """
93
96
 
94
97
  @property
95
98
  def cache_dir(self) -> Path:
@@ -144,7 +144,15 @@ _extra_objects = {
144
144
 
145
145
  if TYPE_CHECKING:
146
146
  from .ada_svd import AdaSVDMergingForCLIPVisionModel
147
- from .adamerging import *
147
+ from .adamerging import (
148
+ CLIPLayerWiseAdaMergingAlgorithm,
149
+ CLIPTaskWiseAdaMergingAlgorithm,
150
+ FlanT5LayerWiseAdaMergingAlgorithm,
151
+ GPT2LayerWiseAdaMergingAlgorithm,
152
+ LayerWiseAdaMergingForLlamaSFT,
153
+ ResNetLayerWiseAdamerging,
154
+ ResNetTaskWiseAdamerging,
155
+ )
148
156
  from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
149
157
  from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
150
158
  from .bitdelta import BitDeltaAlgorithm
@@ -40,6 +40,7 @@ from typing import Optional # noqa: F401
40
40
 
41
41
  from fusion_bench.mixins import BaseYAMLSerializable
42
42
  from fusion_bench.modelpool import BaseModelPool
43
+ from fusion_bench.utils.misc import DeprecationWarningMeta
43
44
 
44
45
  __all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
45
46
 
@@ -202,27 +203,36 @@ class BaseAlgorithm(BaseYAMLSerializable):
202
203
  pass
203
204
 
204
205
 
205
- BaseModelFusionAlgorithm = BaseAlgorithm
206
- """
207
- Alias for BaseAlgorithm class.
206
+ # Create a deprecated wrapper class that inherits from BaseAlgorithm
207
+ class BaseModelFusionAlgorithm(BaseAlgorithm, metaclass=DeprecationWarningMeta):
208
+ """
209
+ Alias for BaseAlgorithm class.
208
210
 
209
- This alias is provided for backward compatibility and semantic clarity.
210
- Some users may prefer the more explicit name 'BaseModelFusionAlgorithm'
211
- to emphasize that this class is specifically designed for model fusion
212
- tasks, while others may prefer the shorter 'BaseAlgorithm' name.
211
+ .. deprecated::
212
+ BaseModelFusionAlgorithm is deprecated and will be removed in a future version.
213
+ Use :class:`BaseAlgorithm` instead.
213
214
 
214
- Both names refer to the exact same class and can be used interchangeably.
215
+ This alias was provided for backward compatibility and semantic clarity.
216
+ Both names refer to the same base class and can be used interchangeably,
217
+ but BaseAlgorithm is now the preferred name for all implementations.
215
218
 
216
- Examples:
217
- Using the original name:
218
- >>> class MyAlgorithm(BaseAlgorithm):
219
- ... def run(self, modelpool): pass
219
+ Examples:
220
+ Preferred (using BaseAlgorithm):
220
221
 
221
- Using the alias:
222
- >>> class MyAlgorithm(BaseModelFusionAlgorithm):
223
- ... def run(self, modelpool): pass
222
+ >>> class MyAlgorithm(BaseAlgorithm):
223
+ ... def run(self, modelpool): pass
224
224
 
225
- Note:
226
- The alias is maintained for compatibility but BaseAlgorithm is the
227
- preferred name for new implementations.
228
- """
225
+ Deprecated (using BaseModelFusionAlgorithm):
226
+
227
+ >>> class MyAlgorithm(BaseModelFusionAlgorithm): # Will trigger deprecation warning
228
+ ... def run(self, modelpool): pass
229
+
230
+ Note:
231
+ New implementations should use :class:`BaseAlgorithm` exclusively.
232
+ The BaseModelFusionAlgorithm alias will be removed in a future release.
233
+
234
+ Warning:
235
+ Using BaseModelFusionAlgorithm will trigger a DeprecationWarning.
236
+ """
237
+
238
+ pass
@@ -173,6 +173,7 @@ class ImageClassificationFineTuning(BaseAlgorithm):
173
173
  ),
174
174
  },
175
175
  )
176
+ lit_module.train()
176
177
 
177
178
  log_dir = (
178
179
  self._program.path.log_dir
@@ -0,0 +1,285 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from tqdm import tqdm
9
+
10
+ from fusion_bench import (
11
+ BaseAlgorithm,
12
+ OpenCLIPClassificationMixin,
13
+ OpenCLIPVisionModelPool,
14
+ SimpleProfilerMixin,
15
+ StateDictType,
16
+ auto_register_config,
17
+ get_rankzero_logger,
18
+ instantiate,
19
+ )
20
+ from fusion_bench.method.adamerging.entropy_loss import entropy_loss
21
+ from fusion_bench.method.task_singular_vector import TaskSingularVectorMerging
22
+ from fusion_bench.method.task_singular_vector.utils import (
23
+ TSVM_utils,
24
+ check_parameterNamesMatch,
25
+ check_state_dicts_equal,
26
+ state_dict_to_vector,
27
+ vector_to_state_dict,
28
+ )
29
+ from fusion_bench.models.masks import MaskModel, mask_sparsity
30
+ from fusion_bench.models.open_clip import (
31
+ ClassificationHead,
32
+ ImageClassifier,
33
+ ImageEncoder,
34
+ )
35
+ from fusion_bench.models.wrappers.task_wise_fusion import (
36
+ TaskWiseMergedModel,
37
+ get_task_wise_weights,
38
+ )
39
+ from fusion_bench.utils.devices import clear_cuda_cache
40
+ from fusion_bench.utils.dtype import parse_dtype
41
+ from fusion_bench.utils.parameters import print_parameters, print_trainable_parameters
42
+ from fusion_bench.utils.rich_utils import print_config_yaml
43
+ from fusion_bench.utils.state_dict_arithmetic import (
44
+ _validate_state_dict_same_keys,
45
+ state_dict_add,
46
+ state_dict_hadamard_product,
47
+ state_dict_mul,
48
+ state_dict_sub,
49
+ )
50
+
51
+ log = get_rankzero_logger(__name__)
52
+
53
+
54
+ @auto_register_config
55
+ class ConcreteTSVMForOpenCLIP(
56
+ OpenCLIPClassificationMixin,
57
+ SimpleProfilerMixin,
58
+ BaseAlgorithm,
59
+ ):
60
+ def __init__(
61
+ self,
62
+ dataloader_kwargs: DictConfig,
63
+ optimizer: DictConfig,
64
+ lr_scheduler: DictConfig,
65
+ max_steps: int,
66
+ save_interval: int,
67
+ initial_logits: float,
68
+ temperature: float,
69
+ eval_mask_type: Literal["continuous", "discrete"],
70
+ mask_checkpoint: Optional[str],
71
+ merge_dtype: str,
72
+ clamp_weights: bool,
73
+ tie_weights: bool,
74
+ strict: bool,
75
+ skip_training: bool,
76
+ # === TSVM parameters ===
77
+ exclude_keys: Optional[List[str]],
78
+ alpha: float,
79
+ return_single_task_models: bool = True,
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ if not return_single_task_models:
84
+ log.warning("return_single_task_models is forced to be True here.")
85
+ self.return_single_task_models = True
86
+
87
+ @torch.no_grad()
88
+ def setup_models(self):
89
+ """
90
+ load the pre-trained model, task vectors, and construct the mask model.
91
+ """
92
+ merge_dtype = parse_dtype(self.merge_dtype)
93
+ modelpool = self.modelpool
94
+
95
+ # load the pre-trained model
96
+ pretrained_model = modelpool.load_pretrained_model()
97
+ self.set_clip_processor(stage="test", processor=pretrained_model.val_preprocess)
98
+
99
+ # constrcute mask model
100
+ mask_model = MaskModel(
101
+ pretrained_model, ignore_untrained_params=True, parameter_type="logits"
102
+ )
103
+ if merge_dtype is not None:
104
+ mask_model.to(merge_dtype)
105
+ mask_model.fill_(self.initial_logits)
106
+
107
+ if self.fabric.is_global_zero:
108
+ print("summary of mask model:")
109
+ print_parameters(mask_model)
110
+
111
+ if self.fabric.is_global_zero:
112
+ tsvm_algo = TaskSingularVectorMerging(
113
+ alpha=self.alpha,
114
+ exclude_keys=self.exclude_keys,
115
+ return_single_task_models=self.return_single_task_models,
116
+ )
117
+ tsvm_algo._fabric_instance = self.fabric
118
+ models = tsvm_algo.run(modelpool)
119
+
120
+ finetuned_models = [models[name] for name in modelpool.model_names]
121
+
122
+ task_wise_weight = get_task_wise_weights(
123
+ num_models=len(modelpool.model_names),
124
+ init_values=self.alpha,
125
+ )
126
+
127
+ # create a wrapped model
128
+ module = TaskWiseMergedModel(
129
+ task_wise_weight=task_wise_weight,
130
+ pretrained_model=pretrained_model,
131
+ finetuned_models=finetuned_models,
132
+ clamp_weights=self.clamp_weights,
133
+ tie_weights=self.tie_weights,
134
+ strict=self.strict,
135
+ task_vector_dtype=merge_dtype,
136
+ )
137
+ module = module.to(dtype=merge_dtype)
138
+
139
+ print("trainable parameter summary of merged model (TaskWiseMergedModel):")
140
+ print_trainable_parameters(module)
141
+ else:
142
+ module = None
143
+
144
+ with torch.no_grad():
145
+ self.fabric.barrier()
146
+ module = self.fabric.broadcast(module, src=0)
147
+
148
+ return module, mask_model
149
+
150
+ def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
151
+ """
152
+ Train the mask model using the provided module.
153
+
154
+ This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
155
+
156
+ Args:
157
+ module (TaskWiseMergedModel): The wrapped model with task-wise weights.
158
+ mask_model (MaskModel): The mask model to be trained.
159
+ """
160
+ config = self.config
161
+ merge_dtype = parse_dtype(self.merge_dtype)
162
+ log.info(f"Using merge dtype: {merge_dtype}")
163
+
164
+ optimizer: "torch.optim.Optimizer" = instantiate(
165
+ self.optimizer,
166
+ params=filter(lambda p: p.requires_grad, mask_model.parameters()),
167
+ )
168
+ print(f"{optimizer=}")
169
+ if self.lr_scheduler is not None:
170
+ lr_scheduler = instantiate(
171
+ self.lr_scheduler,
172
+ optimizer=optimizer,
173
+ )
174
+ print(f"{lr_scheduler=}")
175
+ else:
176
+ lr_scheduler = None
177
+
178
+ log.info("Setup models and optimizer with Fabric.")
179
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
180
+
181
+ log.info("Move the merged module to the correct device and disable gradients.")
182
+ module.requires_grad_(False)
183
+ module.to(mask_model.device)
184
+
185
+ mask_model.train()
186
+ optimizer.zero_grad()
187
+ for step_idx in (
188
+ pbar := tqdm(
189
+ range(self.config.max_steps if not self.is_debug_mode else 5),
190
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
191
+ + "Concrete TSVM Test-Time Adaptation",
192
+ dynamic_ncols=True,
193
+ disable=not self.fabric.is_global_zero,
194
+ )
195
+ ):
196
+ metrics = {}
197
+ # sample a shared mask and merge weights
198
+ with self.profile("sample mask"):
199
+ mask = mask_model.sample_mask(
200
+ mask_type="continuous", temperature=config.temperature
201
+ )
202
+ metrics["train/sparsity"] = mask_sparsity(mask)
203
+ with self.profile("merge weights"):
204
+ # rescale mask
205
+ for name, m in mask.items():
206
+ mask[name] = m / torch.mean(m)
207
+ module.merge_weights(task_vector_mask=mask)
208
+
209
+ # ------ inner optimization goes here ------
210
+ # NOTE:
211
+ # Because the algorithmic parameters of TSVM are assumed to be chosen on a validation test
212
+ # set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
213
+ # ------------------------------------------
214
+
215
+ total_loss = None
216
+ for task in self.modelpool.model_names:
217
+ with self.profile("data loading"):
218
+ batch = next(self.get_shuffled_test_loader_iter(task))
219
+ # NOTE: The labels are not allowed to be used during test-time adaptation
220
+ images = batch[0].to(dtype=merge_dtype)
221
+ with self.profile("forward pass"):
222
+ logits = self.compute_logits(module, images, task)
223
+ loss = entropy_loss(logits)
224
+ total_loss = loss if total_loss is None else total_loss + loss
225
+
226
+ with self.profile("compute grad"):
227
+ self.fabric.backward(total_loss)
228
+
229
+ with self.profile("optimizer step"):
230
+ optimizer.step()
231
+ optimizer.zero_grad()
232
+
233
+ if lr_scheduler is not None:
234
+ lr_scheduler.step()
235
+
236
+ metrics.update({"train/loss": loss.item()})
237
+ self.fabric.log_dict(metrics, step=step_idx)
238
+ pbar.set_postfix(metrics)
239
+
240
+ if (step_idx + 1) % self.config.save_interval == 0:
241
+ with self.profiler.profile("save checkpoint"):
242
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
243
+ if not os.path.exists(save_dir):
244
+ os.makedirs(save_dir, exist_ok=True)
245
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
246
+ print(f"saving checkpoint to {save_path}")
247
+ state = {"model": mask_model}
248
+ self.fabric.save(save_path, state)
249
+
250
+ # Create or update a symbolic link to the latest checkpoint
251
+ if self.fabric.is_global_zero:
252
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
253
+ if os.path.exists(symlink_path):
254
+ os.remove(symlink_path)
255
+ os.link(os.path.abspath(save_path), symlink_path)
256
+
257
+ self.print_profile_summary()
258
+
259
+ def run(self, modelpool: OpenCLIPVisionModelPool):
260
+ self.modelpool = modelpool
261
+ merge_dtype = parse_dtype(self.merge_dtype)
262
+
263
+ with self.profile("setup models"):
264
+ module, mask_model = self.setup_models()
265
+ self.setup_zero_shot_classification_head(freeze=True, dtype=merge_dtype)
266
+
267
+ if self.mask_checkpoint is None:
268
+ if not self.skip_training:
269
+ clear_cuda_cache()
270
+ self.train_mask(module, mask_model=mask_model)
271
+ else:
272
+ if self.fabric.is_global_zero:
273
+ print("loading mask from checkpoint", self.mask_checkpoint)
274
+ self.fabric.load(self.mask_checkpoint, {"model": mask_model})
275
+
276
+ with torch.no_grad():
277
+ clear_cuda_cache()
278
+ mask = mask_model.sample_mask(
279
+ mask_type=self.eval_mask_type, temperature=self.temperature
280
+ )
281
+ # rescale mask
282
+ for name, m in mask.items():
283
+ mask[name] = m / torch.mean(m)
284
+ model = module.merge_and_unload(mask)
285
+ return model.to(dtype=torch.float32)
@@ -249,12 +249,13 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
249
249
  # - SVD finds the principal components (most important directions)
250
250
  # - Task vectors are reconstructed using only the most significant components
251
251
  # - The reconstructed vectors are merged (summed) to create a unified task vector
252
- new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
253
- task_vectors,
254
- exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
255
- accelerator=accelerator, # Use GPU if available
256
- return_single_task_models=self.return_single_task_models,
257
- )
252
+ with torch.no_grad():
253
+ new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
254
+ task_vectors,
255
+ exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
256
+ accelerator=accelerator, # Use GPU if available
257
+ return_single_task_models=self.return_single_task_models,
258
+ )
258
259
 
259
260
  # Handle the case where individual transformed task vectors are also returned
260
261
  if self.return_single_task_models:
@@ -311,7 +311,6 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(
311
311
 
312
312
  ###############
313
313
  #### TSV Merge Orthogonalization
314
- @torch.no_grad()
315
314
  def compute_and_sum_svd_mem_reduction(
316
315
  task_vectors: List[StateDictType],
317
316
  exclude_keys: Optional[List[str]] = None,
@@ -0,0 +1,2 @@
1
+ # Exploring Model Kinship for Merging LLMs
2
+ # The implementation of this module is borrowed from: https://github.com/zjunlp/ModelKinship/
@@ -0,0 +1,77 @@
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy
5
+ import torch
6
+
7
+ from .utility import Metric
8
+
9
+
10
+ def cosine_similarity(a, b):
11
+ similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
12
+ return similarity
13
+
14
+
15
+ def calculate_model_kinship(
16
+ delta1: numpy.ndarray, delta2: numpy.ndarray, metrics: List[str]
17
+ ) -> dict:
18
+ """
19
+ Calculate model kinship using specified metrics.
20
+
21
+ Args:
22
+ delta1: Delta parameters for first model
23
+ delta2: Delta parameters for second model
24
+ metrics: List of metrics to calculate
25
+
26
+ Returns:
27
+ dict: Dictionary of metric names and their calculated values
28
+ """
29
+ results = {}
30
+ for metric in metrics:
31
+ try:
32
+ if metric not in Metric.list():
33
+ raise ValueError(f"Unsupported metric: {metric}")
34
+ results[metric] = calculate_metric(delta1, delta2, metric)
35
+ except Exception as e:
36
+ results[metric] = f"Error calculating {metric}: {str(e)}"
37
+ return results
38
+
39
+
40
+ def calculate_metric(
41
+ d_vector_1: torch.Tensor, d_vector_2: torch.Tensor, metric: str
42
+ ) -> str:
43
+ """
44
+ Calculate the specified metric between two delta vectors.
45
+
46
+ Args:
47
+ d_vector_1 (torch.Tensor): Delta parameters for model 1.
48
+ d_vector_2 (torch.Tensor): Delta parameters for model 2.
49
+ metric (str): The metric to calculate ('pcc', 'ed', 'cs').
50
+
51
+ Returns:
52
+ str: A formatted string with the result of the chosen metric.
53
+ """
54
+ logging.info(f"Starting calculation of {metric.upper()} metric...")
55
+
56
+ # Pearson Correlation Coefficient (PCC)
57
+ if metric == "pcc":
58
+ # Stack the two vectors and calculate the Pearson correlation coefficient
59
+ stack = torch.stack((d_vector_1, d_vector_2), dim=0)
60
+ pcc = torch.corrcoef(stack)[0, 1].item()
61
+ return f"Model Kinship based on Pearson Correlation Coefficient: {pcc}"
62
+
63
+ # Euclidean Distance (ED)
64
+ elif metric == "ed":
65
+ # Compute the Euclidean distance between the vectors
66
+ distance = torch.dist(d_vector_1, d_vector_2).item()
67
+ return f"Model Kinship based on Euclidean Distance: {distance}"
68
+
69
+ # Cosine Similarity (CS)
70
+ elif metric == "cs":
71
+ # Compute cosine similarity
72
+ cs = cosine_similarity(d_vector_1, d_vector_2)
73
+ return f"Model Kinship based on Cosine Similarity: {cs}"
74
+
75
+ # If metric is not recognized
76
+ else:
77
+ return "Invalid metric specified."
@@ -0,0 +1,171 @@
1
+ import logging
2
+ from typing import Dict, List
3
+
4
+ import numpy
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from .utility import Metric, load_model_state_dict, quantize_8bit
9
+
10
+
11
+ def cosine_similarity(a, b):
12
+ similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
13
+ return similarity
14
+
15
+
16
+ def calculate_model_kinship_split(
17
+ model_1_name: str,
18
+ model_2_name: str,
19
+ model_base_name: str,
20
+ low_precision: bool,
21
+ metrics: List[str],
22
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
23
+ ) -> dict:
24
+
25
+ # Extract state dictionaries from models
26
+ state_dict_1 = load_model_state_dict(model_1_name, device)
27
+ state_dict_2 = load_model_state_dict(model_2_name, device)
28
+ state_dict_base = load_model_state_dict(model_base_name, device)
29
+ results = {}
30
+
31
+ # Validate metrics before processing
32
+ valid_metrics = Metric.list()
33
+ for metric in metrics:
34
+ try:
35
+ if metric not in valid_metrics:
36
+ raise ValueError(
37
+ f"Unsupported metric: {metric}. Valid metrics are: {', '.join(valid_metrics)}"
38
+ )
39
+ results[metric] = calculate_metrics_by_split(
40
+ state_dict_1, state_dict_2, state_dict_base, low_precision, metric
41
+ )
42
+ except Exception as e:
43
+ logging.error(f"Error calculating {metric}: {str(e)}")
44
+ results[metric] = f"Error calculating {metric}: {str(e)}"
45
+
46
+ return results
47
+
48
+
49
+ def calculate_metrics_by_split(
50
+ state_dict_1: dict,
51
+ state_dict_2: dict,
52
+ state_dict_base: dict,
53
+ low_precision: bool,
54
+ metric: str,
55
+ ) -> str:
56
+ """
57
+ Calculate metrics for each key and integrate results.
58
+
59
+ Args:
60
+ state_dict_1 (dict): State dictionary of first model
61
+ state_dict_2 (dict): State dictionary of second model
62
+ state_dict_base (dict): State dictionary of base model
63
+ low_precision (bool): Whether to use 8-bit quantization
64
+ metric (str): Metric to calculate ('pcc', 'ed', 'cs')
65
+
66
+ Returns:
67
+ str: Integrated metric result as formatted string
68
+ """
69
+ total_similarity = 0.0
70
+ total_weight = 0.0
71
+ split_results = {}
72
+
73
+ # Determine the number of layers
74
+ num_layers = state_dict_base["lm_head.weight"].shape[0]
75
+
76
+ # Check architectures
77
+ if (
78
+ state_dict_1["lm_head.weight"].shape[0]
79
+ != state_dict_2["lm_head.weight"].shape[0]
80
+ ):
81
+ shape_1 = state_dict_1["lm_head.weight"].shape
82
+ shape_2 = state_dict_2["lm_head.weight"].shape
83
+ logging.warning(
84
+ f"Warning: Model architectures do not match. "
85
+ f"Using sub weight space instead.\n"
86
+ f"Vocab sizes in model 1: {shape_1[0]}, "
87
+ f"Vocab sizes in model 2: {shape_2[0]}"
88
+ )
89
+
90
+ # Process each key
91
+ for key, base_params in tqdm(
92
+ state_dict_base.items(), desc=f"Processing {metric.upper()} by key"
93
+ ):
94
+ try:
95
+ if key not in state_dict_1 or key not in state_dict_2:
96
+ logging.warning(f"Key {key} not found in one of the models")
97
+ continue
98
+
99
+ # Get parameters and calculate deltas
100
+ params_1 = state_dict_1[key][:num_layers]
101
+ params_2 = state_dict_2[key][:num_layers]
102
+
103
+ delta_1 = (params_1 - base_params).view(-1)
104
+ delta_2 = (params_2 - base_params).view(-1)
105
+
106
+ if low_precision:
107
+ delta_1 = quantize_8bit(delta_1)
108
+ delta_2 = quantize_8bit(delta_2)
109
+
110
+ # Calculate weight based on parameter count
111
+ weight = delta_1.numel()
112
+
113
+ # Calculate metric for current key
114
+ if metric == "pcc":
115
+ stack = torch.stack((delta_1, delta_2), dim=0)
116
+ split_similarity = torch.corrcoef(stack)[0, 1].item()
117
+ elif metric == "ed":
118
+ split_similarity = torch.dist(delta_1, delta_2).item()
119
+ elif metric == "cs":
120
+ split_similarity = cosine_similarity(delta_1, delta_2)
121
+ else:
122
+ raise ValueError(f"Unsupported metric: {metric}")
123
+
124
+ # Skip NaN values
125
+ if torch.isnan(torch.tensor(split_similarity)):
126
+ logging.warning(f"Skipping key {key} due to NaN result")
127
+ continue
128
+
129
+ # Store valid result
130
+ split_results[key] = split_similarity
131
+
132
+ # Update weighted average only for valid results
133
+ weight = delta_1.numel()
134
+ total_similarity += split_similarity * weight
135
+ total_weight += weight
136
+
137
+ # Log progress for large layers
138
+ if weight > 1000000:
139
+ logging.info(
140
+ f"Layer {key}: {metric.upper()} = {split_similarity:.4f}, parameters = {weight}"
141
+ )
142
+
143
+ # Free memory
144
+ del delta_1, delta_2
145
+
146
+ except Exception as e:
147
+ logging.error(f"Error processing key {key}: {str(e)}")
148
+ continue
149
+
150
+ # Calculate final weighted average
151
+ if total_weight > 0:
152
+ final_result = total_similarity / total_weight
153
+
154
+ # Log summary statistics
155
+ logging.info(f"\nSummary for {metric.upper()}:")
156
+ logging.info(f"Total parameters: {total_weight}")
157
+
158
+ # Log detailed results for valid splits
159
+ logging.info(f"\nDetailed {metric.upper()} results by key:")
160
+ for key, value in split_results.items():
161
+ logging.info(f"{key}: {value:.4f}")
162
+
163
+ metric_names = {
164
+ "pcc": "Pearson Correlation Coefficient",
165
+ "ed": "Euclidean Distance",
166
+ "cs": "Cosine Similarity",
167
+ }
168
+
169
+ return f"Model Kinship based on {metric_names[metric]} (weighted average): {final_result:.4f}"
170
+ else:
171
+ return f"Error: No valid parameters found for {metric.upper()} calculation"