fusion-bench 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +4 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  12. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  14. fusion_bench/method/model_stock/__init__.py +1 -0
  15. fusion_bench/method/model_stock/model_stock.py +309 -0
  16. fusion_bench/method/regmean/clip_regmean.py +3 -6
  17. fusion_bench/method/regmean/regmean.py +27 -56
  18. fusion_bench/method/regmean/utils.py +56 -0
  19. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  20. fusion_bench/method/slerp/__init__.py +1 -1
  21. fusion_bench/method/slerp/slerp.py +110 -14
  22. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  23. fusion_bench/mixins/clip_classification.py +26 -6
  24. fusion_bench/mixins/serialization.py +25 -15
  25. fusion_bench/modelpool/base_pool.py +1 -1
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +262 -43
  27. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  28. fusion_bench/models/hf_utils.py +9 -4
  29. fusion_bench/models/linearized/vision_model.py +6 -6
  30. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  31. fusion_bench/models/we_moe.py +8 -8
  32. fusion_bench/taskpool/base_pool.py +99 -17
  33. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  34. fusion_bench/taskpool/dummy.py +101 -13
  35. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  36. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +6 -4
  39. fusion_bench/utils/devices.py +7 -4
  40. fusion_bench/utils/dtype.py +3 -2
  41. fusion_bench/utils/lazy_state_dict.py +82 -19
  42. fusion_bench/utils/packages.py +3 -3
  43. fusion_bench/utils/parameters.py +0 -2
  44. fusion_bench/utils/timer.py +92 -10
  45. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -1
  46. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +53 -47
  47. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  48. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  49. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  50. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  51. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  52. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  53. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,24 @@
1
1
  import logging
2
- from typing import Any, Dict
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
3
5
 
4
6
  import torch
5
7
  from torch import nn
8
+ from tqdm import tqdm
6
9
  from typing_extensions import override
7
10
 
11
+ from fusion_bench import LazyStateDict, create_default_model_card, timeit_context
8
12
  from fusion_bench.method import BaseAlgorithm
9
- from fusion_bench.modelpool import BaseModelPool
13
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
14
+ from fusion_bench.modelpool import BaseModelPool, CausalLMPool
10
15
  from fusion_bench.utils.type import StateDictType
11
16
 
12
17
  from .slerp_utils import slerp
13
18
 
19
+ if TYPE_CHECKING:
20
+ from transformers import PreTrainedModel
21
+
14
22
  log = logging.getLogger(__name__)
15
23
 
16
24
 
@@ -21,6 +29,7 @@ def slerp_on_state_dicts(
21
29
  *,
22
30
  DOT_THRESHOLD: float = 0.9995,
23
31
  epsilon: float = 1e-8,
32
+ show_pbar: bool = False,
24
33
  ) -> StateDictType:
25
34
  """
26
35
  Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
@@ -36,7 +45,8 @@ def slerp_on_state_dicts(
36
45
  dict: The interpolated state dictionary.
37
46
  """
38
47
  state_dict = {}
39
- for key in secondary_state_dict:
48
+ pbar = secondary_state_dict if not show_pbar else tqdm(secondary_state_dict)
49
+ for key in pbar:
40
50
  v0 = primary_state_dict[key]
41
51
  v1 = secondary_state_dict[key]
42
52
  if v0.shape != v1.shape:
