fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. fusion_bench/method/__init__.py +11 -0
  2. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
  3. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
  4. fusion_bench/method/base_algorithm.py +1 -0
  5. fusion_bench/method/dawe/dawe_for_clip.py +1 -1
  6. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
  7. fusion_bench/method/expert_sparsity/__init__.py +10 -0
  8. fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
  9. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
  10. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
  11. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
  12. fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
  13. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
  14. fusion_bench/method/knots/__init__.py +0 -0
  15. fusion_bench/method/knots/knots_utils.py +23 -0
  16. fusion_bench/method/pwe_moe/module.py +2 -7
  17. fusion_bench/method/simple_average.py +3 -2
  18. fusion_bench/method/task_singular_vector/TSVM.py +238 -25
  19. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
  20. fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
  21. fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
  22. fusion_bench/mixins/hydra_config.py +1 -1
  23. fusion_bench/mixins/lightning_fabric.py +25 -1
  24. fusion_bench/mixins/serialization.py +18 -2
  25. fusion_bench/modelpool/base_pool.py +1 -0
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
  27. fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
  28. fusion_bench/models/__init__.py +1 -0
  29. fusion_bench/models/expert_sparsity/__init__.py +0 -0
  30. fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
  31. fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
  32. fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
  33. fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
  34. fusion_bench/models/parameter_dict.py +6 -1
  35. fusion_bench/programs/fabric_fusion_program.py +21 -13
  36. fusion_bench/taskpool/base_pool.py +1 -0
  37. fusion_bench/taskpool/dummy.py +6 -4
  38. fusion_bench/utils/__init__.py +4 -3
  39. fusion_bench/utils/dtype.py +2 -1
  40. fusion_bench/utils/fabric.py +11 -4
  41. fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
  42. fusion_bench/utils/lazy_state_dict.py +80 -10
  43. fusion_bench/utils/pylogger.py +30 -0
  44. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
  45. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
  46. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
  47. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
  48. fusion_bench_config/fabric_model_fusion.yaml +2 -2
  49. fusion_bench_config/method/expert_sparsity/README.md +6 -0
  50. fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
  51. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
  52. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
  53. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
  54. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
  56. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
  57. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
  59. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ import itertools as I
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers.models.mixtral.modeling_mixtral import (
9
+ MixtralBlockSparseTop2MLP,
10
+ MixtralForCausalLM,
11
+ MixtralSparseMoeBlock,
12
+ )
13
+
14
+ from .dataset import CacheDataset
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class PrunableMixtralSparseMoeBlockWrapper(torch.nn.Module):
20
+ """
21
+ Wrapper of `MixtralSparseMoeBlock` that supports expert pruning.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: MixtralSparseMoeBlock,
27
+ r: Optional[int] = None,
28
+ ):
29
+ """
30
+ Args:
31
+ model: The model to be wrapped.
32
+ r: The number of experts to keep.
33
+ """
34
+ super().__init__()
35
+ if isinstance(model, MixtralSparseMoeBlock):
36
+ self.model = model
37
+ else:
38
+ self.model = model.model
39
+ self.r = r
40
+
41
+ self.experts_to_drop = None
42
+ self.cache_space = CacheDataset()
43
+ self.cache_logits = False
44
+ self.cache_X = False
45
+ self.cache_Z = False
46
+
47
+ # Forward uses topk
48
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
49
+ """ """
50
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
51
+ hidden_states = hidden_states.view(-1, hidden_dim)
52
+ # router_logits: (batch * sequence_length, n_experts)
53
+ router_logits = self.model.gate(hidden_states)
54
+
55
+ if self.experts_to_drop is not None:
56
+ for e in self.experts_to_drop:
57
+ router_logits[:, e] = -float("inf")
58
+
59
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
60
+ routing_weights, selected_experts = torch.topk(
61
+ routing_weights, self.model.top_k, dim=-1
62
+ )
63
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
64
+ # we cast back to the input dtype
65
+ routing_weights = routing_weights.to(hidden_states.dtype)
66
+
67
+ final_hidden_states = torch.zeros(
68
+ (batch_size * sequence_length, hidden_dim),
69
+ dtype=hidden_states.dtype,
70
+ device=hidden_states.device,
71
+ )
72
+
73
+ # One hot encode the selected experts to create an expert mask
74
+ # this will be used to easily index which expert is going to be sollicitated
75
+ expert_mask = torch.nn.functional.one_hot(
76
+ selected_experts, num_classes=self.model.num_experts
77
+ ).permute(2, 1, 0)
78
+
79
+ # Loop over all available experts in the model and perform the computation on each expert
80
+ for expert_idx in range(self.model.num_experts):
81
+ expert_layer = self.model.experts[expert_idx]
82
+ idx, top_x = torch.where(expert_mask[expert_idx])
83
+
84
+ if top_x.shape[0] == 0:
85
+ continue
86
+
87
+ # in torch it is faster to index using lists than torch tensors
88
+ top_x_list = top_x.tolist()
89
+ idx_list = idx.tolist()
90
+
91
+ # Index the correct hidden states and compute the expert hidden state for
92
+ # the current expert. We need to make sure to multiply the output hidden
93
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
94
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
95
+ current_hidden_states = (
96
+ expert_layer(current_state)
97
+ * routing_weights[top_x_list, idx_list, None]
98
+ )
99
+
100
+ # However `index_add_` only support torch tensors for indexing so we'll use
101
+ # the `top_x` tensor here.
102
+ final_hidden_states.index_add_(
103
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
104
+ )
105
+
106
+ if self.experts_to_drop is not None and (
107
+ self.cache_logits or self.cache_X or self.cache_Z
108
+ ):
109
+ logger.warn(
110
+ f"Already dropped {self.experts_to_drop} but still storing activations."
111
+ )
112
+ self.cache_space.append(
113
+ alpha=(router_logits if self.cache_logits else None),
114
+ X=(hidden_states if self.cache_X else None),
115
+ Z=(final_hidden_states if self.cache_Z else None),
116
+ )
117
+
118
+ final_hidden_states = final_hidden_states.reshape(
119
+ batch_size, sequence_length, hidden_dim
120
+ )
121
+
122
+ return final_hidden_states, router_logits
123
+
124
+ @torch.no_grad()
125
+ def enumerate(self):
126
+ # disable caching
127
+ self.cache_logits = False
128
+ self.cache_X = False
129
+ self.cache_Z = False
130
+ loss_history = dict()
131
+
132
+ with torch.inference_mode():
133
+ for dropped in I.combinations(
134
+ range(self.model.num_experts), self.model.num_experts - self.r
135
+ ):
136
+ self.experts_to_drop = dropped
137
+ loss = 0
138
+
139
+ for hidden_states, final_hidden_states in zip(
140
+ self.cache_space.Xs, self.cache_space.Zs
141
+ ):
142
+ hidden_states = hidden_states.to(
143
+ device=self.model.gate.weight.data.device, non_blocking=True
144
+ )
145
+ final_hidden_states = final_hidden_states.to(
146
+ dtype=torch.float64,
147
+ device=self.model.gate.weight.data.device,
148
+ non_blocking=True,
149
+ )
150
+ final_hidden_states_e, _ = self.forward(hidden_states.unsqueeze(0))
151
+ # compute the |Z - Z_e|_2 L2 loss
152
+ loss += torch.norm(
153
+ final_hidden_states
154
+ - final_hidden_states_e.squeeze(0).to(torch.float64)
155
+ ).item()
156
+ loss_history[dropped] = loss
157
+
158
+ self.experts_to_drop = min(loss_history, key=loss_history.get)
159
+ return loss_history
160
+
161
+ @torch.no_grad()
162
+ def prune(self):
163
+ assert self.experts_to_drop is not None
164
+ assert len(self.experts_to_drop) == self.model.num_experts - self.r
165
+ del self.cache_space
166
+ self.cache_X = False
167
+ self.cache_Z = False
168
+
169
+ experts_to_reserve = sorted(
170
+ set(range(self.model.num_experts)) - set(self.experts_to_drop)
171
+ )
172
+
173
+ # create a new gate with the experts to reserve
174
+ gate_new = torch.nn.Linear(
175
+ in_features=self.model.gate.in_features,
176
+ out_features=self.r,
177
+ bias=False,
178
+ device=self.model.gate.weight.data.device,
179
+ dtype=torch.bfloat16,
180
+ )
181
+ gate_new.weight.data = self.model.gate.weight.data[list(experts_to_reserve)]
182
+ self.model.gate = gate_new
183
+
184
+ self.model.experts = torch.nn.ModuleList(
185
+ [self.model.experts[i] for i in experts_to_reserve]
186
+ )
187
+ self.model.num_experts = self.r
188
+
189
+
190
+ class DynamicSkippingMixtralSparseMoeBlockWrapper(nn.Module):
191
+ def __init__(self, model: MixtralSparseMoeBlock, beta: float):
192
+ super().__init__()
193
+ assert isinstance(model, MixtralSparseMoeBlock)
194
+ assert model.top_k == 2
195
+ self.hidden_dim = model.hidden_dim
196
+ self.ffn_dim = model.ffn_dim
197
+ self.num_experts = model.num_experts
198
+ self.top_k = model.top_k
199
+ self.gate = model.gate
200
+ self.experts = model.experts
201
+
202
+ self.beta = beta
203
+
204
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
205
+ """ """
206
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
207
+ hidden_states = hidden_states.view(-1, hidden_dim)
208
+ # router_logits: (batch * sequence_length, n_experts)
209
+ router_logits = self.gate(hidden_states)
210
+
211
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
212
+ routing_weights, selected_experts = torch.topk(
213
+ routing_weights, self.top_k, dim=-1
214
+ )
215
+
216
+ # (batch * sequence_length)
217
+ mask_top1 = routing_weights[:, 1] < self.beta * routing_weights[:, 0]
218
+ routing_weights[mask_top1, 1] = 0
219
+
220
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
221
+ # we cast back to the input dtype
222
+ routing_weights = routing_weights.to(hidden_states.dtype)
223
+
224
+ final_hidden_states = torch.zeros(
225
+ (batch_size * sequence_length, hidden_dim),
226
+ dtype=hidden_states.dtype,
227
+ device=hidden_states.device,
228
+ )
229
+
230
+ # One hot encode the selected experts to create an expert mask
231
+ # this will be used to easily index which expert is going to be sollicitated
232
+ # (batch * sequence_length, self.top_k, n_experts)
233
+ expert_mask = torch.nn.functional.one_hot(
234
+ selected_experts, num_classes=self.num_experts
235
+ )
236
+
237
+ expert_mask[mask_top1, 1, :] = 0
238
+ expert_mask = expert_mask.permute(2, 1, 0)
239
+
240
+ # Loop over all available experts in the model and perform the computation on each expert
241
+ for expert_idx in range(self.num_experts):
242
+ expert_layer = self.experts[expert_idx]
243
+ top_x, indices = torch.where(expert_mask[expert_idx])
244
+
245
+ if indices.shape[0] == 0:
246
+ continue
247
+
248
+ # in torch it is faster to index using lists than torch tensors
249
+ indices_list = indices.tolist()
250
+ top_x_list = top_x.tolist()
251
+
252
+ # Index the correct hidden states and compute the expert hidden state for
253
+ # the current expert. We need to make sure to multiply the output hidden
254
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
255
+ current_state = hidden_states[None, indices_list].reshape(-1, hidden_dim)
256
+ current_hidden_states = expert_layer(
257
+ current_state, routing_weights[indices_list, top_x_list, None]
258
+ )
259
+
260
+ # However `index_add_` only support torch tensors for indexing so we'll use
261
+ # the `top_x` tensor here.
262
+ final_hidden_states.index_add_(
263
+ 0, indices, current_hidden_states.to(hidden_states.dtype)
264
+ )
265
+ final_hidden_states = final_hidden_states.reshape(
266
+ batch_size, sequence_length, hidden_dim
267
+ )
268
+ return final_hidden_states, router_logits
@@ -66,7 +66,9 @@ class ParameterDictModel(nn.Module):
66
66
  super().__init__()
67
67
  if parameters is not None:
68
68
  for name, param in parameters.items():
69
- assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
69
+ assert isinstance(
70
+ param, (nn.Parameter, nn.Buffer)
71
+ ), f"{name} is not a nn.Parameter or nn.Buffer"
70
72
  _set_attr(
71
73
  self,
72
74
  name.split("."),
@@ -114,3 +116,6 @@ class ParameterDictModel(nn.Module):
114
116
 
115
117
  def values(self) -> List[nn.Parameter]:
116
118
  return [self[name] for name in self.keys()]
119
+
120
+ def __len__(self):
121
+ return len(self.keys())
@@ -9,7 +9,7 @@ from omegaconf import DictConfig, OmegaConf
9
9
  from torch import nn
10
10
  from tqdm.auto import tqdm
11
11
 
12
- import fusion_bench.utils.instantiate
12
+ import fusion_bench.utils.instantiate_utils
13
13
  from fusion_bench.method import BaseAlgorithm
14
14
  from fusion_bench.mixins import LightningFabricMixin
15
15
  from fusion_bench.modelpool import BaseModelPool
@@ -19,8 +19,9 @@ from fusion_bench.utils import import_object, instantiate, timeit_context
19
19
  from fusion_bench.utils.hydra_utils import get_hydra_output_dir
20
20
  from fusion_bench.utils.json import print_json
21
21
  from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
22
+ from fusion_bench.utils.pylogger import getRankZeroLogger
22
23
 
23
- log = logging.getLogger(__name__)
24
+ log = getRankZeroLogger(__name__)
24
25
 
25
26
 
26
27
  class FabricModelFusionProgram(
@@ -66,8 +67,8 @@ class FabricModelFusionProgram(
66
67
  self.merged_model_save_kwargs = merged_model_save_kwargs
67
68
  self.fast_dev_run = fast_dev_run
68
69
  self.seed = seed
70
+ fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
69
71
  super().__init__(**kwargs)
70
- fusion_bench.utils.instantiate.PRINT_FUNCTION_CALL = print_function_call
71
72
 
72
73
  if print_config:
73
74
  print_config_tree(
@@ -252,13 +253,16 @@ class FabricModelFusionProgram(
252
253
  if self.taskpool is not None:
253
254
  report = self.evaluate_merged_model(self.taskpool, merged_model)
254
255
  try:
255
- print_json(report, print_type=False)
256
+ if rank_zero_only.rank == 0:
257
+ print_json(report, print_type=False)
256
258
  except Exception as e:
257
259
  log.warning(f"Failed to pretty print the report: {e}")
258
- print(report)
260
+ log.info(report)
259
261
  if self.report_save_path is not None:
260
262
  # save report (Dict) to a file
261
263
  # if the directory of `save_report` does not exists, create it
264
+ if "{log_dir}" in self.report_save_path and self.log_dir is not None:
265
+ self.report_save_path = self.report_save_path.format(log_dir=self.log_dir)
262
266
  os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
263
267
  json.dump(report, open(self.report_save_path, "w"))
264
268
  else:
@@ -292,13 +296,17 @@ class FabricModelFusionProgram(
292
296
  if hydra_output_dir is not None:
293
297
  os.makedirs(self.log_dir, exist_ok=True)
294
298
  try:
295
- os.symlink(
296
- hydra_output_dir,
297
- os.path.join(
298
- self.log_dir,
299
- "hydra_output_" + os.path.basename(hydra_output_dir),
300
- ),
301
- target_is_directory=True,
302
- )
299
+ # if the system is windows, use the `mklink` command in "CMD" to create the symlink
300
+ if os.name == "nt":
301
+ os.system(f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}")
302
+ else:
303
+ os.symlink(
304
+ hydra_output_dir,
305
+ os.path.join(
306
+ self.log_dir,
307
+ "hydra_output_" + os.path.basename(hydra_output_dir),
308
+ ),
309
+ target_is_directory=True,
310
+ )
303
311
  except OSError as e:
304
312
  log.warning(f"Failed to create symbolic link: {e}")
@@ -5,6 +5,7 @@ from fusion_bench.mixins import BaseYAMLSerializableModel
5
5
 
6
6
  class BaseTaskPool(BaseYAMLSerializableModel):
7
7
  _program = None
8
+ _config_key = "taskpool"
8
9
 
9
10
  @abstractmethod
10
11
  def evaluate(self, model, *args, **kwargs):
@@ -10,6 +10,7 @@ from fusion_bench.models.separate_io import separate_save
10
10
  from fusion_bench.taskpool.base_pool import BaseTaskPool
11
11
  from fusion_bench.utils import timeit_context
12
12
  from fusion_bench.utils.parameters import count_parameters, print_parameters
13
+ from lightning.pytorch.utilities import rank_zero_only
13
14
 
14
15
 
15
16
  def get_model_summary(model: nn.Module) -> dict:
@@ -49,10 +50,11 @@ class DummyTaskPool(BaseTaskPool):
49
50
  Args:
50
51
  model: The model to evaluate.
51
52
  """
