fusion-bench 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -144,7 +144,15 @@ _extra_objects = {
144
144
 
145
145
  if TYPE_CHECKING:
146
146
  from .ada_svd import AdaSVDMergingForCLIPVisionModel
147
- from .adamerging import *
147
+ from .adamerging import (
148
+ CLIPLayerWiseAdaMergingAlgorithm,
149
+ CLIPTaskWiseAdaMergingAlgorithm,
150
+ FlanT5LayerWiseAdaMergingAlgorithm,
151
+ GPT2LayerWiseAdaMergingAlgorithm,
152
+ LayerWiseAdaMergingForLlamaSFT,
153
+ ResNetLayerWiseAdamerging,
154
+ ResNetTaskWiseAdamerging,
155
+ )
148
156
  from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
149
157
  from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
150
158
  from .bitdelta import BitDeltaAlgorithm
@@ -40,6 +40,7 @@ from typing import Optional # noqa: F401
40
40
 
41
41
  from fusion_bench.mixins import BaseYAMLSerializable
42
42
  from fusion_bench.modelpool import BaseModelPool
43
+ from fusion_bench.utils.misc import DeprecationWarningMeta
43
44
 
44
45
  __all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
45
46
 
@@ -202,27 +203,36 @@ class BaseAlgorithm(BaseYAMLSerializable):
202
203
  pass
203
204
 
204
205
 
205
- BaseModelFusionAlgorithm = BaseAlgorithm
206
- """
207
- Alias for BaseAlgorithm class.
206
+ # Create a deprecated wrapper class that inherits from BaseAlgorithm
207
+ class BaseModelFusionAlgorithm(BaseAlgorithm, metaclass=DeprecationWarningMeta):
208
+ """
209
+ Alias for BaseAlgorithm class.
208
210
 
209
- This alias is provided for backward compatibility and semantic clarity.
210
- Some users may prefer the more explicit name 'BaseModelFusionAlgorithm'
211
- to emphasize that this class is specifically designed for model fusion
212
- tasks, while others may prefer the shorter 'BaseAlgorithm' name.
211
+ .. deprecated::
212
+ BaseModelFusionAlgorithm is deprecated and will be removed in a future version.
213
+ Use :class:`BaseAlgorithm` instead.
213
214
 
214
- Both names refer to the exact same class and can be used interchangeably.
215
+ This alias was provided for backward compatibility and semantic clarity.
216
+ Both names refer to the same base class and can be used interchangeably,
217
+ but BaseAlgorithm is now the preferred name for all implementations.
215
218
 
216
- Examples:
217
- Using the original name:
218
- >>> class MyAlgorithm(BaseAlgorithm):
219
- ... def run(self, modelpool): pass
219
+ Examples:
220
+ Preferred (using BaseAlgorithm):
220
221
 
221
- Using the alias:
222
- >>> class MyAlgorithm(BaseModelFusionAlgorithm):
223
- ... def run(self, modelpool): pass
222
+ >>> class MyAlgorithm(BaseAlgorithm):
223
+ ... def run(self, modelpool): pass
224
224
 
225
- Note:
226
- The alias is maintained for compatibility but BaseAlgorithm is the
227
- preferred name for new implementations.
228
- """
225
+ Deprecated (using BaseModelFusionAlgorithm):
226
+
227
+ >>> class MyAlgorithm(BaseModelFusionAlgorithm): # Will trigger deprecation warning
228
+ ... def run(self, modelpool): pass
229
+
230
+ Note:
231
+ New implementations should use :class:`BaseAlgorithm` exclusively.
232
+ The BaseModelFusionAlgorithm alias will be removed in a future release.
233
+
234
+ Warning:
235
+ Using BaseModelFusionAlgorithm will trigger a DeprecationWarning.
236
+ """
237
+
238
+ pass
@@ -0,0 +1,2 @@
1
+ # Exploring Model Kinship for Merging LLMs
2
+ # The implementation of this module is borrowed from: https://github.com/zjunlp/ModelKinship/
@@ -0,0 +1,77 @@
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy
5
+ import torch
6
+
7
+ from .utility import Metric
8
+
9
+
10
+ def cosine_similarity(a, b):
11
+ similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
12
+ return similarity
13
+
14
+
15
+ def calculate_model_kinship(
16
+ delta1: numpy.ndarray, delta2: numpy.ndarray, metrics: List[str]
17
+ ) -> dict:
18
+ """
19
+ Calculate model kinship using specified metrics.
20
+
21
+ Args:
22
+ delta1: Delta parameters for first model
23
+ delta2: Delta parameters for second model
24
+ metrics: List of metrics to calculate
25
+
26
+ Returns:
27
+ dict: Dictionary of metric names and their calculated values
28
+ """
29
+ results = {}
30
+ for metric in metrics:
31
+ try:
32
+ if metric not in Metric.list():
33
+ raise ValueError(f"Unsupported metric: {metric}")
34
+ results[metric] = calculate_metric(delta1, delta2, metric)
35
+ except Exception as e:
36
+ results[metric] = f"Error calculating {metric}: {str(e)}"
37
+ return results
38
+
39
+
40
+ def calculate_metric(
41
+ d_vector_1: torch.Tensor, d_vector_2: torch.Tensor, metric: str
42
+ ) -> str:
43
+ """
44
+ Calculate the specified metric between two delta vectors.
45
+
46
+ Args:
47
+ d_vector_1 (torch.Tensor): Delta parameters for model 1.
48
+ d_vector_2 (torch.Tensor): Delta parameters for model 2.
49
+ metric (str): The metric to calculate ('pcc', 'ed', 'cs').
50
+
51
+ Returns:
52
+ str: A formatted string with the result of the chosen metric.
53
+ """
54
+ logging.info(f"Starting calculation of {metric.upper()} metric...")
55
+
56
+ # Pearson Correlation Coefficient (PCC)
57
+ if metric == "pcc":
58
+ # Stack the two vectors and calculate the Pearson correlation coefficient
59
+ stack = torch.stack((d_vector_1, d_vector_2), dim=0)
60
+ pcc = torch.corrcoef(stack)[0, 1].item()
61
+ return f"Model Kinship based on Pearson Correlation Coefficient: {pcc}"
62
+
63
+ # Euclidean Distance (ED)
64
+ elif metric == "ed":
65
+ # Compute the Euclidean distance between the vectors
66
+ distance = torch.dist(d_vector_1, d_vector_2).item()
67
+ return f"Model Kinship based on Euclidean Distance: {distance}"
68
+
69
+ # Cosine Similarity (CS)
70
+ elif metric == "cs":
71
+ # Compute cosine similarity
72
+ cs = cosine_similarity(d_vector_1, d_vector_2)
73
+ return f"Model Kinship based on Cosine Similarity: {cs}"
74
+
75
+ # If metric is not recognized
76
+ else:
77
+ return "Invalid metric specified."
@@ -0,0 +1,171 @@
1
+ import logging
2
+ from typing import Dict, List
3
+
4
+ import numpy
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from .utility import Metric, load_model_state_dict, quantize_8bit
9
+
10
+
11
+ def cosine_similarity(a, b):
12
+ similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
13
+ return similarity
14
+
15
+
16
+ def calculate_model_kinship_split(
17
+ model_1_name: str,
18
+ model_2_name: str,
19
+ model_base_name: str,
20
+ low_precision: bool,
21
+ metrics: List[str],
22
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
23
+ ) -> dict:
24
+
25
+ # Extract state dictionaries from models
26
+ state_dict_1 = load_model_state_dict(model_1_name, device)
27
+ state_dict_2 = load_model_state_dict(model_2_name, device)
28
+ state_dict_base = load_model_state_dict(model_base_name, device)
29
+ results = {}
30
+
31
+ # Validate metrics before processing
32
+ valid_metrics = Metric.list()
33
+ for metric in metrics:
34
+ try:
35
+ if metric not in valid_metrics:
36
+ raise ValueError(
37
+ f"Unsupported metric: {metric}. Valid metrics are: {', '.join(valid_metrics)}"
38
+ )
39
+ results[metric] = calculate_metrics_by_split(
40
+ state_dict_1, state_dict_2, state_dict_base, low_precision, metric
41
+ )
42
+ except Exception as e:
43
+ logging.error(f"Error calculating {metric}: {str(e)}")
44
+ results[metric] = f"Error calculating {metric}: {str(e)}"
45
+
46
+ return results
47
+
48
+
49
+ def calculate_metrics_by_split(
50
+ state_dict_1: dict,
51
+ state_dict_2: dict,
52
+ state_dict_base: dict,
53
+ low_precision: bool,
54
+ metric: str,
55
+ ) -> str:
56
+ """
57
+ Calculate metrics for each key and integrate results.
58
+
59
+ Args:
60
+ state_dict_1 (dict): State dictionary of first model
61
+ state_dict_2 (dict): State dictionary of second model
62
+ state_dict_base (dict): State dictionary of base model
63
+ low_precision (bool): Whether to use 8-bit quantization
64
+ metric (str): Metric to calculate ('pcc', 'ed', 'cs')
65
+
66
+ Returns:
67
+ str: Integrated metric result as formatted string
68
+ """
69
+ total_similarity = 0.0
70
+ total_weight = 0.0
71
+ split_results = {}
72
+
73
+ # Determine the number of layers
74
+ num_layers = state_dict_base["lm_head.weight"].shape[0]
75
+
76
+ # Check architectures
77
+ if (
78
+ state_dict_1["lm_head.weight"].shape[0]
79
+ != state_dict_2["lm_head.weight"].shape[0]
80
+ ):
81
+ shape_1 = state_dict_1["lm_head.weight"].shape
82
+ shape_2 = state_dict_2["lm_head.weight"].shape
83
+ logging.warning(
84
+ f"Warning: Model architectures do not match. "
85
+ f"Using sub weight space instead.\n"
86
+ f"Vocab sizes in model 1: {shape_1[0]}, "
87
+ f"Vocab sizes in model 2: {shape_2[0]}"
88
+ )
89
+
90
+ # Process each key
91
+ for key, base_params in tqdm(
92
+ state_dict_base.items(), desc=f"Processing {metric.upper()} by key"
93
+ ):
94
+ try:
95
+ if key not in state_dict_1 or key not in state_dict_2:
96
+ logging.warning(f"Key {key} not found in one of the models")
97
+ continue
98
+
99
+ # Get parameters and calculate deltas
100
+ params_1 = state_dict_1[key][:num_layers]
101
+ params_2 = state_dict_2[key][:num_layers]
102
+
103
+ delta_1 = (params_1 - base_params).view(-1)
104
+ delta_2 = (params_2 - base_params).view(-1)
105
+
106
+ if low_precision:
107
+ delta_1 = quantize_8bit(delta_1)
108
+ delta_2 = quantize_8bit(delta_2)
109
+
110
+ # Calculate weight based on parameter count
111
+ weight = delta_1.numel()
112
+
113
+ # Calculate metric for current key
114
+ if metric == "pcc":
115
+ stack = torch.stack((delta_1, delta_2), dim=0)
116
+ split_similarity = torch.corrcoef(stack)[0, 1].item()
117
+ elif metric == "ed":
118
+ split_similarity = torch.dist(delta_1, delta_2).item()
119
+ elif metric == "cs":
120
+ split_similarity = cosine_similarity(delta_1, delta_2)
121
+ else:
122
+ raise ValueError(f"Unsupported metric: {metric}")
123
+
124
+ # Skip NaN values
125
+ if torch.isnan(torch.tensor(split_similarity)):
126
+ logging.warning(f"Skipping key {key} due to NaN result")
127
+ continue
128
+
129
+ # Store valid result
130
+ split_results[key] = split_similarity
131
+
132
+ # Update weighted average only for valid results
133
+ weight = delta_1.numel()
134
+ total_similarity += split_similarity * weight
135
+ total_weight += weight
136
+
137
+ # Log progress for large layers
138
+ if weight > 1000000:
139
+ logging.info(
140
+ f"Layer {key}: {metric.upper()} = {split_similarity:.4f}, parameters = {weight}"
141
+ )
142
+
143
+ # Free memory
144
+ del delta_1, delta_2
145
+
146
+ except Exception as e:
147
+ logging.error(f"Error processing key {key}: {str(e)}")
148
+ continue
149
+
150
+ # Calculate final weighted average
151
+ if total_weight > 0:
152
+ final_result = total_similarity / total_weight
153
+
154
+ # Log summary statistics
155
+ logging.info(f"\nSummary for {metric.upper()}:")
156
+ logging.info(f"Total parameters: {total_weight}")
157
+
158
+ # Log detailed results for valid splits
159
+ logging.info(f"\nDetailed {metric.upper()} results by key:")
160
+ for key, value in split_results.items():
161
+ logging.info(f"{key}: {value:.4f}")
162
+
163
+ metric_names = {
164
+ "pcc": "Pearson Correlation Coefficient",
165
+ "ed": "Euclidean Distance",
166
+ "cs": "Cosine Similarity",
167
+ }
168
+
169
+ return f"Model Kinship based on {metric_names[metric]} (weighted average): {final_result:.4f}"
170
+ else:
171
+ return f"Error: No valid parameters found for {metric.upper()} calculation"
@@ -0,0 +1,184 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import List
4
+
5
+ import click
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PretrainedConfig,
13
+ )
14
+
15
+
16
+ class Metric(str, Enum):
17
+ """Enumeration of supported metrics"""
18
+
19
+ PCC = "pcc"
20
+ ED = "ed"
21
+ CS = "cs"
22
+
23
+ @classmethod
24
+ def list(cls) -> List[str]:
25
+ """Return list of supported metric values"""
26
+ return [metric.value for metric in cls]
27
+
28
+
29
+ def get_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
30
+ """
31
+ Fetch the configuration of a pretrained model from HuggingFace.
32
+
33
+ Args:
34
+ model (str): The name or path of the model to load configuration for.
35
+ trust_remote_code (bool, optional): Whether to trust remote code during loading.
36
+ Defaults to False.
37
+
38
+ Returns:
39
+ PretrainedConfig: The configuration object of the specified model.
40
+ """
41
+ # Fetch the configuration from HuggingFace's model hub.
42
+ config = AutoConfig.from_pretrained(
43
+ model,
44
+ trust_remote_code=trust_remote_code, # Whether to allow remote code execution.
45
+ )
46
+ return config
47
+
48
+
49
+ def validate_models(model_1: str, model_2: str, base_model: str) -> None:
50
+ """
51
+ Validate model names to ensure they are different and exist.
52
+
53
+ Args:
54
+ model_1: Name of the first model
55
+ model_2: Name of the second model
56
+ base_model: Name of the base model
57
+
58
+ Raises:
59
+ click.BadParameter: If validation fails
60
+ """
61
+ if model_1 == model_2 or model_1 == base_model or model_2 == base_model:
62
+ raise click.BadParameter("All model names must be different")
63
+
64
+
65
+ def quantize_8bit(x: torch.Tensor) -> torch.Tensor:
66
+ # Get absolute min and max values
67
+ abs_max = torch.max(torch.abs(x))
68
+
69
+ # Scale to [-127, 127] range for 8-bit signed integers
70
+ # Using 127 instead of 128 to keep zero exactly representable
71
+ scaled = 127 * (x / abs_max)
72
+
73
+ # Round to nearest integer
74
+ quantized = torch.round(scaled)
75
+
76
+ # Clamp values to ensure they stay in valid range
77
+ quantized = torch.clamp(quantized, -127, 127)
78
+
79
+ return quantized
80
+
81
+
82
+ def load_model_state_dict(model_name: str, device: str) -> dict:
83
+ """
84
+ Load a model and return its state dictionary.
85
+
86
+ Args:
87
+ model_name (str): Name or path of the model to load
88
+ device (str): Device to load the model on ('cuda' or 'cpu')
89
+
90
+ Returns:
91
+ dict: State dictionary of the loaded model
92
+ """
93
+ logging.info(f"Loading model: {model_name}")
94
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
95
+ state_dict = model.state_dict()
96
+ del model # Free memory
97
+ return state_dict
98
+
99
+
100
+ def extract_delta_parameters(
101
+ model_1_name: str,
102
+ model_2_name: str,
103
+ model_base_name: str,
104
+ low_precision: bool,
105
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
106
+ ) -> tuple[torch.Tensor, torch.Tensor]:
107
+ """
108
+ Extract the delta parameters (weight differences) between two models
109
+ relative to a base model.
110
+
111
+ Args:
112
+ model_1_name (str): Name or path of the first model.
113
+ model_2_name (str): Name or path of the second model.
114
+ model_base_name (str): Name or path of the base model for comparison.
115
+ low_precision (bool): Whether to use low precision weights
116
+
117
+ Returns:
118
+ (torch.Tensor, torch.Tensor): Delta parameters of model_1 and model_2 relative to base model.
119
+ """
120
+
121
+ # Extract state dictionaries from models
122
+ state_dict_1 = load_model_state_dict(model_1_name, device)
123
+ state_dict_2 = load_model_state_dict(model_2_name, device)
124
+ state_dict_base = load_model_state_dict(model_base_name, device)
125
+
126
+ # Determine the number of layers
127
+ num_layers = state_dict_base["lm_head.weight"].shape[0]
128
+
129
+ # Check if model architectures match, log a warning if not
130
+ if (
131
+ state_dict_1["lm_head.weight"].shape[0]
132
+ != state_dict_2["lm_head.weight"].shape[0]
133
+ ):
134
+ shape_1 = state_dict_1["lm_head.weight"].shape
135
+ shape_2 = state_dict_2["lm_head.weight"].shape
136
+ logging.warning(
137
+ f"Warning: Model architectures do not match. "
138
+ f"Using sub weight space instead.\n"
139
+ f"Vocab sizes in model 1: {shape_1[0]}, "
140
+ f"Vocab sizes in model 2: {shape_2[0]}"
141
+ )
142
+
143
+ # Initialize lists to store delta parameters for both models
144
+ d_vector_1, d_vector_2 = [], []
145
+
146
+ # Iterate over keys in the base model's state dictionary with tqdm
147
+ for key, base_params in tqdm(
148
+ state_dict_base.items(), desc="Processing keys", unit="key"
149
+ ):
150
+ # Only proceed if key exists in both models
151
+ try:
152
+ if key not in state_dict_1 or key not in state_dict_2:
153
+ logging.warning(f"Key {key} not found in one of the models")
154
+ continue
155
+ except Exception as e:
156
+ logging.error(f"Error processing key {key}: {str(e)}")
157
+
158
+ # Get the parameters for each model (truncate to num_layers for consistency)
159
+ params_1 = state_dict_1[key][:num_layers]
160
+ params_2 = state_dict_2[key][:num_layers]
161
+
162
+ # Compute the deltas relative to the base model
163
+ delta_1 = (params_1 - base_params).view(-1)
164
+ delta_2 = (params_2 - base_params).view(-1)
165
+
166
+ # Accumulate deltas
167
+ d_vector_1.append(delta_1)
168
+ d_vector_2.append(delta_2)
169
+
170
+ # Clear memory
171
+ del state_dict_1, state_dict_2, state_dict_base
172
+
173
+ logging.info("Concatenating delta vectors...")
174
+
175
+ d_vector_1 = torch.cat(d_vector_1)
176
+ d_vector_2 = torch.cat(d_vector_2)
177
+
178
+ if low_precision:
179
+ logging.info("Quantizing delta vectors to 8-bit precision...")
180
+ d_vector_1 = quantize_8bit(d_vector_1)
181
+ d_vector_2 = quantize_8bit(d_vector_2)
182
+ logging.info("Quantization complete")
183
+
184
+ return d_vector_1, d_vector_2
@@ -113,21 +113,27 @@ class MaskModel(ParameterDictModel):
113
113
  def get_distribution(
114
114
  self,
115
115
  mask_type: Literal["discrete", "continuous"],
116
+ temperature: float = 0.5,
116
117
  **kwargs,
117
118
  ):