@@ -49,18 +59,19 @@ def slerp_on_state_dicts(
49
59
  return state_dict
50
60
 
51
61
 
62
+ @auto_register_config
52
63
  class SlerpMergeAlgorithm(BaseAlgorithm):
53
64
  """
54
65
  General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
55
66
  """
56
67
 
57
- _config_mapping = BaseAlgorithm._config_mapping | {
58
- "t": "t",
59
- "DOT_THRESHOLD": "DOT_THRESHOLD",
60
- "epsilon": "epsilon",
61
- }
62
-
63
- def __init__(self, t: float, DOT_THRESHOLD: float = 0.9995, epsilon: float = 1e-8):
68
+ def __init__(
69
+ self,
70
+ t: float,
71
+ DOT_THRESHOLD: float = 0.9995,
72
+ epsilon: float = 1e-8,
73
+ **kwargs,
74
+ ):
64
75
  """
65
76
  Initialize the SlerpMergeAlgorithm.
66
77
 
@@ -69,10 +80,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
69
80
  DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
70
81
  epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
71
82
  """
72
- self.t = t
73
- self.DOT_THRESHOLD = DOT_THRESHOLD
74
- self.epsilon = epsilon
75
- super().__init__()
83
+ super().__init__(**kwargs)
76
84
 
77
85
  @override
78
86
  def run(self, modelpool: BaseModelPool) -> nn.Module:
@@ -102,3 +110,91 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
102
110
 
103
111
  primary_model.load_state_dict(state_dict)
104
112
  return primary_model
113
+
114
+
115
+ @auto_register_config
116
+ class SlerpForCausalLM(
117
+ SimpleProfilerMixin,
118
+ BaseAlgorithm,
119
+ ):
120
+ """
121
+ Slerp (Spherical Linear Interpolation) for Causal Language Models.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ t: float,
127
+ DOT_THRESHOLD: float = 0.9995,
128
+ epsilon: float = 1e-8,
129
+ model_save_path: Optional[str] = None,
130
+ show_pbar: bool = False,
131
+ **kwargs,
132
+ ):
133
+ """
134
+ Initialize the SlerpForCausalLM algorithm.
135
+
136
+ Args:
137
+ t (float): The interpolation parameter. Must be in the range [0, 1].
138
+ t=0 returns the first model, t=1 returns the second model,
139
+ t=0.5 provides balanced interpolation.
140
+ DOT_THRESHOLD (float, optional): The threshold for the dot product of normalized vectors.
141
+ When the absolute dot product exceeds this threshold,
142
+ vectors are considered nearly collinear and linear
143
+ interpolation (LERP) is used instead of SLERP for
144
+ numerical stability. Defaults to 0.9995.
145
+ epsilon (float, optional): Small value used for numerical stability to avoid
146
+ division by zero during vector normalization.
147
+ Defaults to 1e-8.
148
+ model_save_path (Optional[str], optional): Path where the merged model should be saved.
149
+ If None, the model is not saved to disk.
150
+ Defaults to None.
151
+ show_pbar (bool, optional): Whether to display a progress bar during the interpolation
152
+ process. Useful for debugging or monitoring progress with
153
+ large models. Defaults to False.
154
+ **kwargs: Additional keyword arguments passed to the parent BaseAlgorithm class.
155
+ """
156
+ super().__init__(**kwargs)
157
+
158
+ @override
159
+ def run(self, modelpool: CausalLMPool):
160
+ assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
161
+ primary_model = modelpool.load_model(modelpool.all_model_names[0])
162
+ secondary_model = modelpool.load_model(modelpool.all_model_names[1])
163
+
164
+ with torch.no_grad():
165
+ primary_state_dict = primary_model.state_dict()
166
+ secondary_state_dict = secondary_model.state_dict()
167
+ state_dict = slerp_on_state_dicts(
168
+ self.t,
169
+ primary_state_dict,
170
+ secondary_state_dict,
171
+ DOT_THRESHOLD=self.DOT_THRESHOLD,
172
+ epsilon=self.epsilon,
173
+ )
174
+
175
+ if isinstance(primary_model, nn.Module):
176
+ model = primary_model
177
+ model.load_state_dict(state_dict)
178
+ elif isinstance(primary_model, LazyStateDict):
179
+ model: "PreTrainedModel" = deepcopy(primary_model.meta_module)
180
+ model.to(device=primary_model._device)
181
+ model.load_state_dict(state_dict)
182
+ else:
183
+ raise TypeError(
184
+ f"Unsupported model type: {type(primary_model)}. "
185
+ "Expected nn.Module or LazyStateDict."
186
+ )
187
+ if self.model_save_path is not None:
188
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
189
+ tokenizer = modelpool.load_tokenizer()
190
+ tokenizer.save_pretrained(self.model_save_path)
191
+ model.save_pretrained(self.model_save_path)
192
+ model_card_str = create_default_model_card(
193
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
194
+ description="Merged model using Slerp.",
195
+ algorithm_config=self.config,
196
+ modelpool_config=modelpool.config,
197
+ )
198
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
199
+ f.write(model_card_str)
200
+ return model
@@ -16,11 +16,14 @@ from transformers.data import default_data_collator
16
16
 
17
17
  from fusion_bench.method import BaseAlgorithm
18
18
  from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
19
- from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
20
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
19
+ from fusion_bench.mixins import (
20
+ LightningFabricMixin,
21
+ SimpleProfilerMixin,
22
+ auto_register_config,
23
+ )
21
24
  from fusion_bench.modelpool import Seq2SeqLMPool
22
25
  from fusion_bench.models.we_moe import WeightEnsemblingMoE
23
- from fusion_bench.utils import timeit_context
26
+ from fusion_bench.utils import print_parameters, timeit_context
24
27
  from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
25
28
  from fusion_bench.utils.instantiate_utils import instantiate
26
29
  from fusion_bench.utils.parameters import print_parameters
@@ -31,10 +34,11 @@ from .utils import get_memory_usage
31
34
  log = logging.getLogger(__name__)
32
35
 
33
36
 
37
+ @auto_register_config
34
38
  class FlanT5WeightEnsemblingMoEAlgorithm(
35
- BaseAlgorithm,
36
39
  LightningFabricMixin,
37
40
  SimpleProfilerMixin,
41
+ BaseAlgorithm,
38
42
  ):
39
43
  """
40
44
  FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
@@ -60,7 +64,6 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
60
64
  num_workers: int = 0,
61
65
  max_steps: int = 1000,
62
66
  use_grad_accumulate: bool = True,
63
- cache_dir: bool = "outputs",
64
67
  fast_dev_run: bool = False,
65
68
  **kwargs,
66
69
  ):