52
- print_parameters(model, is_human_readable=True)
53
+ if rank_zero_only.rank == 0:
54
+ print_parameters(model, is_human_readable=True)
53
55
 
54
- if self.model_save_path is not None:
55
- with timeit_context(f"Saving the model to {self.model_save_path}"):
56
- separate_save(model, self.model_save_path)
56
+ if self.model_save_path is not None:
57
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
58
+ separate_save(model, self.model_save_path)
57
59
 
58
60
  return get_model_summary(model)
@@ -2,14 +2,15 @@
2
2
  import importlib
3
3
  from typing import Iterable
4
4
 
5
- from . import data, functools, path
5
+ from . import data, functools, path, pylogger
6
6
  from .cache_utils import *
7
7
  from .devices import *
8
8
  from .dtype import parse_dtype
9
9
  from .fabric import seed_everything_by_time
10
- from .instantiate import instantiate, is_instantiable
10
+ from .instantiate_utils import instantiate, is_instantiable
11
+ from .json import load_from_json, save_to_json
12
+ from .lazy_state_dict import LazyStateDict
11
13
  from .misc import *
12
14
  from .packages import import_object
13
15
  from .parameters import *
14
16
  from .timer import timeit_context
15
- from .lazy_state_dict import LazyStateDict
@@ -13,6 +13,7 @@ from transformers.utils import (
13
13
  PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
14
14
  "fp16": torch.float16,
15
15
  "float16": torch.float16,
16
+ "half": torch.float16,
16
17
  "bf16": torch.bfloat16,
17
18
  "bfloat16": torch.bfloat16,
18
19
  "float": torch.float32,
@@ -50,7 +51,7 @@ def parse_dtype(dtype: Optional[str]):
50
51
 
51
52
  dtype = dtype.strip('"')
52
53
  if dtype not in PRECISION_STR_TO_DTYPE:
53
- raise ValueError(f"Unsupported dtype: {type(dtype)}")
54
+ raise ValueError(f"Unsupported dtype string: {dtype}")
54
55
 
55
56
  dtype = PRECISION_STR_TO_DTYPE[dtype]
56
57
  return dtype
@@ -1,17 +1,24 @@
1
1
  import time
2
+ from typing import Optional
2
3
 
3
4
  import lightning as L
4
5
 
6
+ from fusion_bench.utils.pylogger import getRankZeroLogger
5
7
 
6
- def seed_everything_by_time(fabric: L.Fabric):
8
+ log = getRankZeroLogger(__name__)
9
+
10
+
11
+ def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
7
12
  """
8
13
  Set seed for all processes by time.
9
14
  """
10
15
  # set seed for all processes
11
- if fabric.is_global_zero:
16
+ if fabric is None or fabric.is_global_zero:
12
17
  seed = int(time.time())
13
18
  else:
14
19
  seed = None
15
- fabric.barrier()
16
- seed = fabric.broadcast(seed, src=0)
20
+ if fabric is not None:
21
+ log.debug(f"Broadcasting seed `{seed}` to all processes")
22
+ fabric.barrier()
23
+ seed = fabric.broadcast(seed, src=0)
17
24
  L.seed_everything(seed)
@@ -41,6 +41,9 @@ def set_print_function_call(value: bool):
41
41
  finally:
42
42
  PRINT_FUNCTION_CALL = old_value
43
43
 
44
+ def set_print_function_call_permeanent(value: bool):
45
+ global PRINT_FUNCTION_CALL
46
+ PRINT_FUNCTION_CALL = value
44
47
 
45
48
  def is_instantiable(config: Union[DictConfig, Any]) -> bool:
46
49
  if OmegaConf.is_dict(config):
@@ -1,13 +1,16 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
4
+ from copy import deepcopy
5
+ from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type
5
6
 
6
7
  import torch
8
+ from accelerate import init_empty_weights
7
9
  from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
8
10
  from huggingface_hub import snapshot_download
9
11
  from safetensors import safe_open
10
12
  from safetensors.torch import load_file
13
+ from torch import nn
11
14
  from transformers import AutoConfig
12
15
 
13
16
  from fusion_bench.utils.dtype import parse_dtype
@@ -59,6 +62,8 @@ class LazyStateDict:
59
62
  def __init__(
60
63
  self,
61
64
  checkpoint: str,
65
+ meta_module_class: Optional[Type[nn.Module]] = None,
66
+ meta_module: Optional[nn.Module] = None,
62
67
  cache_state_dict: bool = False,
63
68
  torch_dtype: Optional[torch.dtype] = None,
64
69
  device: str = "cpu",
@@ -66,6 +71,22 @@ class LazyStateDict:
66
71
  hf_cache_dir: Optional[str] = None,
67
72
  hf_proxies: Optional[Dict] = None,
68
73
  ):
74
+ self.meta_module_class = meta_module_class
75
+ self.meta_module = meta_module
76
+ if self.meta_module_class is not None:
77
+ if self.meta_module is not None:
78
+ raise ValueError(
79
+ "Cannot provide both meta_module_class and meta_module, please provide only one."
80
+ )
81
+ with init_empty_weights():
82
+ self.meta_module = self.meta_module_class.from_pretrained(
83
+ checkpoint,
84
+ torch_dtype=torch_dtype,
85
+ revision=hf_revision,
86
+ cache_dir=hf_cache_dir,
87
+ proxies=hf_proxies,
88
+ )
89
+
69
90
  self._checkpoint = checkpoint
70
91
  self._local_path = resolve_checkpoint_path(
71
92
  checkpoint,
@@ -78,10 +99,32 @@ class LazyStateDict:
78
99
  self._resolve_checkpoint_files(self._local_path)
79
100
  )
80
101
 
81
- if cache_state_dict:
82
- self._state_dict_cache = {}
102
+ if self._index is not None:
103
+ # if meta_module is provided, remove the keys that are not in the meta_module
104
+ if self.meta_module is not None:
105
+ meta_module_state_dict = self.meta_module.state_dict()
106
+ for key in tuple(self._index.keys()):
107
+ if key not in meta_module_state_dict:
108
+ self._index.pop(key)
109
+ if cache_state_dict:
110
+ self._state_dict_cache = {}
111
+ else:
112
+ self._state_dict_cache = None
113
+ elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
114
+ WEIGHTS_NAME
115
+ ):
116
+ log.info(f"Loading full state dict from {WEIGHTS_NAME}")
117
+ self._state_dict_cache = torch.load(self._checkpoint_files[0])
118
+ # if meta_module is provided, remove the keys that are not in the meta_module
119
+ if self.meta_module is not None:
120
+ meta_module_state_dict = self.meta_module.state_dict()
121
+ for key in tuple(self._state_dict_cache.keys()):
122
+ if key not in meta_module_state_dict:
123
+ self._state_dict_cache.pop(key)
83
124
  else:
84
- self._state_dict_cache = None
125
+ raise ValueError(
126
+ f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
127
+ )
85
128
 
86
129
  self._torch_dtype = parse_dtype(torch_dtype)
87
130
  self._device = device
@@ -152,6 +195,8 @@ class LazyStateDict:
152
195
  checkpoint_files = [
153
196
  os.path.join(checkpoint_folder, f) for f in checkpoint_files
154
197
  ]
198
+ else:
199
+ index = None
155
200
  return index, index_filename, checkpoint_files
156
201
 
157
202
  def _load_tensor_from_checkpoint_file(
@@ -248,16 +293,24 @@ class LazyStateDict:
248
293
  def __iter__(self) -> Iterator[str]:
249
294
  if self._index is not None:
250
295
  return iter(self._index)
251
- return iter(self._checkpoint_files)
296
+ elif self._state_dict_cache is not None:
297
+ return iter(self._state_dict_cache)
298
+ else:
299
+ raise RuntimeError(
300
+ "Unexpected error: cannot determine the keys in the state dict."
301
+ )
252
302
 
253
- def keys(self) -> List[str]:
254
- return list(self)
303
+ def keys(self) -> Iterator[str]:
304
+ for key in self:
305
+ yield key
255
306
 
256
- def values(self) -> List[torch.Tensor]:
257
- return [self[key] for key in self]
307
+ def values(self) -> Iterator[torch.Tensor]:
308
+ for key in self:
309
+ yield self[key]
258
310
 
259
311
  def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
260
- return ((key, self[key]) for key in self)
312
+ for key in self:
313
+ yield key, self[key]
261
314
 
262
315
  def __repr__(self) -> str:
263
316
  if self._index is not None:
@@ -266,3 +319,20 @@ class LazyStateDict:
266
319
  return (
267
320
  f"{self.__class__.__name__}(checkpoint_files={self._checkpoint_files})"
268
321
  )
322
+
323
+ def get_parameter(self, target: str) -> torch.Tensor:
324
+ return self[target]
325
+
326
+ def get_submodule(self, target: str) -> nn.Module:
327
+ if self.meta_module is not None:
328
+ module: nn.Module = deepcopy(self.meta_module.get_submodule(target))
329
+ module.to_empty(device=self._device)
330
+ state_dict = {}
331
+ for name, _ in module.named_parameters():
332
+ state_dict[name] = self[f"{target}.{name}"]
333
+ module.load_state_dict(state_dict)
334
+ return module
335
+ else:
336
+ raise RuntimeError(
337
+ "Cannot get submodule because meta_module is not provided."
338
+ )
@@ -53,3 +53,33 @@ class RankedLogger(logging.LoggerAdapter):
53
53
  self.logger.log(level, msg, *args, **kwargs)
54
54
  elif current_rank == rank:
55
55
  self.logger.log(level, msg, *args, **kwargs)
56
+
57
+
58
+ class RankZeroLogger(logging.Logger):
59
+ """A logger that logs only on rank zero and works just like logging.Logger"""
60
+
61
+ @rank_zero_only
62
+ def _log(self, *args, **kwargs):
63
+ if "stacklevel" in kwargs:
64
+ kwargs["stacklevel"] += 1
65
+ else:
66
+ kwargs["stacklevel"] = 2
67
+ return super()._log(*args, **kwargs)
68
+
69
+ def is_global_zero(self):
70
+ return rank_zero_only.rank == 0
71
+
72
+
73
+ RankZeroLogger.manager = logging.Manager(RankZeroLogger.root)
74
+ RankZeroLogger.manager.setLoggerClass(RankZeroLogger)
75
+
76
+
77
+ def getRankZeroLogger(name=None):
78
+ """
79
+ Return a logger with the specified name, creating it if necessary.
80
+
81
+ If no name is specified, return the root logger.
82
+ """
83
+ if not name or isinstance(name, str) and name == logging.root.name:
84
+ return logging.root
85
+ return RankZeroLogger.manager.getLogger(name)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.16
3
+ Version: 0.2.18
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License
@@ -171,6 +171,8 @@ It can be used to improve the performance and robustness of model or to combine
171
171
  For a more detailed introduction to deep model fusion, you can refer to [W. Li, 2023, 'Deep Model Fusion: A Survey'](https://arxiv.org/abs/2309.15698). We also provide a brief overview of deep model fusion in [our documentation](https://tanganke.github.io/fusion_bench/).
172
172
  In this benchmark, we evaluate the performance of different fusion methods on a variety of datasets and tasks.
173
173
 
174
+ A comprehensive list of papers about model merging can be found at [this repository](https://github.com/EnnengYang/Awesome-Model-Merging-Methods-Theories-Applications), and [the arXiv paper](https://arxiv.org/abs/2408.07666) is also available.
175
+
174
176
  ## Project Structure
175
177
 
176
178
  The project is structured as follows: