fusion-bench 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +21 -2
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/method/__init__.py +8 -2
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +11 -42
- fusion_bench/mixins/serialization.py +18 -8
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -33
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +65 -87
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +9 -4
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +53 -45
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
|
@@ -16,10 +16,11 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
18
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
19
|
+
from fusion_bench.constants import RuntimeConstants
|
|
19
20
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
20
21
|
from fusion_bench.modelpool import CausalLMPool
|
|
21
22
|
from fusion_bench.models.hf_utils import (
|
|
22
|
-
|
|
23
|
+
create_default_model_card,
|
|
23
24
|
save_pretrained_with_remote_code,
|
|
24
25
|
)
|
|
25
26
|
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
@@ -41,7 +42,10 @@ log = logging.getLogger(__name__)
|
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
@auto_register_config
|
|
44
|
-
class SmileQwen2UpscalingAlgorithm(
|
|
45
|
+
class SmileQwen2UpscalingAlgorithm(
|
|
46
|
+
SimpleProfilerMixin,
|
|
47
|
+
BaseAlgorithm,
|
|
48
|
+
):
|
|
45
49
|
R"""
|
|
46
50
|
SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
47
51
|
a pretrained Qwen2 model using a set of fine-tuned expert models. The algorithm
|
|
@@ -62,7 +66,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
62
66
|
self,
|
|
63
67
|
device,
|
|
64
68
|
accelerator,
|
|
65
|
-
|
|
69
|
+
model_save_path,
|
|
66
70
|
model_dtype,
|
|
67
71
|
num_experts_per_tok,
|
|
68
72
|
rank_of_router,
|
|
@@ -71,6 +75,11 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
71
75
|
**kwargs,
|
|
72
76
|
):
|
|
73
77
|
super().__init__(**kwargs)
|
|
78
|
+
if not torch.cuda.is_available():
|
|
79
|
+
if "cuda" in self.device:
|
|
80
|
+
self.device = "cpu"
|
|
81
|
+
if "cuda" in self.accelerator:
|
|
82
|
+
self.accelerator = "cpu"
|
|
74
83
|
|
|
75
84
|
@torch.no_grad()
|
|
76
85
|
def run(self, modelpool) -> SmileQwen2ForCausalLM:
|
|
@@ -86,13 +95,6 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
86
95
|
self.modelpool = modelpool = to_modelpool(modelpool)
|
|
87
96
|
config = self.config
|
|
88
97
|
|
|
89
|
-
# load model from path if provided and return directly
|
|
90
|
-
if config.model_path is not None and os.path.exists(config.model_path):
|
|
91
|
-
log.info(f"Loading model from {config.model_path}")
|
|
92
|
-
model = AutoModelForCausalLM.from_pretrained(config.model_path)
|
|
93
|
-
print_parameters(model)
|
|
94
|
-
return model
|
|
95
|
-
|
|
96
98
|
with self.profile("load pretrained model"):
|
|
97
99
|
pretrained_model = modelpool.load_pretrained_model()
|
|
98
100
|
with self.profile("load fine-tuned model"):
|
|
@@ -100,7 +102,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
100
102
|
m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
101
103
|
]
|
|
102
104
|
|
|
103
|
-
if
|
|
105
|
+
if self.device == "cuda" and torch.cuda.is_available():
|
|
104
106
|
pretrained_model = pretrained_model.cuda()
|
|
105
107
|
print("parameter count of pretrained model:")
|
|
106
108
|
print_parameters(pretrained_model)
|
|
@@ -114,17 +116,17 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
114
116
|
print_parameters(model)
|
|
115
117
|
print(model)
|
|
116
118
|
|
|
117
|
-
if
|
|
118
|
-
model.to(dtype=parse_dtype(
|
|
119
|
+
if self.model_dtype is not None:
|
|
120
|
+
model.to(dtype=parse_dtype(self.model_dtype))
|
|
119
121
|
|
|
120
|
-
if
|
|
121
|
-
if os.path.dirname(
|
|
122
|
-
os.makedirs(os.path.dirname(
|
|
123
|
-
log.info(f"Saving model to {
|
|
122
|
+
if self.model_save_path is not None:
|
|
123
|
+
if os.path.dirname(self.model_save_path):
|
|
124
|
+
os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
|
|
125
|
+
log.info(f"Saving model to {self.model_save_path}")
|
|
124
126
|
tokenizer = self.modelpool.load_tokenizer()
|
|
125
|
-
tokenizer.save_pretrained(
|
|
127
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
126
128
|
if not self.save_with_remote_code:
|
|
127
|
-
model.save_pretrained(
|
|
129
|
+
model.save_pretrained(self.model_save_path)
|
|
128
130
|
else:
|
|
129
131
|
save_pretrained_with_remote_code(
|
|
130
132
|
model,
|
|
@@ -133,17 +135,18 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
133
135
|
"AutoModel": SmileQwen2Model,
|
|
134
136
|
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
135
137
|
},
|
|
136
|
-
save_directory=
|
|
138
|
+
save_directory=self.model_save_path,
|
|
137
139
|
)
|
|
138
140
|
|
|
139
141
|
# save readme
|
|
140
|
-
|
|
141
|
-
algorithm=self,
|
|
142
|
-
modelpool=modelpool,
|
|
142
|
+
model_card_str = create_default_model_card(
|
|
143
143
|
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
144
|
+
description="Merged Qwen model using SMILE Upscaling",
|
|
145
|
+
algorithm_config=self.config,
|
|
146
|
+
modelpool_config=modelpool.config,
|
|
144
147
|
)
|
|
145
|
-
with open(os.path.join(
|
|
146
|
-
f.write(
|
|
148
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
149
|
+
f.write(model_card_str)
|
|
147
150
|
|
|
148
151
|
return model
|
|
149
152
|
|
|
@@ -174,9 +177,9 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
174
177
|
)
|
|
175
178
|
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
176
179
|
model_config = SmileQwen2Config(
|
|
177
|
-
num_experts_per_tok=
|
|
178
|
-
rank_of_router=
|
|
179
|
-
rank_of_expert=
|
|
180
|
+
num_experts_per_tok=self.num_experts_per_tok,
|
|
181
|
+
rank_of_router=self.rank_of_router,
|
|
182
|
+
rank_of_expert=self.rank_of_expert,
|
|
180
183
|
num_local_experts=len(finetuned_models),
|
|
181
184
|
**base_config.to_dict(),
|
|
182
185
|
)
|
|
@@ -186,7 +189,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
186
189
|
|
|
187
190
|
# copy pretrained model weights
|
|
188
191
|
state_dict = model.state_dict()
|
|
189
|
-
pretrained_state_dict =
|
|
192
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
190
193
|
for key in list(pretrained_state_dict.keys()):
|
|
191
194
|
if key not in state_dict:
|
|
192
195
|
pretrained_state_dict.pop(key)
|
|
@@ -198,6 +201,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
198
201
|
"Upscaling Modules (layer)",
|
|
199
202
|
dynamic_ncols=True,
|
|
200
203
|
):
|
|
204
|
+
if RuntimeConstants.debug and layer_idx > 0:
|
|
205
|
+
log.info(
|
|
206
|
+
"Debug mode enabled: processing only the first layer, skipping remaining layers"
|
|
207
|
+
)
|
|
208
|
+
break
|
|
209
|
+
|
|
201
210
|
pretrained_layer: Qwen2DecoderLayer = pretrained_model.model.layers[
|
|
202
211
|
layer_idx
|
|
203
212
|
]
|
|
@@ -213,7 +222,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
213
222
|
base=getattr(pretrained_layer.self_attn, n),
|
|
214
223
|
experts=[getattr(m.self_attn, n) for m in finetuned_layers],
|
|
215
224
|
target=getattr(target_layer.self_attn, n),
|
|
216
|
-
accelerator=
|
|
225
|
+
accelerator=self.accelerator,
|
|
217
226
|
)
|
|
218
227
|
except ExpertNotTrainedError:
|
|
219
228
|
setattr(
|
|
@@ -228,7 +237,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
228
237
|
base=getattr(pretrained_layer.mlp, n),
|
|
229
238
|
experts=[getattr(m.mlp, n) for m in finetuned_layers],
|
|
230
239
|
target=getattr(target_layer.mlp, n),
|
|
231
|
-
accelerator=
|
|
240
|
+
accelerator=self.accelerator,
|
|
232
241
|
)
|
|
233
242
|
except ExpertNotTrainedError:
|
|
234
243
|
setattr(
|
|
@@ -20,8 +20,8 @@ from fusion_bench.models.smile_moe.linear_from_module import (
|
|
|
20
20
|
SmileMoELinear,
|
|
21
21
|
)
|
|
22
22
|
from fusion_bench.models.utils import get_attr, set_attr
|
|
23
|
-
from fusion_bench.utils.parameters import print_parameters
|
|
24
23
|
from fusion_bench.utils.devices import get_device
|
|
24
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
25
25
|
|
|
26
26
|
log = logging.getLogger(__name__)
|
|
27
27
|
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
|
|
6
|
+
"""
|
|
7
|
+
Compute the entropy loss of a set of logits.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
11
|
+
eps (float): A small value to avoid log(0). Default is 1e-8.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
Tensor: The entropy loss of the logits.
|
|
15
|
+
"""
|
|
16
|
+
# Ensure the logits tensor has 2 dimensions
|
|
17
|
+
assert (
|
|
18
|
+
logits.dim() == 2
|
|
19
|
+
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
|
|
20
|
+
|
|
21
|
+
# Compute the softmax probabilities
|
|
22
|
+
probs = torch.softmax(logits, dim=-1)
|
|
23
|
+
|
|
24
|
+
# Compute the entropy loss
|
|
25
|
+
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
|
|
6
|
+
|
|
7
|
+
import lightning
|
|
8
|
+
import lightning as L
|
|
9
|
+
import lightning.fabric.wrappers
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
from transformers import T5ForConditionalGeneration
|
|
15
|
+
from transformers.data import default_data_collator
|
|
16
|
+
|
|
17
|
+
from fusion_bench.method import BaseAlgorithm
|
|
18
|
+
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
19
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
20
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
21
|
+
from fusion_bench.modelpool import Seq2SeqLMPool
|
|
22
|
+
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
23
|
+
from fusion_bench.utils import timeit_context
|
|
24
|
+
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
25
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
26
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
27
|
+
|
|
28
|
+
from .entropy_loss import entropy_loss
|
|
29
|
+
from .utils import get_memory_usage
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
35
|
+
BaseAlgorithm,
|
|
36
|
+
LightningFabricMixin,
|
|
37
|
+
SimpleProfilerMixin,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
|
|
41
|
+
for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
modelpool (Seq2SeqLMPool): The model pool containing the FlanT5 models.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
modelpool: Seq2SeqLMPool = None
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
checkpoint: bool = False,
|
|
52
|
+
save_checkpoint: bool = False,
|
|
53
|
+
router_hidden_layers: int = 2,
|
|
54
|
+
init_lambda: float = 0.3,
|
|
55
|
+
batch_reduce: bool = True,
|
|
56
|
+
lr: float = 1e-4,
|
|
57
|
+
optimizer: str = "adam",
|
|
58
|
+
devices: int = 1,
|
|
59
|
+
batch_size: int = 16,
|
|
60
|
+
num_workers: int = 0,
|
|
61
|
+
max_steps: int = 1000,
|
|
62
|
+
use_grad_accumulate: bool = True,
|
|
63
|
+
cache_dir: bool = "outputs",
|
|
64
|
+
fast_dev_run: bool = False,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
72
|
+
"""
|
|
73
|
+
self.checkpoint = checkpoint
|
|
74
|
+
self.save_checkpoint = save_checkpoint
|
|
75
|
+
self.router_hidden_layers = router_hidden_layers
|
|
76
|
+
self.init_lambda = init_lambda
|
|
77
|
+
self.batch_reduce = batch_reduce
|
|
78
|
+
self.lr = lr
|
|
79
|
+
self.optimizer = optimizer
|
|
80
|
+
self.devices = devices
|
|
81
|
+
self.batch_size = batch_size
|
|
82
|
+
self.num_workers = num_workers
|
|
83
|
+
self.max_steps = max_steps
|
|
84
|
+
self.use_grad_accumulate = use_grad_accumulate
|
|
85
|
+
self.cache_dir = cache_dir
|
|
86
|
+
self.fast_dev_run = fast_dev_run
|
|
87
|
+
super().__init__(**kwargs)
|
|
88
|
+
|
|
89
|
+
def construct_moe_model(self) -> WeightEnsemblingMoE:
|
|
90
|
+
"""
|
|
91
|
+
Construct the Mixture of Experts (MoE) model using the models in the model pool.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
WeightEnsemblingMoE: The constructed MoE model.
|
|
95
|
+
"""
|
|
96
|
+
base_model = self.modelpool.load_model("_pretrained_")
|
|
97
|
+
expert_models = [
|
|
98
|
+
self.modelpool.load_model(name) for name in self.modelpool.model_names
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
# Merge the models using task arithmetic
|
|
102
|
+
moe_model = task_arithmetic_merge(
|
|
103
|
+
# This function modifies the model in place, so we need to pass a deepcopy
|
|
104
|
+
deepcopy(base_model),
|
|
105
|
+
expert_models,
|
|
106
|
+
scaling_factor=self.init_lambda,
|
|
107
|
+
).requires_grad_(False)
|
|
108
|
+
|
|
109
|
+
print(base_model)
|
|
110
|
+
|
|
111
|
+
# Up-scale MLP modules
|
|
112
|
+
num_layer = 12
|
|
113
|
+
encoder_mlp_index = 1
|
|
114
|
+
base_encoder = base_model.encoder
|
|
115
|
+
moe_encoder = moe_model.encoder
|
|
116
|
+
expert_encoders = [m.encoder for m in expert_models]
|
|
117
|
+
|
|
118
|
+
for layer_idx in range(num_layer):
|
|
119
|
+
base_mlp = (
|
|
120
|
+
base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
|
|
121
|
+
)
|
|
122
|
+
expert_mlps = [
|
|
123
|
+
e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
|
|
124
|
+
for e in expert_encoders
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
|
|
128
|
+
WeightEnsemblingMoE(
|
|
129
|
+
hidden_size=base_encoder.config.hidden_size,
|
|
130
|
+
base_model=base_mlp,
|
|
131
|
+
expert_models=expert_mlps,
|
|
132
|
+
init_lambda=self.init_lambda,
|
|
133
|
+
batch_first=True,
|
|
134
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
135
|
+
batch_reduce=self.batch_reduce,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
decoder_mlp_index = 2
|
|
140
|
+
base_decoder = base_model.decoder
|
|
141
|
+
moe_decoder = moe_model.decoder
|
|
142
|
+
expert_decoders = [m.decoder for m in expert_models]
|
|
143
|
+
|
|
144
|
+
for layer_idx in range(num_layer):
|
|
145
|
+
base_mlp = (
|
|
146
|
+
base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
|
|
147
|
+
)
|
|
148
|
+
expert_mlps = [
|
|
149
|
+
e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
|
|
150
|
+
for e in expert_decoders
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
|
|
154
|
+
WeightEnsemblingMoE(
|
|
155
|
+
hidden_size=base_decoder.config.hidden_size,
|
|
156
|
+
base_model=base_mlp,
|
|
157
|
+
expert_models=expert_mlps,
|
|
158
|
+
init_lambda=self.init_lambda,
|
|
159
|
+
batch_first=True,
|
|
160
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
161
|
+
batch_reduce=self.batch_reduce,
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
print(moe_model)
|
|
166
|
+
return moe_model
|
|
167
|
+
|
|
168
|
+
@functools.cache
|
|
169
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
170
|
+
"""
|
|
171
|
+
Loader of test dataset for test-time adaptation. labels are not needed.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
task (str): The name of the task.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
DataLoader: The data loader for the test dataset.
|
|
178
|
+
"""
|
|
179
|
+
# dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
180
|
+
# dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
|
|
181
|
+
|
|
182
|
+
dataset = self.modelpool.load_test_dataset(task)
|
|
183
|
+
log.info("get_shuffled_test_loader_iter")
|
|
184
|
+
loader = DataLoader(
|
|
185
|
+
dataset,
|
|
186
|
+
batch_size=self.batch_size,
|
|
187
|
+
shuffle=True,
|
|
188
|
+
num_workers=self.num_workers,
|
|
189
|
+
collate_fn=default_data_collator,
|
|
190
|
+
)
|
|
191
|
+
# loader = DataLoader(dataset, **dataloader_kwargs)
|
|
192
|
+
if self.fabric is not None:
|
|
193
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
194
|
+
return iter(InfiniteDataLoader(loader))
|
|
195
|
+
|
|
196
|
+
def compute_logits(
|
|
197
|
+
self,
|
|
198
|
+
module: Union[T5ForConditionalGeneration],
|
|
199
|
+
batch,
|
|
200
|
+
task: str,
|
|
201
|
+
) -> Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Compute the logits for the given images and task.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
module: The model module.
|
|
207
|
+
images (Tensor): The input images.
|
|
208
|
+
task (str): The name of the task.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Tensor: The computed logits.
|
|
212
|
+
"""
|
|
213
|
+
input_ids: Tensor = batch["input_ids"]
|
|
214
|
+
attention_mask: Tensor = batch["attention_mask"]
|
|
215
|
+
|
|
216
|
+
# remove padding tokens from the input
|
|
217
|
+
while attention_mask[:, -1].eq(0).all():
|
|
218
|
+
input_ids = input_ids[:, :-1]
|
|
219
|
+
attention_mask = attention_mask[:, :-1]
|
|
220
|
+
|
|
221
|
+
outputs = module(
|
|
222
|
+
input_ids=input_ids,
|
|
223
|
+
attention_mask=attention_mask,
|
|
224
|
+
decoder_input_ids=torch.ones(
|
|
225
|
+
input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
logits = outputs.logits[:, 0, :]
|
|
229
|
+
return logits
|
|
230
|
+
|
|
231
|
+
def test_time_adaptation(self, module):
|
|
232
|
+
"""
|
|
233
|
+
Perform test-time adaptation for the given module.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
module (WeightEnsemblingMoE): The MoE module to adapt.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
WeightEnsemblingMoE: The adapted MoE module.
|
|
240
|
+
"""
|
|
241
|
+
self.on_test_time_adaptation_start()
|
|
242
|
+
|
|
243
|
+
# configure optimizer
|
|
244
|
+
if self.optimizer == "adam":
|
|
245
|
+
print([name for name, p in module.named_parameters() if p.requires_grad])
|
|
246
|
+
optimizer = torch.optim.Adam(
|
|
247
|
+
[p for p in module.parameters() if p.requires_grad], lr=self.lr
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError(f"Unsupported optimizer: {self.optimizer}")
|
|
251
|
+
|
|
252
|
+
module, optimizer = self.fabric.setup(module, optimizer)
|
|
253
|
+
|
|
254
|
+
module.train()
|
|
255
|
+
# module.merge_weights()
|
|
256
|
+
for step_idx in (
|
|
257
|
+
pbar := tqdm(
|
|
258
|
+
range(self.max_steps if not self.is_debug_mode else 1),
|
|
259
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
260
|
+
+ "WEMoE Test-time adaptation",
|
|
261
|
+
dynamic_ncols=True,
|
|
262
|
+
)
|
|
263
|
+
):
|
|
264
|
+
total_loss = 0
|
|
265
|
+
for task in self.modelpool.model_names:
|
|
266
|
+
with self.profile("data loading"):
|
|
267
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
268
|
+
with self.profile("forward pass"):
|
|
269
|
+
logits = self.compute_logits(module, batch, task)
|
|
270
|
+
logits = logits.mean(dim=0, keepdim=True)
|
|
271
|
+
loss = entropy_loss(logits)
|
|
272
|
+
total_loss += loss
|
|
273
|
+
with self.profile("backward pass"):
|
|
274
|
+
self.fabric.backward(loss, retain_graph=True)
|
|
275
|
+
|
|
276
|
+
with self.profile("optimizer step"):
|
|
277
|
+
optimizer.step()
|
|
278
|
+
optimizer.zero_grad()
|
|
279
|
+
|
|
280
|
+
metrics = {
|
|
281
|
+
"train/loss": total_loss.item(),
|
|
282
|
+
}
|
|
283
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
284
|
+
pbar.set_postfix(metrics)
|
|
285
|
+
|
|
286
|
+
log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
|
|
287
|
+
self.print_profile_summary()
|
|
288
|
+
return module
|
|
289
|
+
|
|
290
|
+
def on_test_time_adaptation_start(self):
|
|
291
|
+
"""
|
|
292
|
+
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
|
|
293
|
+
"""
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
def run(self, modelpool: Seq2SeqLMPool, **kwargs):
|
|
297
|
+
"""
|
|
298
|
+
Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
WeightEnsemblingMoE: The fused MoE model.
|
|
305
|
+
"""
|
|
306
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
307
|
+
self.modelpool = modelpool
|
|
308
|
+
|
|
309
|
+
with timeit_context("upscaling models to a weight-ensembling MoE model"):
|
|
310
|
+
moe_model = self.construct_moe_model()
|
|
311
|
+
print_parameters(moe_model)
|
|
312
|
+
|
|
313
|
+
if self.checkpoint != False:
|
|
314
|
+
log.info(
|
|
315
|
+
f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
|
|
316
|
+
)
|
|
317
|
+
self.load_checkpoint(moe_model, self.checkpoint)
|
|
318
|
+
else:
|
|
319
|
+
with self.profile("test-time adaptation"):
|
|
320
|
+
moe_model = self.test_time_adaptation(moe_model)
|
|
321
|
+
if self.save_checkpoint != False:
|
|
322
|
+
log.info(f"save checkpoint to {self.save_checkpoint}")
|
|
323
|
+
self.save_checkpoint(moe_model, self.save_checkpoint)
|
|
324
|
+
|
|
325
|
+
if lightning.fabric.wrappers.is_wrapped(moe_model):
|
|
326
|
+
moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
|
|
327
|
+
|
|
328
|
+
# enable sample-wise adaptation
|
|
329
|
+
moe_model.batch_reduce = False
|
|
330
|
+
self.print_profile_summary()
|
|
331
|
+
return moe_model
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_memory_usage(desc):
|
|
5
|
+
"""
|
|
6
|
+
obtain the current GPU memory usage
|
|
7
|
+
|
|
8
|
+
Returns:
|
|
9
|
+
str: A string containing the allocated and cached memory in MB.
|
|
10
|
+
"""
|
|
11
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
12
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
13
|
+
return (
|
|
14
|
+
f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
15
|
+
)
|
|
@@ -7,11 +7,11 @@ from transformers import PreTrainedModel
|
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
9
|
from fusion_bench.method import BaseAlgorithm
|
|
10
|
+
from fusion_bench.mixins import auto_register_config
|
|
10
11
|
from fusion_bench.modelpool import CausalLMPool
|
|
11
12
|
from fusion_bench.utils import timeit_context
|
|
12
13
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
13
14
|
from fusion_bench.utils.type import StateDictType
|
|
14
|
-
from fusion_bench.mixins import auto_register_config
|
|
15
15
|
|
|
16
16
|
log = logging.getLogger(__name__)
|
|
17
17
|
|
|
@@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
|
|
|
22
22
|
from tqdm.auto import tqdm
|
|
23
23
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
24
24
|
|
|
25
|
+
from fusion_bench import cache_with_joblib
|
|
25
26
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
26
27
|
from fusion_bench.mixins import LightningFabricMixin
|
|
27
28
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
@@ -46,7 +47,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
46
47
|
|
|
47
48
|
- `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
|
|
48
49
|
- `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
|
|
49
|
-
- `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
52
|
dataloader_kwargs: Dict[str, Any] = {}
|
|
@@ -54,7 +54,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
54
54
|
modelpool: CLIPVisionModelPool = None
|
|
55
55
|
_clip_processor: CLIPProcessor = None
|
|
56
56
|
# a dict of zeroshot weights for each task, each key is the task name
|
|
57
|
-
zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
|
|
58
57
|
zeroshot_weights: Dict[str, torch.Tensor] = {}
|
|
59
58
|
whether_setup_zero_shot_classification_head = False
|
|
60
59
|
|
|
@@ -131,26 +130,16 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
131
130
|
self.visual_projection = self.fabric.to_device(self.visual_projection)
|
|
132
131
|
self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
|
|
133
132
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
cache_dir = os.path.join(
|
|
144
|
-
self.zeroshot_weights_cache_dir,
|
|
145
|
-
os.path.normpath(model_name.split("/")[-1]),
|
|
146
|
-
)
|
|
147
|
-
if not os.path.exists(cache_dir):
|
|
148
|
-
log.info(
|
|
149
|
-
f"Creating cache directory for zero-shot classification head at {cache_dir}"
|
|
150
|
-
)
|
|
151
|
-
os.makedirs(cache_dir)
|
|
133
|
+
@cache_with_joblib()
|
|
134
|
+
def construct_classification_head(task: str):
|
|
135
|
+
nonlocal clip_classifier
|
|
136
|
+
|
|
137
|
+
classnames, templates = get_classnames_and_templates(task)
|
|
138
|
+
clip_classifier.set_classification_task(classnames, templates)
|
|
139
|
+
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
|
|
140
|
+
|
|
141
|
+
return zeroshot_weights
|
|
152
142
|
|
|
153
|
-
log.info(f"cache directory for zero-shot classification head: {cache_dir}")
|
|
154
143
|
for task in tqdm(
|
|
155
144
|
self.modelpool.model_names if task_names is None else task_names,
|
|
156
145
|
"Setting up zero-shot classification head",
|
|
@@ -158,27 +147,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
158
147
|
):
|
|
159
148
|
zeroshot_weights = None
|
|
160
149
|
if self.fabric.is_global_zero:
|
|
161
|
-
|
|
162
|
-
cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
|
|
163
|
-
)
|
|
164
|
-
if os.path.exists(cache_file):
|
|
165
|
-
zeroshot_weights = torch.load(
|
|
166
|
-
cache_file,
|
|
167
|
-
map_location="cpu",
|
|
168
|
-
weights_only=True,
|
|
169
|
-
).detach()
|
|
170
|
-
log.info(
|
|
171
|
-
f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
|
|
172
|
-
)
|
|
173
|
-
else:
|
|
174
|
-
log.info(
|
|
175
|
-
f"Construct zero shot classification head for task: {task}"
|
|
176
|
-
)
|
|
177
|
-
classnames, templates = get_classnames_and_templates(task)
|
|
178
|
-
clip_classifier.set_classification_task(classnames, templates)
|
|
179
|
-
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
|
|
180
|
-
log.info(f"save zeroshot weights to {cache_file}")
|
|
181
|
-
torch.save(zeroshot_weights, cache_file)
|
|
150
|
+
zeroshot_weights = construct_classification_head(task)
|
|
182
151
|
|
|
183
152
|
self.fabric.barrier()
|
|
184
153
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|