@@ -70,23 +73,9 @@ class FlanT5WeightEnsemblingMoEAlgorithm(
70
73
  Args:
71
74
  algorithm_config (DictConfig): The configuration for the algorithm.
72
75
  """
73
- self.checkpoint = checkpoint
74
- self.save_checkpoint = save_checkpoint
75
- self.router_hidden_layers = router_hidden_layers
76
- self.init_lambda = init_lambda
77
- self.batch_reduce = batch_reduce
78
- self.lr = lr
79
- self.optimizer = optimizer
80
- self.devices = devices
81
- self.batch_size = batch_size
82
- self.num_workers = num_workers
83
- self.max_steps = max_steps
84
- self.use_grad_accumulate = use_grad_accumulate
85
- self.cache_dir = cache_dir
86
- self.fast_dev_run = fast_dev_run
87
76
  super().__init__(**kwargs)
88
77
 
89
- def construct_moe_model(self) -> WeightEnsemblingMoE:
78
+ def construct_moe_model(self):
90
79
  """
91
80
  Construct the Mixture of Experts (MoE) model using the models in the model pool.
92
81
 
@@ -113,11 +113,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
113
113
  clip_model: Optional[CLIPModel] = None,
114
114
  task_names: Optional[List[str]] = None,
115
115
  ):
116
+ """
117
+ Initializes a zero-shot classification head.
118
+
119
+ This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
120
+ These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
121
+ from the provided CLIP model, which are necessary for calculating the final logits.
122
+
123
+ Args:
124
+ clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
125
+ clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
126
+ task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
127
+ """
116
128
  self.whether_setup_zero_shot_classification_head = True
129
+ # load clip model if not provided
117
130
  if clip_model is None:
118
131
  if self.modelpool.has_pretrained:
119
132
  clip_model = self.modelpool.load_clip_model("_pretrained_")
120
133
  else:
134
+ log.warning(
135
+ f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
136
+ )
121
137
  clip_model = self.modelpool.load_clip_model(
122
138
  self.modelpool.model_names[0]
123
139
  )
@@ -166,16 +182,20 @@ class CLIPClassificationMixin(LightningFabricMixin):
166
182
  image_embeds: Optional[torch.Tensor] = None,
167
183
  ) -> torch.Tensor:
168
184
  """
169
- Compute the logits of the images for a given task.
185
+ Computes the classification logits for a batch of images for a specific task.
186
+
187
+ This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
188
+ The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
189
+ The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.
170
190
 
171
191
  Args:
172
- module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
173
- images (torch.Tensor): The images to compute the logits.
174
- task (str): The task to compute the logits.
175
- image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
192
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
193
+ images (torch.Tensor): A batch of images to classify.
194
+ task (str): The name of the classification task.
195
+ image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.
176
196
 
177
197
  Returns:
178
- torch.Tensor: The logits of the images.
198
+ torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
179
199
  """