118
119
  return {
119
- name: self._param_to_distribution(param, mask_type=mask_type, **kwargs)
120
+ name: self._param_to_distribution(
121
+ param, mask_type=mask_type, temperature=temperature, **kwargs
122
+ )
120
123
  for name, param in self.named_parameters()
121
124
  }
122
125
 
123
126
  def sample_mask(
124
127
  self,
125
128
  mask_type: Literal["discrete", "continuous"] = "discrete",
129
+ temperature: float = 0.5,
126
130
  **kwargs,
127
131
  ):
128
132
  mask = {}
129
133
  for name, param in self.named_parameters():
130
- dist = self._param_to_distribution(param, mask_type, **kwargs)
134
+ dist = self._param_to_distribution(
135
+ param, mask_type, temperature=temperature, **kwargs
136
+ )
131
137
  if mask_type == "discrete":
132
138
  mask[name] = dist.sample()
133
139
  elif mask_type == "continuous":
@@ -1,3 +1,10 @@
1
+ from fusion_bench.utils.packages import is_open_clip_available
2
+
3
+ if not is_open_clip_available():
4
+ raise ImportError(
5
+ "open_clip is not installed. Please install it with `pip install open_clip_torch`."
6
+ )
7
+
1
8
  from typing import Callable, List
2
9
 
3
10
  import open_clip
@@ -173,6 +173,24 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
173
173
 
174
174
  @property
175
175
  def forward_model(self):
176
+ """
177
+ Get a functional model with merged parameters.
178
+
179
+ Returns a partial function that applies the pretrained model with the current
180
+ merged state dictionary. This allows for efficient forward passes without
181
+ modifying the original model's parameters.
182
+
183
+ Returns:
184
+ Callable: A partial function that can be called with (args, kwargs) to
185
+ perform forward pass with merged parameters.
186
+
187
+ Example:
188
+ ```python
189
+ # Internal usage during forward pass
190
+ forward_fn = merged_model.forward_model
191
+ output = forward_fn(args=(x,), kwargs={})
192
+ ```
193
+ """
176
194
  return functools.partial(
177
195
  functional_call,
178
196
  self.pretrained_model,
@@ -181,10 +199,30 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
181
199
  strict=self.strict,
182
200
  )
183
201
 
184
- def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
202
+ def merge_and_unload(
203
+ self,
204
+ task_vector_mask: Optional[Dict[str, Tensor]] = None,
205
+ copy: bool = False,
206
+ ) -> TorchModelType:
207
+ """
208
+ Merge models and return the final merged model.
209
+
210
+ Args:
211
+ task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
212
+ for selective parameter merging. Defaults to None.
213
+ copy (bool, optional): Whether to return a deep copy of the pretrained model.
214
+ Defaults to False. If True, the original pretrained model remains unchanged.
215
+
216
+ Returns:
217
+ TorchModelType: The pretrained model with merged parameters loaded.
218
+ """
185
219
  self.merge_weights(task_vector_mask=task_vector_mask)
