fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +21 -13
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +4 -3
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +80 -10
- fusion_bench/utils/pylogger.py +30 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import itertools as I
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import (
|
|
9
|
+
MixtralBlockSparseTop2MLP,
|
|
10
|
+
MixtralForCausalLM,
|
|
11
|
+
MixtralSparseMoeBlock,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from .dataset import CacheDataset
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrunableMixtralSparseMoeBlockWrapper(torch.nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
Wrapper of `MixtralSparseMoeBlock` that supports expert pruning.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: MixtralSparseMoeBlock,
|
|
27
|
+
r: Optional[int] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Args:
|
|
31
|
+
model: The model to be wrapped.
|
|
32
|
+
r: The number of experts to keep.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
if isinstance(model, MixtralSparseMoeBlock):
|
|
36
|
+
self.model = model
|
|
37
|
+
else:
|
|
38
|
+
self.model = model.model
|
|
39
|
+
self.r = r
|
|
40
|
+
|
|
41
|
+
self.experts_to_drop = None
|
|
42
|
+
self.cache_space = CacheDataset()
|
|
43
|
+
self.cache_logits = False
|
|
44
|
+
self.cache_X = False
|
|
45
|
+
self.cache_Z = False
|
|
46
|
+
|
|
47
|
+
# Forward uses topk
|
|
48
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
""" """
|
|
50
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
51
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
52
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
53
|
+
router_logits = self.model.gate(hidden_states)
|
|
54
|
+
|
|
55
|
+
if self.experts_to_drop is not None:
|
|
56
|
+
for e in self.experts_to_drop:
|
|
57
|
+
router_logits[:, e] = -float("inf")
|
|
58
|
+
|
|
59
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
60
|
+
routing_weights, selected_experts = torch.topk(
|
|
61
|
+
routing_weights, self.model.top_k, dim=-1
|
|
62
|
+
)
|
|
63
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
64
|
+
# we cast back to the input dtype
|
|
65
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
66
|
+
|
|
67
|
+
final_hidden_states = torch.zeros(
|
|
68
|
+
(batch_size * sequence_length, hidden_dim),
|
|
69
|
+
dtype=hidden_states.dtype,
|
|
70
|
+
device=hidden_states.device,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# One hot encode the selected experts to create an expert mask
|
|
74
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
75
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
76
|
+
selected_experts, num_classes=self.model.num_experts
|
|
77
|
+
).permute(2, 1, 0)
|
|
78
|
+
|
|
79
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
80
|
+
for expert_idx in range(self.model.num_experts):
|
|
81
|
+
expert_layer = self.model.experts[expert_idx]
|
|
82
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
83
|
+
|
|
84
|
+
if top_x.shape[0] == 0:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
# in torch it is faster to index using lists than torch tensors
|
|
88
|
+
top_x_list = top_x.tolist()
|
|
89
|
+
idx_list = idx.tolist()
|
|
90
|
+
|
|
91
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
92
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
93
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
94
|
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
|
95
|
+
current_hidden_states = (
|
|
96
|
+
expert_layer(current_state)
|
|
97
|
+
* routing_weights[top_x_list, idx_list, None]
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
101
|
+
# the `top_x` tensor here.
|
|
102
|
+
final_hidden_states.index_add_(
|
|
103
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if self.experts_to_drop is not None and (
|
|
107
|
+
self.cache_logits or self.cache_X or self.cache_Z
|
|
108
|
+
):
|
|
109
|
+
logger.warn(
|
|
110
|
+
f"Already dropped {self.experts_to_drop} but still storing activations."
|
|
111
|
+
)
|
|
112
|
+
self.cache_space.append(
|
|
113
|
+
alpha=(router_logits if self.cache_logits else None),
|
|
114
|
+
X=(hidden_states if self.cache_X else None),
|
|
115
|
+
Z=(final_hidden_states if self.cache_Z else None),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
119
|
+
batch_size, sequence_length, hidden_dim
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return final_hidden_states, router_logits
|
|
123
|
+
|
|
124
|
+
@torch.no_grad()
|
|
125
|
+
def enumerate(self):
|
|
126
|
+
# disable caching
|
|
127
|
+
self.cache_logits = False
|
|
128
|
+
self.cache_X = False
|
|
129
|
+
self.cache_Z = False
|
|
130
|
+
loss_history = dict()
|
|
131
|
+
|
|
132
|
+
with torch.inference_mode():
|
|
133
|
+
for dropped in I.combinations(
|
|
134
|
+
range(self.model.num_experts), self.model.num_experts - self.r
|
|
135
|
+
):
|
|
136
|
+
self.experts_to_drop = dropped
|
|
137
|
+
loss = 0
|
|
138
|
+
|
|
139
|
+
for hidden_states, final_hidden_states in zip(
|
|
140
|
+
self.cache_space.Xs, self.cache_space.Zs
|
|
141
|
+
):
|
|
142
|
+
hidden_states = hidden_states.to(
|
|
143
|
+
device=self.model.gate.weight.data.device, non_blocking=True
|
|
144
|
+
)
|
|
145
|
+
final_hidden_states = final_hidden_states.to(
|
|
146
|
+
dtype=torch.float64,
|
|
147
|
+
device=self.model.gate.weight.data.device,
|
|
148
|
+
non_blocking=True,
|
|
149
|
+
)
|
|
150
|
+
final_hidden_states_e, _ = self.forward(hidden_states.unsqueeze(0))
|
|
151
|
+
# compute the |Z - Z_e|_2 L2 loss
|
|
152
|
+
loss += torch.norm(
|
|
153
|
+
final_hidden_states
|
|
154
|
+
- final_hidden_states_e.squeeze(0).to(torch.float64)
|
|
155
|
+
).item()
|
|
156
|
+
loss_history[dropped] = loss
|
|
157
|
+
|
|
158
|
+
self.experts_to_drop = min(loss_history, key=loss_history.get)
|
|
159
|
+
return loss_history
|
|
160
|
+
|
|
161
|
+
@torch.no_grad()
|
|
162
|
+
def prune(self):
|
|
163
|
+
assert self.experts_to_drop is not None
|
|
164
|
+
assert len(self.experts_to_drop) == self.model.num_experts - self.r
|
|
165
|
+
del self.cache_space
|
|
166
|
+
self.cache_X = False
|
|
167
|
+
self.cache_Z = False
|
|
168
|
+
|
|
169
|
+
experts_to_reserve = sorted(
|
|
170
|
+
set(range(self.model.num_experts)) - set(self.experts_to_drop)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# create a new gate with the experts to reserve
|
|
174
|
+
gate_new = torch.nn.Linear(
|
|
175
|
+
in_features=self.model.gate.in_features,
|
|
176
|
+
out_features=self.r,
|
|
177
|
+
bias=False,
|
|
178
|
+
device=self.model.gate.weight.data.device,
|
|
179
|
+
dtype=torch.bfloat16,
|
|
180
|
+
)
|
|
181
|
+
gate_new.weight.data = self.model.gate.weight.data[list(experts_to_reserve)]
|
|
182
|
+
self.model.gate = gate_new
|
|
183
|
+
|
|
184
|
+
self.model.experts = torch.nn.ModuleList(
|
|
185
|
+
[self.model.experts[i] for i in experts_to_reserve]
|
|
186
|
+
)
|
|
187
|
+
self.model.num_experts = self.r
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class DynamicSkippingMixtralSparseMoeBlockWrapper(nn.Module):
|
|
191
|
+
def __init__(self, model: MixtralSparseMoeBlock, beta: float):
|
|
192
|
+
super().__init__()
|
|
193
|
+
assert isinstance(model, MixtralSparseMoeBlock)
|
|
194
|
+
assert model.top_k == 2
|
|
195
|
+
self.hidden_dim = model.hidden_dim
|
|
196
|
+
self.ffn_dim = model.ffn_dim
|
|
197
|
+
self.num_experts = model.num_experts
|
|
198
|
+
self.top_k = model.top_k
|
|
199
|
+
self.gate = model.gate
|
|
200
|
+
self.experts = model.experts
|
|
201
|
+
|
|
202
|
+
self.beta = beta
|
|
203
|
+
|
|
204
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
""" """
|
|
206
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
207
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
208
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
209
|
+
router_logits = self.gate(hidden_states)
|
|
210
|
+
|
|
211
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
212
|
+
routing_weights, selected_experts = torch.topk(
|
|
213
|
+
routing_weights, self.top_k, dim=-1
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# (batch * sequence_length)
|
|
217
|
+
mask_top1 = routing_weights[:, 1] < self.beta * routing_weights[:, 0]
|
|
218
|
+
routing_weights[mask_top1, 1] = 0
|
|
219
|
+
|
|
220
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
221
|
+
# we cast back to the input dtype
|
|
222
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
223
|
+
|
|
224
|
+
final_hidden_states = torch.zeros(
|
|
225
|
+
(batch_size * sequence_length, hidden_dim),
|
|
226
|
+
dtype=hidden_states.dtype,
|
|
227
|
+
device=hidden_states.device,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# One hot encode the selected experts to create an expert mask
|
|
231
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
232
|
+
# (batch * sequence_length, self.top_k, n_experts)
|
|
233
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
234
|
+
selected_experts, num_classes=self.num_experts
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
expert_mask[mask_top1, 1, :] = 0
|
|
238
|
+
expert_mask = expert_mask.permute(2, 1, 0)
|
|
239
|
+
|
|
240
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
241
|
+
for expert_idx in range(self.num_experts):
|
|
242
|
+
expert_layer = self.experts[expert_idx]
|
|
243
|
+
top_x, indices = torch.where(expert_mask[expert_idx])
|
|
244
|
+
|
|
245
|
+
if indices.shape[0] == 0:
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
# in torch it is faster to index using lists than torch tensors
|
|
249
|
+
indices_list = indices.tolist()
|
|
250
|
+
top_x_list = top_x.tolist()
|
|
251
|
+
|
|
252
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
253
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
254
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
255
|
+
current_state = hidden_states[None, indices_list].reshape(-1, hidden_dim)
|
|
256
|
+
current_hidden_states = expert_layer(
|
|
257
|
+
current_state, routing_weights[indices_list, top_x_list, None]
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
261
|
+
# the `top_x` tensor here.
|
|
262
|
+
final_hidden_states.index_add_(
|
|
263
|
+
0, indices, current_hidden_states.to(hidden_states.dtype)
|
|
264
|
+
)
|
|
265
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
266
|
+
batch_size, sequence_length, hidden_dim
|
|
267
|
+
)
|
|
268
|
+
return final_hidden_states, router_logits
|
|
@@ -66,7 +66,9 @@ class ParameterDictModel(nn.Module):
|
|
|
66
66
|
super().__init__()
|
|
67
67
|
if parameters is not None:
|
|
68
68
|
for name, param in parameters.items():
|
|
69
|
-
assert isinstance(
|
|
69
|
+
assert isinstance(
|
|
70
|
+
param, (nn.Parameter, nn.Buffer)
|
|
71
|
+
), f"{name} is not a nn.Parameter or nn.Buffer"
|
|
70
72
|
_set_attr(
|
|
71
73
|
self,
|
|
72
74
|
name.split("."),
|
|
@@ -114,3 +116,6 @@ class ParameterDictModel(nn.Module):
|
|
|
114
116
|
|
|
115
117
|
def values(self) -> List[nn.Parameter]:
|
|
116
118
|
return [self[name] for name in self.keys()]
|
|
119
|
+
|
|
120
|
+
def __len__(self):
|
|
121
|
+
return len(self.keys())
|
|
@@ -9,7 +9,7 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
from tqdm.auto import tqdm
|
|
11
11
|
|
|
12
|
-
import fusion_bench.utils.
|
|
12
|
+
import fusion_bench.utils.instantiate_utils
|
|
13
13
|
from fusion_bench.method import BaseAlgorithm
|
|
14
14
|
from fusion_bench.mixins import LightningFabricMixin
|
|
15
15
|
from fusion_bench.modelpool import BaseModelPool
|
|
@@ -19,8 +19,9 @@ from fusion_bench.utils import import_object, instantiate, timeit_context
|
|
|
19
19
|
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
20
20
|
from fusion_bench.utils.json import print_json
|
|
21
21
|
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
22
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
22
23
|
|
|
23
|
-
log =
|
|
24
|
+
log = getRankZeroLogger(__name__)
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class FabricModelFusionProgram(
|
|
@@ -66,8 +67,8 @@ class FabricModelFusionProgram(
|
|
|
66
67
|
self.merged_model_save_kwargs = merged_model_save_kwargs
|
|
67
68
|
self.fast_dev_run = fast_dev_run
|
|
68
69
|
self.seed = seed
|
|
70
|
+
fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
|
|
69
71
|
super().__init__(**kwargs)
|
|
70
|
-
fusion_bench.utils.instantiate.PRINT_FUNCTION_CALL = print_function_call
|
|
71
72
|
|
|
72
73
|
if print_config:
|
|
73
74
|
print_config_tree(
|
|
@@ -252,13 +253,16 @@ class FabricModelFusionProgram(
|
|
|
252
253
|
if self.taskpool is not None:
|
|
253
254
|
report = self.evaluate_merged_model(self.taskpool, merged_model)
|
|
254
255
|
try:
|
|
255
|
-
|
|
256
|
+
if rank_zero_only.rank == 0:
|
|
257
|
+
print_json(report, print_type=False)
|
|
256
258
|
except Exception as e:
|
|
257
259
|
log.warning(f"Failed to pretty print the report: {e}")
|
|
258
|
-
|
|
260
|
+
log.info(report)
|
|
259
261
|
if self.report_save_path is not None:
|
|
260
262
|
# save report (Dict) to a file
|
|
261
263
|
# if the directory of `save_report` does not exists, create it
|
|
264
|
+
if "{log_dir}" in self.report_save_path and self.log_dir is not None:
|
|
265
|
+
self.report_save_path = self.report_save_path.format(log_dir=self.log_dir)
|
|
262
266
|
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
263
267
|
json.dump(report, open(self.report_save_path, "w"))
|
|
264
268
|
else:
|
|
@@ -292,13 +296,17 @@ class FabricModelFusionProgram(
|
|
|
292
296
|
if hydra_output_dir is not None:
|
|
293
297
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
294
298
|
try:
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
os.path.join(
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
299
|
+
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
|
300
|
+
if os.name == "nt":
|
|
301
|
+
os.system(f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}")
|
|
302
|
+
else:
|
|
303
|
+
os.symlink(
|
|
304
|
+
hydra_output_dir,
|
|
305
|
+
os.path.join(
|
|
306
|
+
self.log_dir,
|
|
307
|
+
"hydra_output_" + os.path.basename(hydra_output_dir),
|
|
308
|
+
),
|
|
309
|
+
target_is_directory=True,
|
|
310
|
+
)
|
|
303
311
|
except OSError as e:
|
|
304
312
|
log.warning(f"Failed to create symbolic link: {e}")
|
fusion_bench/taskpool/dummy.py
CHANGED
|
@@ -10,6 +10,7 @@ from fusion_bench.models.separate_io import separate_save
|
|
|
10
10
|
from fusion_bench.taskpool.base_pool import BaseTaskPool
|
|
11
11
|
from fusion_bench.utils import timeit_context
|
|
12
12
|
from fusion_bench.utils.parameters import count_parameters, print_parameters
|
|
13
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def get_model_summary(model: nn.Module) -> dict:
|
|
@@ -49,10 +50,11 @@ class DummyTaskPool(BaseTaskPool):
|
|
|
49
50
|
Args:
|
|
50
51
|
model: The model to evaluate.
|
|
51
52
|
"""
|
|
52
|
-
|
|
53
|
+
if rank_zero_only.rank == 0:
|
|
54
|
+
print_parameters(model, is_human_readable=True)
|
|
53
55
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
if self.model_save_path is not None:
|
|
57
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
58
|
+
separate_save(model, self.model_save_path)
|
|
57
59
|
|
|
58
60
|
return get_model_summary(model)
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -2,14 +2,15 @@
|
|
|
2
2
|
import importlib
|
|
3
3
|
from typing import Iterable
|
|
4
4
|
|
|
5
|
-
from . import data, functools, path
|
|
5
|
+
from . import data, functools, path, pylogger
|
|
6
6
|
from .cache_utils import *
|
|
7
7
|
from .devices import *
|
|
8
8
|
from .dtype import parse_dtype
|
|
9
9
|
from .fabric import seed_everything_by_time
|
|
10
|
-
from .
|
|
10
|
+
from .instantiate_utils import instantiate, is_instantiable
|
|
11
|
+
from .json import load_from_json, save_to_json
|
|
12
|
+
from .lazy_state_dict import LazyStateDict
|
|
11
13
|
from .misc import *
|
|
12
14
|
from .packages import import_object
|
|
13
15
|
from .parameters import *
|
|
14
16
|
from .timer import timeit_context
|
|
15
|
-
from .lazy_state_dict import LazyStateDict
|
fusion_bench/utils/dtype.py
CHANGED
|
@@ -13,6 +13,7 @@ from transformers.utils import (
|
|
|
13
13
|
PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
|
|
14
14
|
"fp16": torch.float16,
|
|
15
15
|
"float16": torch.float16,
|
|
16
|
+
"half": torch.float16,
|
|
16
17
|
"bf16": torch.bfloat16,
|
|
17
18
|
"bfloat16": torch.bfloat16,
|
|
18
19
|
"float": torch.float32,
|
|
@@ -50,7 +51,7 @@ def parse_dtype(dtype: Optional[str]):
|
|
|
50
51
|
|
|
51
52
|
dtype = dtype.strip('"')
|
|
52
53
|
if dtype not in PRECISION_STR_TO_DTYPE:
|
|
53
|
-
raise ValueError(f"Unsupported dtype: {
|
|
54
|
+
raise ValueError(f"Unsupported dtype string: {dtype}")
|
|
54
55
|
|
|
55
56
|
dtype = PRECISION_STR_TO_DTYPE[dtype]
|
|
56
57
|
return dtype
|
fusion_bench/utils/fabric.py
CHANGED
|
@@ -1,17 +1,24 @@
|
|
|
1
1
|
import time
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import lightning as L
|
|
4
5
|
|
|
6
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
5
7
|
|
|
6
|
-
|
|
8
|
+
log = getRankZeroLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
|
|
7
12
|
"""
|
|
8
13
|
Set seed for all processes by time.
|
|
9
14
|
"""
|
|
10
15
|
# set seed for all processes
|
|
11
|
-
if fabric.is_global_zero:
|
|
16
|
+
if fabric is None or fabric.is_global_zero:
|
|
12
17
|
seed = int(time.time())
|
|
13
18
|
else:
|
|
14
19
|
seed = None
|
|
15
|
-
fabric
|
|
16
|
-
|
|
20
|
+
if fabric is not None:
|
|
21
|
+
log.debug(f"Broadcasting seed `{seed}` to all processes")
|
|
22
|
+
fabric.barrier()
|
|
23
|
+
seed = fabric.broadcast(seed, src=0)
|
|
17
24
|
L.seed_everything(seed)
|
|
@@ -41,6 +41,9 @@ def set_print_function_call(value: bool):
|
|
|
41
41
|
finally:
|
|
42
42
|
PRINT_FUNCTION_CALL = old_value
|
|
43
43
|
|
|
44
|
+
def set_print_function_call_permeanent(value: bool):
|
|
45
|
+
global PRINT_FUNCTION_CALL
|
|
46
|
+
PRINT_FUNCTION_CALL = value
|
|
44
47
|
|
|
45
48
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
46
49
|
if OmegaConf.is_dict(config):
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type
|
|
5
6
|
|
|
6
7
|
import torch
|
|
8
|
+
from accelerate import init_empty_weights
|
|
7
9
|
from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
|
8
10
|
from huggingface_hub import snapshot_download
|
|
9
11
|
from safetensors import safe_open
|
|
10
12
|
from safetensors.torch import load_file
|
|
13
|
+
from torch import nn
|
|
11
14
|
from transformers import AutoConfig
|
|
12
15
|
|
|
13
16
|
from fusion_bench.utils.dtype import parse_dtype
|
|
@@ -59,6 +62,8 @@ class LazyStateDict:
|
|
|
59
62
|
def __init__(
|
|
60
63
|
self,
|
|
61
64
|
checkpoint: str,
|
|
65
|
+
meta_module_class: Optional[Type[nn.Module]] = None,
|
|
66
|
+
meta_module: Optional[nn.Module] = None,
|
|
62
67
|
cache_state_dict: bool = False,
|
|
63
68
|
torch_dtype: Optional[torch.dtype] = None,
|
|
64
69
|
device: str = "cpu",
|
|
@@ -66,6 +71,22 @@ class LazyStateDict:
|
|
|
66
71
|
hf_cache_dir: Optional[str] = None,
|
|
67
72
|
hf_proxies: Optional[Dict] = None,
|
|
68
73
|
):
|
|
74
|
+
self.meta_module_class = meta_module_class
|
|
75
|
+
self.meta_module = meta_module
|
|
76
|
+
if self.meta_module_class is not None:
|
|
77
|
+
if self.meta_module is not None:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Cannot provide both meta_module_class and meta_module, please provide only one."
|
|
80
|
+
)
|
|
81
|
+
with init_empty_weights():
|
|
82
|
+
self.meta_module = self.meta_module_class.from_pretrained(
|
|
83
|
+
checkpoint,
|
|
84
|
+
torch_dtype=torch_dtype,
|
|
85
|
+
revision=hf_revision,
|
|
86
|
+
cache_dir=hf_cache_dir,
|
|
87
|
+
proxies=hf_proxies,
|
|
88
|
+
)
|
|
89
|
+
|
|
69
90
|
self._checkpoint = checkpoint
|
|
70
91
|
self._local_path = resolve_checkpoint_path(
|
|
71
92
|
checkpoint,
|
|
@@ -78,10 +99,32 @@ class LazyStateDict:
|
|
|
78
99
|
self._resolve_checkpoint_files(self._local_path)
|
|
79
100
|
)
|
|
80
101
|
|
|
81
|
-
if
|
|
82
|
-
|
|
102
|
+
if self._index is not None:
|
|
103
|
+
# if meta_module is provided, remove the keys that are not in the meta_module
|
|
104
|
+
if self.meta_module is not None:
|
|
105
|
+
meta_module_state_dict = self.meta_module.state_dict()
|
|
106
|
+
for key in tuple(self._index.keys()):
|
|
107
|
+
if key not in meta_module_state_dict:
|
|
108
|
+
self._index.pop(key)
|
|
109
|
+
if cache_state_dict:
|
|
110
|
+
self._state_dict_cache = {}
|
|
111
|
+
else:
|
|
112
|
+
self._state_dict_cache = None
|
|
113
|
+
elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
|
|
114
|
+
WEIGHTS_NAME
|
|
115
|
+
):
|
|
116
|
+
log.info(f"Loading full state dict from {WEIGHTS_NAME}")
|
|
117
|
+
self._state_dict_cache = torch.load(self._checkpoint_files[0])
|
|
118
|
+
# if meta_module is provided, remove the keys that are not in the meta_module
|
|
119
|
+
if self.meta_module is not None:
|
|
120
|
+
meta_module_state_dict = self.meta_module.state_dict()
|
|
121
|
+
for key in tuple(self._state_dict_cache.keys()):
|
|
122
|
+
if key not in meta_module_state_dict:
|
|
123
|
+
self._state_dict_cache.pop(key)
|
|
83
124
|
else:
|
|
84
|
-
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
|
|
127
|
+
)
|
|
85
128
|
|
|
86
129
|
self._torch_dtype = parse_dtype(torch_dtype)
|
|
87
130
|
self._device = device
|
|
@@ -152,6 +195,8 @@ class LazyStateDict:
|
|
|
152
195
|
checkpoint_files = [
|
|
153
196
|
os.path.join(checkpoint_folder, f) for f in checkpoint_files
|
|
154
197
|
]
|
|
198
|
+
else:
|
|
199
|
+
index = None
|
|
155
200
|
return index, index_filename, checkpoint_files
|
|
156
201
|
|
|
157
202
|
def _load_tensor_from_checkpoint_file(
|
|
@@ -248,16 +293,24 @@ class LazyStateDict:
|
|
|
248
293
|
def __iter__(self) -> Iterator[str]:
|
|
249
294
|
if self._index is not None:
|
|
250
295
|
return iter(self._index)
|
|
251
|
-
|
|
296
|
+
elif self._state_dict_cache is not None:
|
|
297
|
+
return iter(self._state_dict_cache)
|
|
298
|
+
else:
|
|
299
|
+
raise RuntimeError(
|
|
300
|
+
"Unexpected error: cannot determine the keys in the state dict."
|
|
301
|
+
)
|
|
252
302
|
|
|
253
|
-
def keys(self) ->
|
|
254
|
-
|
|
303
|
+
def keys(self) -> Iterator[str]:
|
|
304
|
+
for key in self:
|
|
305
|
+
yield key
|
|
255
306
|
|
|
256
|
-
def values(self) ->
|
|
257
|
-
|
|
307
|
+
def values(self) -> Iterator[torch.Tensor]:
|
|
308
|
+
for key in self:
|
|
309
|
+
yield self[key]
|
|
258
310
|
|
|
259
311
|
def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
260
|
-
|
|
312
|
+
for key in self:
|
|
313
|
+
yield key, self[key]
|
|
261
314
|
|
|
262
315
|
def __repr__(self) -> str:
|
|
263
316
|
if self._index is not None:
|
|
@@ -266,3 +319,20 @@ class LazyStateDict:
|
|
|
266
319
|
return (
|
|
267
320
|
f"{self.__class__.__name__}(checkpoint_files={self._checkpoint_files})"
|
|
268
321
|
)
|
|
322
|
+
|
|
323
|
+
def get_parameter(self, target: str) -> torch.Tensor:
|
|
324
|
+
return self[target]
|
|
325
|
+
|
|
326
|
+
def get_submodule(self, target: str) -> nn.Module:
|
|
327
|
+
if self.meta_module is not None:
|
|
328
|
+
module: nn.Module = deepcopy(self.meta_module.get_submodule(target))
|
|
329
|
+
module.to_empty(device=self._device)
|
|
330
|
+
state_dict = {}
|
|
331
|
+
for name, _ in module.named_parameters():
|
|
332
|
+
state_dict[name] = self[f"{target}.{name}"]
|
|
333
|
+
module.load_state_dict(state_dict)
|
|
334
|
+
return module
|
|
335
|
+
else:
|
|
336
|
+
raise RuntimeError(
|
|
337
|
+
"Cannot get submodule because meta_module is not provided."
|
|
338
|
+
)
|
fusion_bench/utils/pylogger.py
CHANGED
|
@@ -53,3 +53,33 @@ class RankedLogger(logging.LoggerAdapter):
|
|
|
53
53
|
self.logger.log(level, msg, *args, **kwargs)
|
|
54
54
|
elif current_rank == rank:
|
|
55
55
|
self.logger.log(level, msg, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class RankZeroLogger(logging.Logger):
|
|
59
|
+
"""A logger that logs only on rank zero and works just like logging.Logger"""
|
|
60
|
+
|
|
61
|
+
@rank_zero_only
|
|
62
|
+
def _log(self, *args, **kwargs):
|
|
63
|
+
if "stacklevel" in kwargs:
|
|
64
|
+
kwargs["stacklevel"] += 1
|
|
65
|
+
else:
|
|
66
|
+
kwargs["stacklevel"] = 2
|
|
67
|
+
return super()._log(*args, **kwargs)
|
|
68
|
+
|
|
69
|
+
def is_global_zero(self):
|
|
70
|
+
return rank_zero_only.rank == 0
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
RankZeroLogger.manager = logging.Manager(RankZeroLogger.root)
|
|
74
|
+
RankZeroLogger.manager.setLoggerClass(RankZeroLogger)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def getRankZeroLogger(name=None):
|
|
78
|
+
"""
|
|
79
|
+
Return a logger with the specified name, creating it if necessary.
|
|
80
|
+
|
|
81
|
+
If no name is specified, return the root logger.
|
|
82
|
+
"""
|
|
83
|
+
if not name or isinstance(name, str) and name == logging.root.name:
|
|
84
|
+
return logging.root
|
|
85
|
+
return RankZeroLogger.manager.getLogger(name)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.18
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -171,6 +171,8 @@ It can be used to improve the performance and robustness of model or to combine
|
|
|
171
171
|
For a more detailed introduction to deep model fusion, you can refer to [W. Li, 2023, 'Deep Model Fusion: A Survey'](https://arxiv.org/abs/2309.15698). We also provide a brief overview of deep model fusion in [our documentation](https://tanganke.github.io/fusion_bench/).
|
|
172
172
|
In this benchmark, we evaluate the performance of different fusion methods on a variety of datasets and tasks.
|
|
173
173
|
|
|
174
|
+
A comprehensive list of papers about model merging can be found at [this repository](https://github.com/EnnengYang/Awesome-Model-Merging-Methods-Theories-Applications), and [the arXiv paper](https://arxiv.org/abs/2408.07666) is also available.
|
|
175
|
+
|
|
174
176
|
## Project Structure
|
|
175
177
|
|
|
176
178
|
The project is structured as follows:
|