fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
22
22
  from tqdm.auto import tqdm
23
23
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
24
24
 
25
+ from fusion_bench import cache_with_joblib
25
26
  from fusion_bench.dataset.clip_dataset import CLIPDataset
26
27
  from fusion_bench.mixins import LightningFabricMixin
27
28
  from fusion_bench.modelpool import CLIPVisionModelPool
@@ -46,7 +47,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
46
47
 
47
48
  - `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
48
49
  - `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
49
- - `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
50
50
  """
51
51
 
52
52
  dataloader_kwargs: Dict[str, Any] = {}
@@ -54,7 +54,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
54
54
  modelpool: CLIPVisionModelPool = None
55
55
  _clip_processor: CLIPProcessor = None
56
56
  # a dict of zeroshot weights for each task, each key is the task name
57
- zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
58
57
  zeroshot_weights: Dict[str, torch.Tensor] = {}
59
58
  whether_setup_zero_shot_classification_head = False
60
59
 
@@ -114,11 +113,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
114
113
  clip_model: Optional[CLIPModel] = None,
115
114
  task_names: Optional[List[str]] = None,
116
115
  ):
116
+ """
117
+ Initializes a zero-shot classification head.
118
+
119
+ This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
120
+ These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
121
+ from the provided CLIP model, which are necessary for calculating the final logits.
122
+
123
+ Args:
124
+ clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
125
+ clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
126
+ task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
127
+ """
117
128
  self.whether_setup_zero_shot_classification_head = True
129
+ # load clip model if not provided
118
130
  if clip_model is None:
119
131
  if self.modelpool.has_pretrained:
120
132
  clip_model = self.modelpool.load_clip_model("_pretrained_")
121
133
  else:
134
+ log.warning(
135
+ f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
136
+ )
122
137
  clip_model = self.modelpool.load_clip_model(
123
138
  self.modelpool.model_names[0]
124
139
  )
@@ -131,26 +146,16 @@ class CLIPClassificationMixin(LightningFabricMixin):
131
146
  self.visual_projection = self.fabric.to_device(self.visual_projection)
132
147
  self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
133
148
 
134
- # get cache directory
135
- if self.modelpool.has_pretrained:
136
- model_name = self.modelpool.get_model_config("_pretrained_")
137
- if not isinstance(model_name, str):
138
- model_name = model_name.pretrained_model_name_or_path
139
- else:
140
- model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
141
- if not isinstance(model_name, str):
142
- model_name = model_name.pretrained_model_name_or_path
143
- cache_dir = os.path.join(
144
- self.zeroshot_weights_cache_dir,
145
- os.path.normpath(model_name.split("/")[-1]),
146
- )
147
- if not os.path.exists(cache_dir):
148
- log.info(
149
- f"Creating cache directory for zero-shot classification head at {cache_dir}"
150
- )
151
- os.makedirs(cache_dir)
149
+ @cache_with_joblib()
150
+ def construct_classification_head(task: str):
151
+ nonlocal clip_classifier
152
+
153
+ classnames, templates = get_classnames_and_templates(task)
154
+ clip_classifier.set_classification_task(classnames, templates)
155
+ zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
156
+
157
+ return zeroshot_weights
152
158
 
153
- log.info(f"cache directory for zero-shot classification head: {cache_dir}")
154
159
  for task in tqdm(
155
160
  self.modelpool.model_names if task_names is None else task_names,
156
161
  "Setting up zero-shot classification head",
@@ -158,27 +163,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
158
163
  ):
159
164
  zeroshot_weights = None
160
165
  if self.fabric.is_global_zero:
161
- cache_file = os.path.join(
162
- cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
163
- )
164
- if os.path.exists(cache_file):
165
- zeroshot_weights = torch.load(
166
- cache_file,
167
- map_location="cpu",
168
- weights_only=True,
169
- ).detach()
170
- log.info(
171
- f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
172
- )
173
- else:
174
- log.info(
175
- f"Construct zero shot classification head for task: {task}"
176
- )
177
- classnames, templates = get_classnames_and_templates(task)
178
- clip_classifier.set_classification_task(classnames, templates)
179
- zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
180
- log.info(f"save zeroshot weights to {cache_file}")
181
- torch.save(zeroshot_weights, cache_file)
166
+ zeroshot_weights = construct_classification_head(task)
182
167
 
183
168
  self.fabric.barrier()
184
169
  self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
@@ -197,16 +182,20 @@ class CLIPClassificationMixin(LightningFabricMixin):
197
182
  image_embeds: Optional[torch.Tensor] = None,
198
183
  ) -> torch.Tensor:
199
184
  """
200
- Compute the logits of the images for a given task.
185
+ Computes the classification logits for a batch of images for a specific task.
186
+
187
+ This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
188
+ The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
189
+ The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.
201
190
 
202
191
  Args:
203
- module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
204
- images (torch.Tensor): The images to compute the logits.
205
- task (str): The task to compute the logits.
206
- image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
192
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
193
+ images (torch.Tensor): A batch of images to classify.
194
+ task (str): The name of the classification task.
195
+ image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.
207
196
 
208
197
  Returns:
209
- torch.Tensor: The logits of the images.
198
+ torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
210
199
  """
211
200
  text_embeds = self.zeroshot_weights[task]
212
201
 
@@ -4,7 +4,7 @@ from copy import deepcopy
4
4
  from functools import wraps
5
5
  from inspect import Parameter, _ParameterKind
6
6
  from pathlib import Path
7
- from typing import Dict, Optional, Union
7
+ from typing import Dict, Mapping, Optional, Union
8
8
 
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
 
@@ -21,6 +21,20 @@ __all__ = [
21
21
  ]
22
22
 
23
23
 
24
+ def _get_attr_name(config_mapping: Mapping[str, str], param_name):
25
+ for attr_name, p in config_mapping.items():
26
+ if p == param_name:
27
+ return attr_name
28
+ else:
29
+ raise ValueError(f"Parameter {param_name} not found in config mapping.")
30
+
31
+
32
+ def _set_attr(self, param_name: str, value):
33
+ attr_name = _get_attr_name(self._config_mapping, param_name)
34
+ log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
35
+ setattr(self, attr_name, value)
36
+
37
+
24
38
  def auto_register_config(cls):
25
39
  """
26
40
  Decorator to automatically register __init__ parameters in _config_mapping.
@@ -56,8 +70,8 @@ def auto_register_config(cls):
56
70
  ```python