186
- self.pretrained_model.load_state_dict(self._merged_state_dict)
187
- return self.pretrained_model
220
+ if copy:
221
+ model = deepcopy(self.pretrained_model)
222
+ else:
223
+ model = self.pretrained_model
224
+ model.load_state_dict(self._merged_state_dict)
225
+ return model
188
226
 
189
227
  def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
190
228
  """
@@ -16,6 +16,7 @@ outputs = merged_model(inputs)
16
16
 
17
17
  import functools
18
18
  import logging
19
+ from copy import deepcopy
19
20
  from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
20
21
 
21
22
  import torch
@@ -327,7 +328,11 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
327
328
  self._merged_state_dict = state_dict
328
329
  return state_dict
329
330
 
330
- def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
331
+ def merge_and_unload(
332
+ self,
333
+ task_vector_mask: Optional[Dict[str, Tensor]] = None,
334
+ copy: bool = False,
335
+ ) -> TorchModelType:
331
336
  """
332
337
  Merge models and return the final merged model.
333
338
 
@@ -338,6 +343,8 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
338
343
  Args:
339
344
  task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
340
345
  for selective parameter merging. Defaults to None.
346
+ copy (bool, optional): Whether to return a deep copy of the pretrained model.
347
+ Defaults to False. If True, the original pretrained model remains unchanged.
341
348
 
342
349
  Returns:
343
350
  TorchModelType: The pretrained model with merged parameters loaded.
@@ -363,8 +370,12 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
363
370
  The original pretrained model parameters will be lost.
364
371
  """
365
372
  self.merge_weights(task_vector_mask=task_vector_mask)
366
- self.pretrained_model.load_state_dict(self._merged_state_dict)
367
- return self.pretrained_model
373
+ if copy:
374
+ model = deepcopy(self.pretrained_model)
375
+ else:
376
+ model = self.pretrained_model
377
+ model.load_state_dict(self._merged_state_dict)
378
+ return model
368
379
 
369
380
  def forward(self, *args, **kwargs):
370
381
  """
@@ -69,6 +69,20 @@ def main(cfg: DictConfig) -> None:
69
69
  """
70
70
  OmegaConf.resolve(cfg)
71
71
  program: BaseHydraProgram = instantiate(cfg)
72
+
73
+ # Validate that instantiation succeeded and returned an object with 'run' method
74
+ if not hasattr(program, "run") or not callable(getattr(program, "run")):
75
+ err_msg = (
76
+ f"Expected an object with a callable 'run' method, but got {type(program).__name__}. "
77
+ "Ensure that the configuration specifies a concrete program class with '_target_'."
78
+ )
79
+ if "_target_" not in cfg:
80
+ err_msg += "\nThe '_target_' field is missing from the root configuration."
81
+ else:
82
+ err_msg += f"\nFound '_target_': {cfg._target_}"
83
+ err_msg += f"\n\nConfiguration content:\n{cfg}"
84
+ raise TypeError(err_msg)
85
+
72
86
  program.run()
73
87
 
74
88
 
@@ -32,11 +32,13 @@ def clear_cuda_cache():
32
32
  Clears the CUDA memory cache to free up GPU memory.
33
33
  Works only if CUDA is available.
34
34
  """
35
+
35
36
  gc.collect()
36
37
  if torch.cuda.is_available():
37
38
  torch.cuda.empty_cache()
39
+ gc.collect()
38
40
  else:
39
- log.warning("CUDA is not available. No cache to clear.")
41
+ log.debug("CUDA is not available. No cache to clear.")
40
42
 
41
43
 