180
200
  text_embeds = self.zeroshot_weights[task]
181
201
 
@@ -4,7 +4,7 @@ from copy import deepcopy
4
4
  from functools import wraps
5
5
  from inspect import Parameter, _ParameterKind
6
6
  from pathlib import Path
7
- from typing import Dict, Optional, Union
7
+ from typing import Dict, Mapping, Optional, Union
8
8
 
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
 
@@ -21,6 +21,20 @@ __all__ = [
21
21
  ]
22
22
 
23
23
 
24
+ def _get_attr_name(config_mapping: Mapping[str, str], param_name):
25
+ for attr_name, p in config_mapping.items():
26
+ if p == param_name:
27
+ return attr_name
28
+ else:
29
+ raise ValueError(f"Parameter {param_name} not found in config mapping.")
30
+
31
+
32
+ def _set_attr(self, param_name: str, value):
33
+ attr_name = _get_attr_name(self._config_mapping, param_name)
34
+ log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
35
+ setattr(self, attr_name, value)
36
+
37
+
24
38
  def auto_register_config(cls):
25
39
  """
26
40
  Decorator to automatically register __init__ parameters in _config_mapping.
@@ -56,8 +70,8 @@ def auto_register_config(cls):
56
70
  ```python
57
71
  @auto_register_config
58
72
  class MyAlgorithm(BaseYAMLSerializable):
59
- def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
60
- super().__init__()
73
+ def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
74
+ super().__init__(**kwargs)
61
75
 
62
76
  # All instantiation methods work automatically:
63
77
  algo1 = MyAlgorithm(0.01, 64) # positional args
@@ -103,8 +117,7 @@ def auto_register_config(cls):
103
117
  cls._config_mapping[param_name] = param_name
104
118
 
105
119
  def __init__(self, *args, **kwargs):
106
- nonlocal original_init, registered_parameters
107
-
120
+ log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
108
121
  # auto-register the attributes based on the signature
109
122
  sig = inspect.signature(original_init)
110
123
  param_names = list(sig.parameters.keys())[1:] # Skip 'self'
@@ -117,29 +130,26 @@ def auto_register_config(cls):
117
130
  _ParameterKind.VAR_POSITIONAL,
118
131
  _ParameterKind.VAR_KEYWORD,
119
132
  ]:
120
- setattr(self, param_name, arg_value)
133
+ _set_attr(self, param_name, arg_value)
121
134
 
122
135
  # Handle keyword arguments and defaults
123
136
  for param_name in param_names:
124
- if (
125
- sig.parameters[param_name].kind
126
- not in [
127
- _ParameterKind.VAR_POSITIONAL,
128
- _ParameterKind.VAR_KEYWORD,
129
- ]
130
- ) and (param_name not in registered_parameters):
137
+ if sig.parameters[param_name].kind not in [
138
+ _ParameterKind.VAR_POSITIONAL,
139
+ _ParameterKind.VAR_KEYWORD,
140
+ ]:
131
141
  # Skip if already set by positional argument
132
142
  param_index = param_names.index(param_name)
133
143
  if param_index >= 0 and param_index < len(args):
134
144
  continue
135
145
 
136
146
  if param_name in kwargs:
137
- setattr(self, param_name, kwargs[param_name])
147
+ _set_attr(self, param_name, kwargs[param_name])
138
148
  else:
139
149
  # Set default value if available and attribute doesn't exist
140
150
  default_value = sig.parameters[param_name].default
141
151
  if default_value is not Parameter.empty:
142
- setattr(self, param_name, default_value)
152
+ _set_attr(self, param_name, default_value)
143
153
 
144
154
  # Call the original __init__
145
155
  result = original_init(self, *args, **kwargs)
@@ -277,7 +277,7 @@ class BaseModelPool(
277
277
  for dataset_name in self.test_dataset_names:
278
278
  yield self.load_test_dataset(dataset_name)
279
279
 
280
- def save_model(self, model: nn.Module, path: str):
280
+ def save_model(self, model: nn.Module, path: str, *args, **kwargs):
281
281
  """
282
282
  Save the state dictionary of the model to the specified path.
283
283