fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  30. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  32. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  33. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  34. fusion_bench/programs/fabric_fusion_program.py +5 -0
  35. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  36. fusion_bench/utils/__init__.py +1 -0
  37. fusion_bench/utils/data.py +1 -1
  38. fusion_bench/utils/lazy_state_dict.py +268 -0
  39. fusion_bench/utils/parameters.py +33 -0
  40. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  41. fusion_bench/utils/type.py +1 -0
  42. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
  43. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
  44. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  45. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  53. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  54. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  55. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  56. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  57. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  58. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  60. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  61. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  63. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  64. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  68. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  72. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  73. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  74. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  75. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,38 @@
1
+ from typing import List, Optional, Union
2
+
3
+
4
+ from transformers.models.llama import LlamaTokenizerFast
5
+
6
+
7
+ class DeepseekTokenizerFast(LlamaTokenizerFast):
8
+
9
+ def convert_ids_to_tokens(
10
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
11
+ ) -> Union[str, List[str]]:
12
+ """
13
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
14
+ added tokens.
15
+
16
+ Args:
17
+ ids (`int` or `List[int]`):
18
+ The token id (or token ids) to convert to tokens.
19
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
20
+ Whether or not to remove special tokens in the decoding.
21
+
22
+ Returns:
23
+ `str` or `List[str]`: The decoded token(s).
24
+ """
25
+ if isinstance(ids, int):
26
+ return self._convert_id_to_token(ids)
27
+ tokens = []
28
+ for index in ids:
29
+ index = int(index)
30
+ if skip_special_tokens and index in self.all_special_ids:
31
+ continue
32
+ token = self._tokenizer.id_to_token(index)
33
+ tokens.append(token if token is not None else "")
34
+ return tokens
35
+
36
+ def _convert_id_to_token(self, index: int) -> Optional[str]:
37
+ token = self._tokenizer.id_to_token(int(index))
38
+ return token if token is not None else ""
@@ -196,6 +196,11 @@ class FabricModelFusionProgram(
196
196
  for key, item in merged_model.items():
197
197
  if isinstance(item, nn.Module):
198
198
  report[key] = taskpool.evaluate(item, *args, **kwargs)
199
+ elif key == "models":
200
+ # for multi-model evaluation
201
+ report[key] = self.evaluate_merged_model(
202
+ taskpool, item, *args, **kwargs
203
+ )
199
204
  else:
200
205
  # metadata
201
206
  report[key] = item
@@ -348,8 +348,15 @@ class CLIPVisionModelTaskPool(
348
348
 
349
349
  log.info(f"Evaluation Result: {report}")
350
350
  if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
351
- with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
351
+ save_path = os.path.join(self.log_dir, "report.json")
352
+ for version in itertools.count(1):
353
+ if not os.path.exists(save_path):
354
+ break
355
+ # if the file already exists, increment the version to avoid overwriting
356
+ save_path = os.path.join(self.log_dir, f"report_{version}.json")
357
+ with open(save_path, "w") as fp:
352
358
  json.dump(report, fp)
359
+ log.info(f"Evaluation report saved to {save_path}")
353
360
  return report
354
361
 
355
362
  def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
@@ -12,3 +12,4 @@ from .misc import *
12
12
  from .packages import import_object
13
13
  from .parameters import *
14
14
  from .timer import timeit_context
15
+ from .lazy_state_dict import LazyStateDict
@@ -96,7 +96,7 @@ def train_validation_split(
96
96
 
97
97
  # Compute the number of samples for training and validation
98
98
  num_samples = len(dataset)
99
- if validation_size is not None:
99
+ if validation_size is None:
100
100
  assert (
101
101
  0 < validation_fraction < 1
102
102
  ), "Validation fraction must be between 0 and 1"
@@ -0,0 +1,268 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
5
+
6
+ import torch
7
+ from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
8
+ from huggingface_hub import snapshot_download
9
+ from safetensors import safe_open
10
+ from safetensors.torch import load_file
11
+ from transformers import AutoConfig
12
+
13
+ from fusion_bench.utils.dtype import parse_dtype
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers import PretrainedConfig
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+ __all__ = ["resolve_checkpoint_path", "LazyStateDict"]
21
+
22
+
23
+ def resolve_checkpoint_path(
24
+ checkpoint: str,
25
+ hf_revision: Optional[str] = None,
26
+ hf_cache_dir: Optional[str] = None,
27
+ hf_proxies: Optional[Dict] = None,
28
+ ):
29
+ # If it's a local file or directory, return as is
30
+ if os.path.exists(checkpoint):
31
+ return checkpoint
32
+ # If it's a HuggingFace Hub model id, download snapshot
33
+ try:
34
+ # This will download the model to the cache and return the local path
35
+ local_path = snapshot_download(
36
+ repo_id=checkpoint,
37
+ revision=hf_revision,
38
+ cache_dir=hf_cache_dir,
39
+ proxies=hf_proxies,
40
+ )
41
+ return local_path
42
+ except Exception as e:
43
+ raise FileNotFoundError(
44
+ f"Could not resolve checkpoint: {checkpoint}. Error: {e}"
45
+ )
46
+
47
+
48
+ class LazyStateDict:
49
+ """
50
+ Dictionary-like object that lazily loads a state dict from a checkpoint path.
51
+ """
52
+
53
+ _local_path: str
54
+ _state_dict_cache: Optional[Dict]
55
+ _index_filename: Optional[str]
56
+ _checkpoint_files: Optional[List[str]]
57
+ _index: Optional[Dict]
58
+
59
+ def __init__(
60
+ self,
61
+ checkpoint: str,
62
+ cache_state_dict: bool = False,
63
+ torch_dtype: Optional[torch.dtype] = None,
64
+ device: str = "cpu",
65
+ hf_revision: Optional[str] = None,
66
+ hf_cache_dir: Optional[str] = None,
67
+ hf_proxies: Optional[Dict] = None,
68
+ ):
69
+ self._checkpoint = checkpoint
70
+ self._local_path = resolve_checkpoint_path(
71
+ checkpoint,
72
+ hf_revision=hf_revision,
73
+ hf_cache_dir=hf_cache_dir,
74
+ hf_proxies=hf_proxies,
75
+ )
76
+
77
+ self._index, self._index_filename, self._checkpoint_files = (
78
+ self._resolve_checkpoint_files(self._local_path)
79
+ )
80
+
81
+ if cache_state_dict:
82
+ self._state_dict_cache = {}
83
+ else:
84
+ self._state_dict_cache = None
85
+
86
+ self._torch_dtype = parse_dtype(torch_dtype)
87
+ self._device = device
88
+
89
+ @property
90
+ def checkpoint(self) -> str:
91
+ return self._checkpoint
92
+
93
+ @property
94
+ def config(self) -> "PretrainedConfig":
95
+ return AutoConfig.from_pretrained(self._checkpoint)
96
+
97
+ def state_dict(self) -> "LazyStateDict":
98
+ return self
99
+
100
+ def _resolve_checkpoint_files(self, checkpoint: str):
101
+ # reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
102
+ checkpoint_files = None
103
+ index_filename = None
104
+ if os.path.isfile(checkpoint):
105
+ if str(checkpoint).endswith(".json"):
106
+ index_filename = checkpoint
107
+ else:
108
+ checkpoint_files = [checkpoint]
109
+ elif os.path.isdir(checkpoint):
110
+ # check if the whole state dict is present
111
+ potential_state_bin = [
112
+ f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME
113
+ ]
114
+ potential_state_safetensor = [
115
+ f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME
116
+ ]
117
+ if len(potential_state_bin) == 1:
118
+ checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
119
+ elif len(potential_state_safetensor) == 1:
120
+ checkpoint_files = [
121
+ os.path.join(checkpoint, potential_state_safetensor[0])
122
+ ]
123
+ else:
124
+ # otherwise check for sharded checkpoints
125
+ potential_index = [
126
+ f for f in os.listdir(checkpoint) if f.endswith(".index.json")
127
+ ]
128
+ if len(potential_index) == 0:
129
+ raise ValueError(
130
+ f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
131
+ )
132
+ elif len(potential_index) == 1:
133
+ index_filename = os.path.join(checkpoint, potential_index[0])
134
+ else:
135
+ raise ValueError(
136
+ f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
137
+ )
138
+ else:
139
+ raise ValueError(
140
+ "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
141
+ f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
142
+ )
143
+
144
+ if index_filename is not None:
145
+ checkpoint_folder = os.path.split(index_filename)[0]
146
+ with open(index_filename) as f:
147
+ index = json.loads(f.read())
148
+
149
+ if "weight_map" in index:
150
+ index = index["weight_map"]
151
+ checkpoint_files = sorted(list(set(index.values())))
152
+ checkpoint_files = [
153
+ os.path.join(checkpoint_folder, f) for f in checkpoint_files
154
+ ]
155
+ return index, index_filename, checkpoint_files
156
+
157
+ def _load_tensor_from_checkpoint_file(
158
+ self, checkpoint_file: str, key: str, update_cache: bool = True
159
+ ) -> torch.Tensor:
160
+ if checkpoint_file.endswith(".safetensors"):
161
+ with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
162
+ tensor = f.get_tensor(key)
163
+ if self._torch_dtype is not None:
164
+ tensor = tensor.to(self._torch_dtype)
165
+ if update_cache and self._state_dict_cache is not None:
166
+ self._state_dict_cache[key] = tensor
167
+ return tensor
168
+ else:
169
+ state_dict = torch.load(checkpoint_file, map_location=self._device)
170
+ if update_cache:
171
+ if self._state_dict_cache is not None:
172
+ self._state_dict_cache.update(state_dict)
173
+ else:
174
+ log.warning(
175
+ f"Load full state dict from file {checkpoint_file}, but state dict cache is disabled."
176
+ )
177
+ return state_dict[key]
178
+
179
+ def __getitem__(self, key: str) -> torch.Tensor:
180
+ if self._state_dict_cache is not None and key in self._state_dict_cache:
181
+ return self._state_dict_cache[key]
182
+
183
+ if self._index is None:
184
+ if len(self._checkpoint_files) == 1 and os.path.isfile(
185
+ self._checkpoint_files[0]
186
+ ):
187
+ checkpoint_file = self._checkpoint_files[0]
188
+ tensor = self._load_tensor_from_checkpoint_file(
189
+ checkpoint_file, key, update_cache=True
190
+ )
191
+ return tensor
192
+ else:
193
+ if len(self._checkpoint_files) > 1:
194
+ raise RuntimeError(
195
+ "Get multiple checkpoint files, but index is not provided."
196
+ )
197
+ if not os.path.isfile(self._checkpoint_files[0]):
198
+ raise FileNotFoundError(
199
+ f"Checkpoint file {self._checkpoint_files[0]} not found."
200
+ )
201
+ raise RuntimeError("Unexpected error.")
202
+ else:
203
+ if key not in self._index:
204
+ raise KeyError(f"Key {key} not found in index.")
205
+ checkpoint_file = os.path.join(self._local_path, self._index[key])
206
+ if not os.path.isfile(checkpoint_file):
207
+ raise FileNotFoundError(f"Checkpoint file {checkpoint_file} not found.")
208
+ tensor = self._load_tensor_from_checkpoint_file(
209
+ checkpoint_file, key, update_cache=True
210
+ )
211
+ return tensor
212
+
213
+ def __contains__(self, key: str) -> bool:
214
+ if self._state_dict_cache is not None and key in self._state_dict_cache:
215
+ return True
216
+ if self._index is not None and key in self._index:
217
+ return True
218
+ if len(self._checkpoint_files) == 1 and os.path.isfile(
219
+ self._checkpoint_files[0]
220
+ ):
221
+ try:
222
+ tensor = self._load_tensor_from_checkpoint_file(
223
+ self._checkpoint_files[0], key, update_cache=False
224
+ )
225
+ return tensor is not None
226
+ except Exception:
227
+ return False
228
+ return False
229
+
230
+ def __len__(self) -> int:
231
+ if self._index is not None:
232
+ return len(self._index)
233
+ if len(self._checkpoint_files) == 1 and os.path.isfile(
234
+ self._checkpoint_files[0]
235
+ ):
236
+ checkpoint_file = self._checkpoint_files[0]
237
+ if checkpoint_file.endswith(".safetensors"):
238
+ with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
239
+ return len(tuple(f.keys()))
240
+ else:
241
+ return len(
242
+ tuple(torch.load(checkpoint_file, map_location="cpu").keys())
243
+ )
244
+ raise RuntimeError(
245
+ "Unexpected error: cannot determine the number of keys in the state dict."
246
+ )
247
+
248
+ def __iter__(self) -> Iterator[str]:
249
+ if self._index is not None:
250
+ return iter(self._index)
251
+ return iter(self._checkpoint_files)
252
+
253
+ def keys(self) -> List[str]:
254
+ return list(self)
255
+
256
+ def values(self) -> List[torch.Tensor]:
257
+ return [self[key] for key in self]
258
+
259
+ def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
260
+ return ((key, self[key]) for key in self)
261
+
262
+ def __repr__(self) -> str:
263
+ if self._index is not None:
264
+ return f"{self.__class__.__name__}(index={self._index})"
265
+ else:
266
+ return (
267
+ f"{self.__class__.__name__}(checkpoint_files={self._checkpoint_files})"
268
+ )
@@ -222,6 +222,39 @@ def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[in
222
222
  return trainable_params, all_param
223
223
 
224
224
 
225
+ @torch.no_grad()
226
+ def get_parameter_summary(
227
+ module_or_state_dict: Union[nn.Module, StateDictType], non_zero_only: bool = False
228
+ ) -> dict:
229
+ """
230
+ Get a summary of the parameters in a PyTorch model.
231
+ """
232
+ if isinstance(module_or_state_dict, nn.Module):
233
+ state_dict = module_or_state_dict.state_dict(keep_vars=True)
234
+ else:
235
+ state_dict = module_or_state_dict
236
+
237
+ trainable_params = 0
238
+ all_param = 0
239
+ bytes = 0
240
+
241
+ for name, param in state_dict.items():
242
+ # count the number of parameters
243
+ num_params = _numel(param, non_zero_only)
244
+ bytes += _numel(param, non_zero_only=False) * param.element_size()
245
+
246
+ # accumulate the number of trainable and total parameters
247
+ all_param += num_params
248
+ if param.requires_grad:
249
+ trainable_params += num_params
250
+
251
+ return {
252
+ "trainable_params": trainable_params,
253
+ "all_param": all_param,
254
+ "bytes": bytes,
255
+ }
256
+
257
+
225
258
  def print_parameters(
226
259
  module: nn.Module,
227
260
  is_human_readable: bool = True,
@@ -1,12 +1,12 @@
1
1
  from collections import OrderedDict
2
2
  from numbers import Number
3
- from typing import Dict, List, Union, cast
3
+ from typing import Callable, Dict, List, Literal, Union, cast
4
4
 
5
5
  import torch
6
6
  from torch import Tensor
7
7
 
8
8
  from .parameters import check_parameters_all_equal
9
- from .type import StateDictType
9
+ from .type import BoolStateDictType, StateDictType
10
10
 
11
11
 
12
12
  def to_device(
@@ -295,3 +295,75 @@ def state_dict_weighted_sum(
295
295
  device, non_blocking=True
296
296
  )
297
297
  return weighted_sum_state_dict
298
+
299
+
300
+ def state_dict_diff_abs(a: StateDictType, b: StateDictType):
301
+ """
302
+ Returns the per-layer abs of the difference between two state dicts.
303
+
304
+ Args:
305
+ a (StateDictType): The first state dict.
306
+ b (StateDictType): The second state dict.
307
+
308
+ Returns:
309
+ StateDictType: The absolute difference between the two state dicts.
310
+ """
311
+ diff = state_dict_sub(a, b)
312
+ abs_diff = {key: diff[key].abs() for key in diff}
313
+ return abs_diff
314
+
315
+
316
+ def state_dict_binary_mask(
317
+ a: StateDictType,
318
+ b: StateDictType,
319
+ compare_fn: Union[
320
+ Literal["greater", "less", "equal", "not_equal"],
321
+ Callable[[Tensor, Tensor], torch.BoolTensor],
322
+ ] = "greater",
323
+ ) -> BoolStateDictType:
324
+ """
325
+ Returns the binary mask of elements in a compared to elements in b using the provided comparison function.
326
+
327
+ Args:
328
+ a (StateDictType): The first state dict.
329
+ b (StateDictType): The second state dict.
330
+ compare_fn (Union[Literal["greater", "less", "equal", "not_equal"], Callable[[Tensor, Tensor], Tensor]]): A function that takes two tensors and returns a boolean tensor.
331
+ Defaults to greater than comparison (x > y).
332
+
333
+ Returns:
334
+ StateDictType: A dictionary containing binary masks (0 or 1) based on the comparison.
335
+ """
336
+ compare_fn_dict = {
337
+ "greater": lambda x, y: x > y,
338
+ "less": lambda x, y: x < y,
339
+ "equal": lambda x, y: x == y,
340
+ "not_equal": lambda x, y: x != y,
341
+ }
342
+ if isinstance(compare_fn, str):
343
+ compare_fn = compare_fn_dict[compare_fn]
344
+ elif not callable(compare_fn):
345
+ raise ValueError(
346
+ f"compare_fn must be a string or a callable, but got {type(compare_fn)}"
347
+ )
348
+
349
+ mask = OrderedDict()
350
+ for key in a:
351
+ mask[key] = compare_fn(a[key], b[key])
352
+ return mask
353
+
354
+
355
+ def state_dict_hadmard_product(a: StateDictType, b: StateDictType) -> StateDictType:
356
+ """
357
+ Returns the Hadamard product of two state dicts, i.e. element-wise product.
358
+
359
+ Args:
360
+ a (StateDictType): The first state dict.
361
+ b (StateDictType): The second state dict.
362
+
363
+ Returns:
364
+ StateDictType: The Hadamard product of the two state dicts.
365
+ """
366
+ ans = OrderedDict()
367
+ for key in a:
368
+ ans[key] = a[key] * b[key]
369
+ return ans
@@ -9,6 +9,7 @@ try:
9
9
  from torch import Tensor, nn
10
10
 
11
11
  StateDictType: TypeAlias = Dict[str, Tensor]
12
+ BoolStateDictType: TypeAlias = Dict[str, torch.BoolTensor]
12
13
  TorchModelType = TypeVar("TorchModelType", bound=nn.Module)
13
14
 
14
15
  except ImportError:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.15
3
+ Version: 0.2.16
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License
@@ -70,7 +70,7 @@ Dynamic: license-file
70
70
 
71
71
  FusionBench is a benchmark suite designed to evaluate the performance of various deep model fusion techniques. It aims to provide a comprehensive comparison of different methods on a variety of datasets and tasks.
72
72
 
73
- Projects based on FusionBench and news from the community (descending order of date):
73
+ Projects based on FusionBench and news from the community (descending order of date. If you have any work based on FusionBench, please feel free to let us know, we are willing to add it to the list. :partying_face:):
74
74
 
75
75
  <details>
76
76
  <summary>Hao Mark Chen, et al. FW-Merging: Scaling Model Merging with Frank-Wolfe Optimization. Mar 2025. https://arxiv.org/abs/2503.12649</summary>
@@ -139,6 +139,10 @@ cd fusion_bench
139
139
  pip install -e . # install the package in editable mode
140
140
  ```
141
141
 
142
+ > [!TIP]
143
+ > FusionBench is highly dependent on the use of [Hydra](https://hydra.cc/) for configuration management and command line argument parsing, and [Lightning Fabric](https://lightning.ai/) for device management.
144
+ > If you are not familiar with these tools, it is strongly recommended to read the [Hydra](https://hydra.cc/docs/intro/) and [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) documentation.
145
+
142
146
  ### Install with [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
143
147
 
144
148
  [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10256836.svg)](https://doi.org/10.5281/zenodo.10256836)