42
44
  def to_device(
@@ -14,8 +14,8 @@ from lightning_utilities.core.rank_zero import rank_zero_only
14
14
  from omegaconf import DictConfig, OmegaConf, SCMode
15
15
  from omegaconf._utils import is_structured_config
16
16
  from rich import print
17
- from rich.panel import Panel
18
- from rich.syntax import Syntax
17
+
18
+ from fusion_bench.utils.rich_utils import print_bordered
19
19
 
20
20
  PRINT_FUNCTION_CALL = True
21
21
  """
@@ -67,12 +67,22 @@ def _resolve_callable_name(f: Callable[..., Any]) -> str:
67
67
  return full_name
68
68
 
69
69
 
70
- def _format_args_kwargs(args, kwargs):
70
+ def _get_obj_str(obj: Any) -> str:
71
+ if isinstance(obj, (str, int, float, bool, type(None))):
72
+ return repr(obj)
73
+ else:
74
+ return f"'<{type(obj).__name__} object>'"
75
+
76
+
77
+ def _format_args_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> str:
71
78
  result_strings = []
72
79
  if len(args) > 0:
73
- result_strings.append(", ".join(repr(arg) for arg in args))
80
+ result_strings.append(", ".join(_get_obj_str(arg) for arg in args))
81
+
74
82
  if len(kwargs) > 0:
75
- result_strings.append(", ".join(f"{k}={repr(v)}" for k, v in kwargs.items()))
83
+ result_strings.append(
84
+ ", ".join(f"{k}={_get_obj_str(v)}" for k, v in kwargs.items())
85
+ )
76
86
 
77
87
  if len(result_strings) == 0:
78
88
  return ""
@@ -145,14 +155,14 @@ def _call_target(
145
155
  if _partial_:
146
156
  if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
147
157
  call_str = f"functools.partial({_resolve_callable_name(_target_)}, {_format_args_kwargs(args, kwargs)})"
148
- PRINT_FUNCTION_CALL_FUNC(
149
- Panel(
150
- Syntax(call_str, "python", theme="monokai", word_wrap=True),
151
- title="Instantiate by calling partial",
152
- border_style="cyan",
153
- )
158
+ print_bordered(
159
+ call_str,
160
+ code_style="python",
161
+ title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
162
+ style="cyan",
163
+ expand=False,
164
+ print_fn=PRINT_FUNCTION_CALL_FUNC,
154
165
  )
155
-
156
166
  if CATCH_EXCEPTION:
157
167
  try:
158
168
  return functools.partial(_target_, *args, **kwargs)
@@ -169,12 +179,13 @@ def _call_target(
169
179
  else:
170
180
  if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
171
181
  call_str = f"{_resolve_callable_name(_target_)}({_format_args_kwargs(args, kwargs)})"
172
- PRINT_FUNCTION_CALL_FUNC(
173
- Panel(
174
- Syntax(call_str, "python", theme="monokai", word_wrap=True),
175
- title="Instantiate by calling function",
176
- border_style="green",
177
- )
182
+ print_bordered(
183
+ call_str,
184
+ code_style="python",
185
+ title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
186
+ style="green",
187
+ expand=False,
188
+ print_fn=PRINT_FUNCTION_CALL_FUNC,
178
189
  )
179
190
  if CATCH_EXCEPTION:
180
191
  try:
@@ -178,3 +178,19 @@ def validate_and_suggest_corrections(
178
178
  if matches:
179
179
  msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
180
180
  raise ValueError(msg)
181
+
182
+
183
+ class DeprecationWarningMeta(type):
184
+ """
185
+ Metaclass that issues a deprecation warning whenever a class using it is instantiated.
186
+ """
187
+
188
+ def __call__(cls, *args, **kwargs):
189
+ import warnings
190
+
191
+ warnings.warn(
192
+ f"{cls.__name__} is deprecated and will be removed in a future version. ",
193
+ DeprecationWarning,
194
+ stacklevel=2,
195
+ )
196
+ return super(DeprecationWarningMeta, cls).__call__(*args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Sequence
3
+ from typing import Optional, Sequence
4
4
 
5
5
  import rich
6
6
  import rich.syntax
@@ -19,6 +19,9 @@ from rich.text import Text
19
19
  from rich.traceback import install as install_rich_traceback
20
20
 
21
21
  from fusion_bench.utils import pylogger
22
+ from fusion_bench.utils.packages import _is_package_available
23
+
24
+ install_rich_traceback()
22
25
 
23
26
  log = pylogger.RankedLogger(__name__, rank_zero_only=True)
24
27
 
@@ -61,7 +64,31 @@ def display_available_styles():
61
64
  console.print(Columns(style_samples, equal=True, expand=False))
62
65
 
63
66
 
64
- def print_bordered(message, title=None, style="blue", code_style=None):
67
+ def format_code_str(message: str, code_style="python"):
68
+ if code_style.lower() == "python" and _is_package_available("black"):
69
+ # Use black formatting for python code if black is available
70
+ import black
71
+
72
+ try:
73
+ message = black.format_str(message, mode=black.Mode())
74
+ except black.InvalidInput:
75
+ pass # If black fails, use the original message
76
+
77
+ return message.strip()
78
+
79
+
80
+ def print_bordered(
81
+ message,
82
+ title=None,
83
+ style="blue",
84
+ code_style=None,
85
+ *,
86
+ expand: bool = True,
87
+ theme: str = "monokai",
88
+ background_color: Optional[str] = "default",
89
+ print_fn=print,
90
+ format_code: bool = True,
91
+ ):
65
92
  """
66
93
  Print a message with a colored border.
67
94
 
@@ -73,12 +100,63 @@ def print_bordered(message, title=None, style="blue", code_style=None):
73
100
  Set to None for plain text. Defaults to "python".
74
101
  """
75
102
  if code_style:
76
- content = Syntax(message, code_style, theme="monokai", word_wrap=True)
103
+ if format_code:
104
+ message = format_code_str(message, code_style)
105
+ content = Syntax(
106
+ message,
107
+ code_style,
108
+ word_wrap=True,
109
+ theme=theme,
110
+ background_color=background_color,
111
+ )
77
112
  else:
78
113
  content = Text(message)
79
114
 
80
- panel = Panel(content, title=title, border_style=style)
81
- print(panel)
115
+ panel = Panel(content, title=title, border_style=style, expand=expand)
116
+ print_fn(panel)
117
+
118
+
119
+ def print_code(
120
+ message,
121
+ title=None,
122
+ code_style=None,
123
+ *,
124
+ expand: bool = True,
125
+ theme: str = "monokai",
126
+ background_color: Optional[str] = "default",
127
+ print_fn=print,
128
+ ):
129
+ """
130
+ Print code or plain text with optional syntax highlighting.
131
+
132
+ Args:
133
+ message (str): The message or code to print.
134
+ title (str, optional): Optional title associated with this output. Currently
135
+ not used by this function, but kept for API compatibility. Defaults to None.
136
+ code_style (str, optional): The language/lexer name for syntax highlighting
137
+ (for example, ``"python"``). If ``None``, the message is rendered as plain
138
+ text without syntax highlighting. Defaults to ``None``.
139
+ expand (bool, optional): Placeholder flag for API symmetry with other printing
140
+ helpers. It is not used in the current implementation. Defaults to True.
141
+ theme (str, optional): Name of the Rich syntax highlighting theme to use when
142
+ ``code_style`` is provided. Defaults to ``"monokai"``.
143
+ background_color (str, optional): Background color style to apply to the code
144
+ block when using syntax highlighting. Defaults to ``"default"``.
145
+ print_fn (Callable, optional): Function used to render the resulting Rich
146
+ object. Defaults to :func:`rich.print`.
147
+ """
148
+ if code_style:
149
+ content = Syntax(
150
+ message,
151
+ code_style,
152
+ word_wrap=True,
153
+ theme=theme,
154
+ background_color=background_color,
155
+ )
156
+ else:
157
+ content = Text(message)
158
+
159
+ print_fn(content)
82
160
 
83
161
 
84
162
  @rank_zero_only
@@ -95,6 +173,9 @@ def print_config_tree(
95
173
  ),
96
174
  resolve: bool = False,
97
175
  save_to_file: bool = False,
176
+ *,
177
+ theme: str = "monokai",
178
+ background_color: Optional[str] = "default",
98
179
  ) -> None:
99
180
  """Prints the contents of a DictConfig as a tree structure using the Rich library.
100
181
 
@@ -134,7 +215,14 @@ def print_config_tree(
134
215
  else:
135
216
  branch_content = str(config_group)
136
217
 
137
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
218
+ branch.add(
219
+ rich.syntax.Syntax(
220
+ branch_content,
221
+ "yaml",
222
+ theme=theme,
223
+ background_color=background_color,
224
+ )
225
+ )
138
226
 
139
227
  # print config tree
140
228
  rich.print(tree)
@@ -145,6 +233,35 @@ def print_config_tree(
145
233
  rich.print(tree, file=file)
146
234
 
147
235
 
236
+ @rank_zero_only
237
+ def print_config_yaml(
238
+ cfg: DictConfig,
239
+ resolve: bool = False,
240
+ output_path: Optional[str] = False,
241
+ *,
242
+ theme: str = "monokai",
243
+ background_color: Optional[str] = "default",
244
+ ) -> None:
245
+ """
246
+ Prints the contents of a DictConfig as a YAML string using the Rich library.
247
+
248
+ Args:
249
+ cfg: A DictConfig composed by Hydra.
250
+ resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
251
+ output_path: Optional path to export the config YAML to. If provided, the file is written to this path.
252
+ """
253
+ config_yaml = OmegaConf.to_yaml(cfg, resolve=resolve)
254
+ syntax = rich.syntax.Syntax(
255
+ config_yaml, "yaml", theme=theme, background_color=background_color
256
+ )
257
+ rich.print(syntax)
258
+
259
+ if output_path:
260
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
261
+ with open(Path(output_path), "w") as file:
262
+ rich.print(syntax, file=file)
263
+
264
+
148
265
  @rank_zero_only
149
266
  def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
150
267
  """Prompts user to input tags from command line if no tags are provided in config.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion-bench
3
- Version: 0.2.29
3
+ Version: 0.2.30
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench
@@ -255,9 +255,9 @@ First, create a new Python file for the algorithm in the `fusion_bench/method` d
255
255
  Following the naming convention, the file should be named `{method_name_or_class}/{variant}.py`.
256
256
 
257
257
  ```python
258
- from fusion_bench import BaseModelFusionAlgorithm, BaseModelPool
258
+ from fusion_bench import BaseAlgorithm, BaseModelPool
259
259
 
260
- class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
260
+ class DerivedModelFusionAlgorithm(BaseAlgorithm):
261
261
  """
262
262
  An example of a derived model fusion algorithm.
263
263
  """
@@ -265,7 +265,7 @@ class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
265
265
  # _config_mapping maps the attribution to the corresponding key in the configuration file.
266
266
  # this is optional and can be used to serialize the object to a configuration file.
267
267
  # `self.config.hyperparam_1` will be mapped to the attribute `hyperparam_attr_1`.
268
- _config_mapping = BaseModelFusionAlgorithm._config_mapping | {
268
+ _config_mapping = BaseAlgorithm._config_mapping | {
269
269
  "hyperparam_attr_1": "hyperparam_1",
270
270
  "hyperparam_attr_2": "hyperparam_2",
271
271
  }
@@ -344,9 +344,9 @@ If you find this benchmark useful, please consider citing our work:
344
344
  ```bibtex
345
345
  @article{tang2024fusionbench,
346
346
  title={Fusionbench: A comprehensive benchmark of deep model fusion},
347
- author={Tang, Anke and Shen, Li and Luo, Yong and Hu, Han and Du, Bo and Tao, Dacheng},
348
- journal={arXiv preprint arXiv:2406.03280},
349
- year={2024}
347
+ author={Tang, Anke and Shen, Li and Luo, Yong and Yang, Enneng and Hu, Han and Zhang, Lefei and Du, Bo and Tao, Dacheng},
348
+ journal={Journal of Machine Learning Research},
349
+ year={2025}
350
350
  }
351
351
  ```
352
352
 
@@ -48,8 +48,8 @@ fusion_bench/dataset/llama/stanford_shp.py,sha256=6ueXKnFXIBBobacU1h5WxGLZrSOtBk
48
48
  fusion_bench/dataset/llama/ultrachat.py,sha256=Go7WvrDAYnm184fdazHGRYLbSY6Xd7jrESyQeUJtOww,1736
49
49
  fusion_bench/dataset/llama/wikitext.py,sha256=9ZHR-nMfXRumd3o-PIj3n7B83YlVeqpGkZ2zJs2B-9Y,2883
50
50
  fusion_bench/dataset/llama/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
- fusion_bench/method/__init__.py,sha256=sXjV8DDn3yXVjzsl6k-nMVx6EABQDjXjY3xK-I6nvr0,9527
52
- fusion_bench/method/base_algorithm.py,sha256=OnKSNPQ_nIdIWxryyblW_sko7uoEBN4lGh-eLkJ4kh4,9004
51
+ fusion_bench/method/__init__.py,sha256=Set_2GWpmI3q_WvbV1hBUfa6GFiIuajyiZR2hRbfrN0,9811
52
+ fusion_bench/method/base_algorithm.py,sha256=Pa3A7ON0YK3PJqFE77IY9dpQC-tQGJpX6kdf8IMnM_k,9453
53
53
  fusion_bench/method/dummy.py,sha256=hb1y6LR_geRZ5eRgGwt5zJUcHYorCeIbs5i76CvurUc,1031
54
54
  fusion_bench/method/ensemble.py,sha256=Bjzqxt-tUp5cawT1jIhqKswN5QH3bkYbmuI4LS4uTG0,3619
55
55
  fusion_bench/method/model_recombination.py,sha256=b2ku5wCrWd1QSZscIra4KlhLDxt04JjU30ItMNvpZ6g,5268
@@ -257,6 +257,10 @@ fusion_bench/method/wudi/wudi.py,sha256=HL3Y0MPjozp7NML_UNjIWWPbQDQxYH_WG_Buyrip
257
257
  fusion_bench/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
258
258
  fusion_bench/metrics/continual_learning/__init__.py,sha256=f-mkv4SpXTq5kiQVHbe2g0IPf4yLFgu1Dw7g2DOK6T4,57
259
259
  fusion_bench/metrics/continual_learning/backward_transfer.py,sha256=LCMWFFmBgWv7UIAJqiTaSvVvanx4qjnXIGuCMYvzmtc,559
260
+ fusion_bench/metrics/model_kinship/__init__.py,sha256=-XWD0NR6Xz-p4oE8AKGoWrq-s1ayqWse7qLgNRENsaU,137
261
+ fusion_bench/metrics/model_kinship/calculate.py,sha256=FoyBQuz3-q2NRfUW9w0dq9Tm51WG83iF_L_nHMOSI20,2447
262
+ fusion_bench/metrics/model_kinship/calculate_split.py,sha256=_aTw7nfAZeEhiyqWlUkzwafQXLI3iDQMHdFy6ZMb88w,5797
263
+ fusion_bench/metrics/model_kinship/utility.py,sha256=9iF9bWsJOFhhLqPMDyHyg-PAmat_zYUbud-umTfgBLs,5903
260
264
  fusion_bench/metrics/nyuv2/__init__.py,sha256=Ed1FQTJAxguJoorZLHIO-cSIgKYHHfqdf17J3o9_feI,1390
261
265
  fusion_bench/metrics/nyuv2/depth.py,sha256=xmUokztxyPrl90qtcoQaanti6DbFaIVqglAo3PDnEso,2851
262
266
  fusion_bench/metrics/nyuv2/loss.py,sha256=YKZSqycNyPWJV29Qa12--Wh87zZvtJcuUxUuiPbccpM,2529
@@ -330,7 +334,7 @@ fusion_bench/models/llama/model_utils/misc.py,sha256=3SJ7wk71zLMVF-AJEvQ_KCfFaMg
330
334
  fusion_bench/models/llama/model_utils/mod.py,sha256=xzNOgTRfOK9q8kml4Q2nmSOl23f33dE1tPi5zxgpWK0,1498
331
335
  fusion_bench/models/llama/model_utils/visual.py,sha256=wpqWqEASyA7WhJLCfC26h0Cdn5CXnwC1qPJUlSXggo4,8310
332
336
  fusion_bench/models/masks/__init__.py,sha256=vXG6jrBkDbPsnrX6nMEYAW1rQuGEWDgdjID7cKzXvrs,69
333
- fusion_bench/models/masks/mask_model.py,sha256=YXNZ_CGp6VPshZH__Znh6Z07BqOK53G-Ltc1LVy1E3I,5502
337
+ fusion_bench/models/masks/mask_model.py,sha256=NDVhtuvZ10NUfTLEI_ONTKiceuSF-W7T9SEeUnyZFYQ,5680
334
338
  fusion_bench/models/model_card_templates/default.md,sha256=OoU83l1hip1gKsoA08hoKx-nCrOYbKaVTVCjK0pt9WY,1028
335
339
  fusion_bench/models/modeling_deepseek_v2/__init__.py,sha256=trXrhtKb_gIxXVo7wSZ-il5sLJtDTiNZezRrEt3M8zM,505
336
340
  fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py,sha256=TblFOCfNwaXUnXnD-sxFhSn5Df-_yy2LMcrth-sBPFI,10301
@@ -364,7 +368,7 @@ fusion_bench/models/nyuv2/lightning_module.py,sha256=SLtC0yL6455uKeb-o07MR6v-xE4
364
368
  fusion_bench/models/nyuv2/resnet.py,sha256=PcCfBhEsxm7W8cu3epBbIbCYFARPrPTamIa3TtUAVa0,14305
365
369
  fusion_bench/models/nyuv2/resnet_dilated.py,sha256=4EXB6vrBJS307YP6k-TRY1dFJ50LURcTuzqN4tZzYRk,3125
366
370
  fusion_bench/models/open_clip/__init__.py,sha256=zT2sGAT98Py5vXMckZF4aD8MYEICEWa2p7nRg4IrS0w,192
367
- fusion_bench/models/open_clip/modeling.py,sha256=34wKcbxe5xb6fzAVdIz0QcsSXs-8FQFUyqRNlIJso78,5556
371
+ fusion_bench/models/open_clip/modeling.py,sha256=YOCsM1RfvhqJkUzwK9T4WqX1NW7LyAIi0UnN6ERQ-rk,5775
368
372
  fusion_bench/models/open_clip/utils.py,sha256=YM_vGQSxIDoB2euHG54hhRGIcINJfR0NxNT5U42KRCw,10394
369
373
  fusion_bench/models/open_clip/variables_and_paths.py,sha256=_OBcKvZwSGvYSmgKtXOuekEJI-btW94Ia-BQ9n4isfY,1231
370
374
  fusion_bench/models/smile_moe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -376,9 +380,9 @@ fusion_bench/models/surgery/__init__.py,sha256=tcUSi2m9GzGWfvRDQScIbdEbFBS_35gm9
376
380
  fusion_bench/models/surgery/surgerymodelwrapper.py,sha256=F8jX88K5zVWC6HsfN-nGNkEiPwNrN11ydyQQ1EZHehM,5133
377
381
  fusion_bench/models/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
378
382
  fusion_bench/models/wrappers/ensemble.py,sha256=T-DAKrAm-ciZwV6Hbt8uASbjtoQpHTlvVyan3rhk_8k,11632
379
- fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=A7LjG0inL5oeEVOkJwEUDM15v4dpQnsCq2y9zA78R3k,11198
383
+ fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=T1sbujx_84Pj5yHFy5QqfipT6v3p96gUmnMgyy4lG0c,12560
380
384
  fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py,sha256=q5Hc4BtLpAawMbxsWJRL-8OR-x7994Jhr9IyN7vKZ9o,16930
381
- fusion_bench/models/wrappers/task_wise_fusion.py,sha256=ROLANdDq0bZ3sIROqIv3udPN8lzDdEwxD0Jonx-5ycw,17465
385
+ fusion_bench/models/wrappers/task_wise_fusion.py,sha256=iCrevrkG4uTr3U8_hgT_xEY4epnEK0EJO8yg-uEMIUI,17836
382
386
  fusion_bench/optim/__init__.py,sha256=JS7J2VjrM2LdkiFCxuQnIuFwBsWiPyFb7QuEU6V2bPY,845
383
387
  fusion_bench/optim/exception.py,sha256=fMgo1heiqfGhuI5RIbf30BwWSShn5RQiyeb30QtfTI0,1607
384
388
  fusion_bench/optim/mezo.py,sha256=Vm4vMGh10Fhe28_9L1MK8r_U7DrurA8Liprh2_gn4_U,3646
@@ -392,7 +396,7 @@ fusion_bench/programs/base_program.py,sha256=Bl_bv8SawEUc-GBTtZFMoii0y-r-0hOXBAJ
392
396
  fusion_bench/programs/fabric_fusion_program.py,sha256=wIHNpLUw6uAXpAasJRAMWut55hF_EGFShxn70zRRvfk,12449
393
397
  fusion_bench/programs/fusion_program.py,sha256=qLyA3FHJUMM1L3mlYn4jlnZzv9OKguWM5aGGIoLts2I,11309
394
398
  fusion_bench/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
395
- fusion_bench/scripts/cli.py,sha256=kEWLEkZEBqUr1_-XTePzNC5NM8lwWvgUBf0Lcuk_FI8,2739
399
+ fusion_bench/scripts/cli.py,sha256=t3YFuscJluxxNdXawW8FOaYH2fKn7m_6bXNlJ8KcZZg,3414
396
400
  fusion_bench/scripts/imgui.py,sha256=r9Glbfbwu3JCsX9TKQFwcHarvwA_G7ff0jWBUPW1S1U,7613
397
401
  fusion_bench/scripts/nyuv2_mtl_train.py,sha256=W1C45R9NdF4O-UjCx1bUxRTdFE0-FlRpwJHZ5gY18rI,3602
398
402
  fusion_bench/scripts/webui.py,sha256=ROvZUIj-hR4JLgCiWEKGc25LMtAjaMAZLJ5ckDYt-w4,21513
@@ -457,24 +461,24 @@ fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py,sha256=mKMTX
457
461
  fusion_bench/utils/__init__.py,sha256=EvrvupFGAzxll_jO0HYk1-I6jCHqDrIwZ5vswlR-9Pw,5149
458
462
  fusion_bench/utils/cache_utils.py,sha256=-bTZijQgl4BuAx0VSJFD-bSDOXuq3o0NkrOaiLiyofU,4795
459
463
  fusion_bench/utils/data.py,sha256=QAXpsvzHOgfAf6G_Pe2a5HOKUAP8Mxz77avujQI9Fd8,10027
460
- fusion_bench/utils/devices.py,sha256=6AkGcs3flt0FSo9yfEREuehoTrgcc65gkwpTWQy8XsI,9546
464
+ fusion_bench/utils/devices.py,sha256=IyUBaWbnZGDsAxI97LEioUj-JIjYTzxQo_EhyKY3RZM,9566
461
465
  fusion_bench/utils/dict.py,sha256=ZCK0CRRT_B1Z18WY_GOYcmth7k5x9Jn1k7XhAVWRu98,1379
462
466
  fusion_bench/utils/dtype.py,sha256=z6UlPGF9dzG4Ik8rXGf59PJk_RKzG6Trp8O6wcBS9PU,4360
463
467
  fusion_bench/utils/expr.py,sha256=zwHNrtIbOMnIChU-0ZI5qLbDva8zvHbizL-4F2TwM14,2386
464
468
  fusion_bench/utils/fabric.py,sha256=qKcJ6Xj-6rEGy35dsUPHzxZT6az9RkSNcyBQl1uOv0M,6050
465
469
  fusion_bench/utils/functools.py,sha256=7_tYJ2WD88_2DDuOOj5aZz3cYuslYH5tsVyIgCeLtmk,1318
466
470
  fusion_bench/utils/hydra_utils.py,sha256=TklUDKDEZlg4keI-TEZiqh4gFjr9-61Rt1RMlqkoSGk,1174
467
- fusion_bench/utils/instantiate_utils.py,sha256=OXkfhq_o3Sgy5n3Psf-HI-dIfbK9oD2GBdfcx3gT63Q,17526
471
+ fusion_bench/utils/instantiate_utils.py,sha256=UNfx188feTDrMSgp-ocLHetj6uD6axZcC46dRfBMtko,17884
468
472
  fusion_bench/utils/json.py,sha256=XZvEqBGpq-e0MaKkkX-1_PD8xMf6IDLAn4BrAF7IeiU,4552
469
473
  fusion_bench/utils/lazy_imports.py,sha256=s-1ABhPyyHs7gW4aodCzu3NySzILzTL7kVNZ0DZRXJA,6156
470
474
  fusion_bench/utils/lazy_state_dict.py,sha256=mJaiAtKB1vlNUAoQILnnCmU80FGJ8MSwmdPpmdhOyDE,22206
471
- fusion_bench/utils/misc.py,sha256=_7BaS9dNKyySGU0qmTmE0Tk8WK82TEm7IBJxVRkuEAw,5315
475
+ fusion_bench/utils/misc.py,sha256=xntIUj4cwgx10y7Z1YqXT0zU4nDHfnKRK_M9biWgLH4,5780
472
476
  fusion_bench/utils/modelscope.py,sha256=P8fV6Eff8oP0LVGIFGbLvuk8MBteysN438djZ6ZEfE4,10699
473
477
  fusion_bench/utils/packages.py,sha256=m2E0ryIMI0NwWR9vUHkK9FtZEwA1G-A4dYOf87olli4,2217
474
478
  fusion_bench/utils/parameters.py,sha256=ufEDOYJwcQQxLfveK8hBAGwpu5J3LA_cTWiDgZ2zkJ0,11788
475
479
  fusion_bench/utils/path.py,sha256=piznok_znXkTY71VBwJrxBlXureYOdQnMfvqaZ26qvc,2643
476
480
  fusion_bench/utils/pylogger.py,sha256=1Uy_LkHkbrYdt1g5Ge_eAh2YoCJwn3U3Ndouz9sVA6g,3419
477
- fusion_bench/utils/rich_utils.py,sha256=3Z0di-1IOs3QoovF2frNA28ITVKWBLdm84zbXdTrM28,5924
481
+ fusion_bench/utils/rich_utils.py,sha256=CJKL1vIHm2EznWa4e7ExmY5-lRtRRHLd7ZFPcn2acUs,9664
478
482
  fusion_bench/utils/set.py,sha256=_43ZvGKJ_BK9sUslsSNhi7xEfuAQuyj3vViImnGpnCY,134
479
483
  fusion_bench/utils/state_dict_arithmetic.py,sha256=bXO3zewO3KDzRmTaznlsnURIoSlcW5V5IhuXGtI_nxk,41234
480
484
  fusion_bench/utils/tensorboard.py,sha256=9fkgNYR9LM38nPNkudcxL9TjLUseW-280M0k2nLff7o,1669
@@ -488,7 +492,7 @@ fusion_bench/utils/plot/token_notebook.py,sha256=bsntXf46Zz_RavTxNiB9c3-KvHw7LFw
488
492
  fusion_bench/utils/strenum/__init__.py,sha256=id9ORi1uXrDxhbmVxitJ1KDwLS4H3AAwFpaK5h1cQzw,8531
489
493
  fusion_bench/utils/strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
490
494
  fusion_bench/utils/strenum/_version.py,sha256=6JQRo9LcvODbCOeVFYQb9HNJ_J9XiG_Zbn8ws2A3BV8,18466
491
- fusion_bench-0.2.29.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
495
+ fusion_bench-0.2.30.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
492
496
  fusion_bench_config/README.md,sha256=Lc8YSBJ5oxf9KV5kKDivJ9LRyGuraGQPmBbgbdVA-j4,703
493
497
  fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=pZ5dFgg5n1W9cKdNyGNa7b4yPd4aQSu2iR2-yw9hhbY,442
494
498
  fusion_bench_config/fabric_model_fusion.yaml,sha256=kSQbhBsKypVFA3rmkdhY9BITnZWDXJof-I35t473_U0,2646
@@ -1015,8 +1019,8 @@ fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml,sha256=3q-KMuFaM
1015
1019
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml,sha256=GjpiiRownrBCpl-TNwWRW2PYePbF-Cl99jlLNPrK5T4,1017
1016
1020
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml,sha256=WwiYMQKehtJixDPnu5o3vcWe4yJksXTWRqOzm3uVWXQ,1017
1017
1021
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml,sha256=xGRt0J9joXTzWUew6DvoYprAWlPXhaVFw5AX4im5VQw,1017
1018
- fusion_bench-0.2.29.dist-info/METADATA,sha256=RivzHbrFvjc6WrrpTlsPwyCpUz8vw8Kc7GfxIwtIKxk,26292
1019
- fusion_bench-0.2.29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1020
- fusion_bench-0.2.29.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
1021
- fusion_bench-0.2.29.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
1022
- fusion_bench-0.2.29.dist-info/RECORD,,
1022
+ fusion_bench-0.2.30.dist-info/METADATA,sha256=fcL0hcELjiXF7XmX4E2efcc_v1SrlSL9fsqQ7WCxyVM,26298
1023
+ fusion_bench-0.2.30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1024
+ fusion_bench-0.2.30.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
1025
+ fusion_bench-0.2.30.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
1026
+ fusion_bench-0.2.30.dist-info/RECORD,,