57
71
  @auto_register_config
58
72
  class MyAlgorithm(BaseYAMLSerializable):
59
- def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
60
- super().__init__()
73
+ def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
74
+ super().__init__(**kwargs)
61
75
 
62
76
  # All instantiation methods work automatically:
63
77
  algo1 = MyAlgorithm(0.01, 64) # positional args
@@ -90,14 +104,20 @@ def auto_register_config(cls):
90
104
  # Auto-register parameters in _config_mapping
91
105
  if not "_config_mapping" in cls.__dict__:
92
106
  cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
107
+ registered_parameters = tuple(cls._config_mapping.values())
108
+
93
109
  for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
94
- if sig.parameters[param_name].kind not in [
95
- _ParameterKind.VAR_POSITIONAL,
96
- _ParameterKind.VAR_KEYWORD,
97
- ]:
110
+ if (
111
+ sig.parameters[param_name].kind
112
+ not in [
113
+ _ParameterKind.VAR_POSITIONAL,
114
+ _ParameterKind.VAR_KEYWORD,
115
+ ]
116
+ ) and (param_name not in registered_parameters):
98
117
  cls._config_mapping[param_name] = param_name
99
118
 
100
119
  def __init__(self, *args, **kwargs):
120
+ log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
101
121
  # auto-register the attributes based on the signature
102
122
  sig = inspect.signature(original_init)
103
123
  param_names = list(sig.parameters.keys())[1:] # Skip 'self'
@@ -110,7 +130,7 @@ def auto_register_config(cls):
110
130
  _ParameterKind.VAR_POSITIONAL,
111
131
  _ParameterKind.VAR_KEYWORD,
112
132
  ]:
113
- setattr(self, param_name, arg_value)
133
+ _set_attr(self, param_name, arg_value)
114
134
 
115
135
  # Handle keyword arguments and defaults
116
136
  for param_name in param_names:
@@ -124,12 +144,12 @@ def auto_register_config(cls):
124
144
  continue
125
145
 
126
146
  if param_name in kwargs:
127
- setattr(self, param_name, kwargs[param_name])
147
+ _set_attr(self, param_name, kwargs[param_name])
128
148
  else:
129
149
  # Set default value if available and attribute doesn't exist
130
150
  default_value = sig.parameters[param_name].default
131
151
  if default_value is not Parameter.empty:
132
- setattr(self, param_name, default_value)
152
+ _set_attr(self, param_name, default_value)
133
153
 
134
154
  # Call the original __init__
135
155
  result = original_init(self, *args, **kwargs)
@@ -277,7 +277,7 @@ class BaseModelPool(
277
277
  for dataset_name in self.test_dataset_names:
278
278
  yield self.load_test_dataset(dataset_name)
279
279
 
280
- def save_model(self, model: nn.Module, path: str):
280
+ def save_model(self, model: nn.Module, path: str, *args, **kwargs):
281
281
  """
282
282
  Save the state dictionary of the model to the specified path.
283
283