fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import itertools
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -10,16 +11,16 @@ import peft
|
|
|
10
11
|
import torch
|
|
11
12
|
from lightning.fabric.strategies.fsdp import FSDPStrategy
|
|
12
13
|
from lightning.fabric.utilities import rank_zero_only
|
|
13
|
-
from omegaconf import DictConfig
|
|
14
|
-
from peft import PeftModel, get_peft_config, get_peft_model
|
|
14
|
+
from omegaconf import DictConfig, OmegaConf
|
|
15
|
+
from peft import LoraConfig, PeftModel, get_peft_config, get_peft_model
|
|
15
16
|
from torch import nn
|
|
16
|
-
from torch.utils.data import
|
|
17
|
+
from torch.utils.data import ConcatDataset, DataLoader
|
|
17
18
|
from tqdm.auto import tqdm
|
|
18
19
|
from typing_extensions import TYPE_CHECKING, override
|
|
19
20
|
|
|
20
21
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
21
22
|
from fusion_bench.dataset.llama.collate import padded_collate_sft
|
|
22
|
-
from fusion_bench.mixins import
|
|
23
|
+
from fusion_bench.mixins import FabricTrainingMixin
|
|
23
24
|
from fusion_bench.modelpool import CausalLMPool
|
|
24
25
|
from fusion_bench.utils import instantiate
|
|
25
26
|
from fusion_bench.utils.dtype import get_dtype
|
|
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
|
|
|
35
36
|
log = logging.getLogger(__name__)
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
class PeftFinetuneSFT(BaseAlgorithm,
|
|
39
|
+
class PeftFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
|
|
39
40
|
|
|
40
41
|
model: Union[
|
|
41
42
|
nn.Module, "_FabricModule", "LlamaForCausalLM", PeftModel, peft.LoraModel
|
|
@@ -65,7 +66,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
65
66
|
gradient_clip_algorithm: Literal["value", "norm"] = "norm",
|
|
66
67
|
save_optimizer_state: bool = False,
|
|
67
68
|
save_full_model: bool = False,
|
|
69
|
+
save_ckpt_type: Literal["lightning", "peft"] = "peft",
|
|
68
70
|
ckpt_path: Optional[str] = None,
|
|
71
|
+
max_length: int = 6144,
|
|
69
72
|
**kwargs,
|
|
70
73
|
):
|
|
71
74
|
"""
|
|
@@ -90,6 +93,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
90
93
|
gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
|
|
91
94
|
save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
|
|
92
95
|
save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
|
|
96
|
+
save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
|
|
93
97
|
ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
|
|
94
98
|
"""
|
|
95
99
|
self._optimizer = optimizer
|
|
@@ -110,23 +114,31 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
110
114
|
self.gradient_clip_algorithm = gradient_clip_algorithm
|
|
111
115
|
self.save_optimizer_state = save_optimizer_state
|
|
112
116
|
self.save_full_model = save_full_model
|
|
117
|
+
self.save_ckpt_type = save_ckpt_type
|
|
113
118
|
self.ckpt_path = ckpt_path
|
|
119
|
+
self.max_length = max_length
|
|
114
120
|
super().__init__(**kwargs)
|
|
115
121
|
|
|
116
122
|
def run(self, modelpool: CausalLMPool):
|
|
117
123
|
self.modelpool = modelpool
|
|
118
124
|
self.setup()
|
|
119
|
-
self.train()
|
|
125
|
+
self.train(self.model, self.optimizer, self.lr_scheduler)
|
|
120
126
|
|
|
121
127
|
if self.merge_and_unload:
|
|
122
128
|
self.model = self.model.merge_and_unload()
|
|
123
129
|
return self.model
|
|
124
130
|
|
|
125
131
|
def setup_model(self):
|
|
132
|
+
# https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py
|
|
133
|
+
self.tokenizer = self.modelpool.load_tokenizer()
|
|
134
|
+
if self.tokenizer.pad_token_id is None:
|
|
135
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
136
|
+
|
|
126
137
|
model = self.modelpool.load_pretrained_model()
|
|
127
138
|
|
|
128
139
|
# get the PEFT model
|
|
129
|
-
peft_config =
|
|
140
|
+
peft_config = instantiate(self._peft_config, _convert_="all")
|
|
141
|
+
peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config"))
|
|
130
142
|
peft_model = get_peft_model(model, peft_config, self.adapter_name)
|
|
131
143
|
peft_model.print_trainable_parameters()
|
|
132
144
|
|
|
@@ -139,21 +151,16 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
139
151
|
self.model.gradient_checkpointing_enable(
|
|
140
152
|
gradient_checkpointing_kwargs={"use_reentrant": True}
|
|
141
153
|
)
|
|
154
|
+
self.use_cache = False
|
|
155
|
+
else:
|
|
156
|
+
self.use_cache = True
|
|
157
|
+
|
|
142
158
|
self.model_dtype = get_dtype(self.model)
|
|
159
|
+
self.model = self.model.to(dtype=self.model_dtype)
|
|
143
160
|
|
|
144
161
|
def configure_optimizer(self):
|
|
145
162
|
# compute expected total steps
|
|
146
|
-
self.
|
|
147
|
-
if self.max_steps > 0:
|
|
148
|
-
self.expected_total_steps.append(self.max_steps)
|
|
149
|
-
if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
|
|
150
|
-
self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
|
|
151
|
-
if self.max_epochs > 0:
|
|
152
|
-
self.expected_total_steps.append(
|
|
153
|
-
len(self.train_dataloader) * self.max_epochs
|
|
154
|
-
)
|
|
155
|
-
self.expected_total_steps = min(self.expected_total_steps)
|
|
156
|
-
log.info(f"Expected total steps: {self.expected_total_steps}")
|
|
163
|
+
self.compute_expected_total_steps(self.train_dataloader)
|
|
157
164
|
|
|
158
165
|
optimizer = instantiate(self._optimizer, self.model.parameters())
|
|
159
166
|
if self._lr_scheduler is not None:
|
|
@@ -192,7 +199,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
192
199
|
train_dataset,
|
|
193
200
|
**self.dataloader_kwargs,
|
|
194
201
|
shuffle=True,
|
|
195
|
-
collate_fn=
|
|
202
|
+
collate_fn=functools.partial(
|
|
203
|
+
padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
|
|
204
|
+
),
|
|
196
205
|
)
|
|
197
206
|
self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
|
|
198
207
|
|
|
@@ -205,28 +214,19 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
205
214
|
optimizer = self.configure_optimizer()
|
|
206
215
|
optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
|
|
207
216
|
|
|
208
|
-
self.model
|
|
217
|
+
self.model = self.fabric.setup_module(self.model)
|
|
218
|
+
self.optimizer = self.fabric.setup_optimizers(optimizer)
|
|
209
219
|
self.lr_scheduler = lr_scheduler
|
|
210
220
|
|
|
211
|
-
|
|
221
|
+
@override
|
|
222
|
+
def train_epoch(self, *args, **kwargs):
|
|
212
223
|
fabric = self.fabric
|
|
213
224
|
|
|
214
|
-
|
|
215
|
-
if self.gradient_clip_algorithm == "value":
|
|
216
|
-
fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
|
|
217
|
-
elif self.gradient_clip_algorithm == "norm":
|
|
218
|
-
fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
|
|
219
|
-
else:
|
|
220
|
-
raise ValueError(
|
|
221
|
-
f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
def train_epoch(self):
|
|
225
|
-
fabric = self.fabric
|
|
225
|
+
accumulated_loss = 0
|
|
226
226
|
for step_idx, batch in enumerate(
|
|
227
227
|
pbar := tqdm(
|
|
228
228
|
self.train_dataloader,
|
|
229
|
-
desc="Training
|
|
229
|
+
desc="Training Batches",
|
|
230
230
|
dynamic_ncols=True,
|
|
231
231
|
leave=False,
|
|
232
232
|
disable=not fabric.is_global_zero,
|
|
@@ -234,24 +234,30 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
234
234
|
):
|
|
235
235
|
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
236
236
|
|
|
237
|
+
if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
|
|
238
|
+
log.warning(
|
|
239
|
+
f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
|
|
240
|
+
)
|
|
241
|
+
batch["input_ids"] = batch["input_ids"][:, : self.max_length]
|
|
242
|
+
batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
|
|
243
|
+
batch["labels"] = batch["labels"][:, : self.max_length]
|
|
244
|
+
|
|
237
245
|
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
238
246
|
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
239
247
|
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
240
|
-
output = self.model(
|
|
241
|
-
|
|
248
|
+
output = self.model(
|
|
249
|
+
input_ids=batch["input_ids"],
|
|
250
|
+
attention_mask=batch["attention_mask"],
|
|
251
|
+
labels=batch["labels"],
|
|
252
|
+
use_cache=self.use_cache,
|
|
253
|
+
)
|
|
254
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
242
255
|
|
|
243
256
|
fabric.backward(loss)
|
|
244
|
-
|
|
245
|
-
metrics = {
|
|
246
|
-
"train/loss": loss.item(),
|
|
247
|
-
"train/epoch_idx": self.epoch_idx,
|
|
248
|
-
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
249
|
-
}
|
|
250
|
-
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
251
|
-
pbar.set_postfix(metrics)
|
|
257
|
+
accumulated_loss += loss.item()
|
|
252
258
|
|
|
253
259
|
if not is_accumulating:
|
|
254
|
-
self.
|
|
260
|
+
self.clip_gradients_if_needed(self.model, self.optimizer)
|
|
255
261
|
|
|
256
262
|
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
257
263
|
if (
|
|
@@ -264,104 +270,30 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
264
270
|
self.optimizer.step()
|
|
265
271
|
self.optimizer.zero_grad()
|
|
266
272
|
|
|
267
|
-
|
|
268
|
-
|
|
273
|
+
metrics = {
|
|
274
|
+
"train/loss": accumulated_loss,
|
|
275
|
+
"train/epoch_idx": self.epoch_idx,
|
|
276
|
+
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
277
|
+
}
|
|
278
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
279
|
+
pbar.set_postfix(metrics)
|
|
269
280
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
self.max_steps_per_epoch > 0
|
|
273
|
-
and step_idx + 1 >= self.max_steps_per_epoch
|
|
274
|
-
):
|
|
275
|
-
break
|
|
276
|
-
# break if max_steps is set, and exit training
|
|
277
|
-
if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
|
|
278
|
-
self.is_training = False
|
|
279
|
-
break
|
|
281
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
282
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
280
283
|
|
|
281
|
-
|
|
284
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
285
|
+
if (
|
|
286
|
+
self.max_steps_per_epoch > 0
|
|
287
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
288
|
+
):
|
|
289
|
+
break
|
|
290
|
+
# break if max_steps is set, and exit training
|
|
291
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
292
|
+
self.is_training = False
|
|
293
|
+
break
|
|
282
294
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
self.is_training = True
|
|
286
|
-
self.global_step_idx = 0
|
|
287
|
-
self.model.train()
|
|
288
|
-
for epoch_idx in tqdm(
|
|
289
|
-
range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
|
|
290
|
-
"Training Epoch",
|
|
291
|
-
dynamic_ncols=True,
|
|
292
|
-
leave=False,
|
|
293
|
-
disable=not fabric.is_global_zero,
|
|
294
|
-
):
|
|
295
|
-
self.epoch_idx = epoch_idx
|
|
296
|
-
self.train_epoch()
|
|
297
|
-
# run lr_scheduler at the end of the epoch if interval is set to "epoch"
|
|
298
|
-
if (
|
|
299
|
-
self.lr_scheduler_interval == "epoch"
|
|
300
|
-
and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
|
|
301
|
-
):
|
|
302
|
-
self.lr_scheduler.step()
|
|
303
|
-
|
|
304
|
-
# save the model at the end of the epoch if interval is set to "epoch" and frequency is met
|
|
305
|
-
self._try_save_checkpoint(stage="end_of_epoch")
|
|
306
|
-
|
|
307
|
-
if not self.is_training:
|
|
308
|
-
break
|
|
309
|
-
|
|
310
|
-
# save the model at the end of training
|
|
311
|
-
self._try_save_checkpoint(stage="end_of_training")
|
|
312
|
-
|
|
313
|
-
def _try_save_checkpoint(
|
|
314
|
-
self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
|
|
315
|
-
):
|
|
316
|
-
if stage == "end_of_step":
|
|
317
|
-
if (
|
|
318
|
-
self.checkpoint_save_interval == "step"
|
|
319
|
-
and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
|
|
320
|
-
):
|
|
321
|
-
self.save_checkpoint(
|
|
322
|
-
os.path.join(
|
|
323
|
-
self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
|
|
324
|
-
)
|
|
325
|
-
)
|
|
326
|
-
elif stage == "end_of_epoch":
|
|
327
|
-
if (
|
|
328
|
-
self.checkpoint_save_interval == "epoch"
|
|
329
|
-
and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
|
|
330
|
-
):
|
|
331
|
-
self.save_checkpoint(
|
|
332
|
-
os.path.join(
|
|
333
|
-
self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
|
|
334
|
-
)
|
|
335
|
-
)
|
|
336
|
-
elif stage == "end_of_training":
|
|
337
|
-
# if the checkpoint has not been saved yet, save it
|
|
338
|
-
if self.global_step_idx > self._latest_saved_checkpoint_global_step:
|
|
339
|
-
self.save_checkpoint(
|
|
340
|
-
os.path.join(
|
|
341
|
-
self.log_dir,
|
|
342
|
-
"checkpoints",
|
|
343
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
344
|
-
)
|
|
345
|
-
)
|
|
346
|
-
try:
|
|
347
|
-
os.symlink(
|
|
348
|
-
os.path.join(
|
|
349
|
-
self.log_dir,
|
|
350
|
-
"checkpoints",
|
|
351
|
-
"latest_model.ckpt",
|
|
352
|
-
),
|
|
353
|
-
os.path.join(
|
|
354
|
-
self.log_dir,
|
|
355
|
-
"checkpoints",
|
|
356
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
357
|
-
),
|
|
358
|
-
)
|
|
359
|
-
except Exception as e:
|
|
360
|
-
pass
|
|
361
|
-
else:
|
|
362
|
-
raise ValueError(
|
|
363
|
-
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
364
|
-
)
|
|
295
|
+
self.global_step_idx += 1
|
|
296
|
+
accumulated_loss = 0
|
|
365
297
|
|
|
366
298
|
def save_checkpoint(
|
|
367
299
|
self,
|
|
@@ -373,24 +305,37 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
373
305
|
return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
|
|
374
306
|
|
|
375
307
|
fabric = self.fabric
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
308
|
+
if self.save_ckpt_type == "lightning":
|
|
309
|
+
state = {"model": self.model}
|
|
310
|
+
|
|
311
|
+
# save the optimizer and lr_scheduler state if needed
|
|
312
|
+
if self.save_optimizer_state and save_optimizer_state is not False:
|
|
313
|
+
state.update(
|
|
314
|
+
{
|
|
315
|
+
"optimizer": self.optimizer,
|
|
316
|
+
"lr_scheduler": self.lr_scheduler,
|
|
317
|
+
"global_step_idx": self.global_step_idx,
|
|
318
|
+
"epoch_idx": self.epoch_idx,
|
|
319
|
+
}
|
|
320
|
+
)
|
|
321
|
+
trainable_param_names = set(
|
|
322
|
+
name
|
|
323
|
+
for name, param in self.model.state_dict(keep_vars=True).items()
|
|
324
|
+
if param.requires_grad
|
|
325
|
+
)
|
|
326
|
+
filter = (
|
|
327
|
+
None
|
|
328
|
+
if self.save_full_model
|
|
329
|
+
else {"model": lambda k, p: k in trainable_param_names}
|
|
330
|
+
)
|
|
331
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
332
|
+
fabric.save(path, state=state, filter=filter)
|
|
333
|
+
elif self.save_ckpt_type == "peft":
|
|
334
|
+
self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(
|
|
337
|
+
f"Unknown save_ckpt_type: {self.save_ckpt_type}. Available options: 'lightning', 'peft'"
|
|
387
338
|
)
|
|
388
|
-
|
|
389
|
-
filter = (
|
|
390
|
-
None if self.save_full_model else {"model": lambda k, p: p.requires_grad}
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
fabric.save(path, state=state, filter=filter)
|
|
394
339
|
self._latest_saved_checkpoint_global_step = self.global_step_idx
|
|
395
340
|
|
|
396
341
|
def load_checkpoint(self, path: Union[str, Path]):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Literal, Optional, Union # noqa: F401
|
|
1
|
+
from typing import Dict, Literal, Optional, Union # noqa: F401
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import
|
|
4
|
+
from torch import nn
|
|
5
5
|
from tqdm.auto import tqdm
|
|
6
6
|
from transformers import LlamaForCausalLM, LlamaModel
|
|
7
7
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
2
3
|
import re
|
|
3
4
|
from copy import deepcopy
|
|
@@ -10,7 +11,7 @@ from tqdm.auto import tqdm
|
|
|
10
11
|
from fusion_bench.method import BaseAlgorithm
|
|
11
12
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
12
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
13
|
-
|
|
14
|
+
|
|
14
15
|
from .prune_utils import unstructured_magnitude_prune_
|
|
15
16
|
|
|
16
17
|
log = logging.getLogger(__name__)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
|
10
|
+
|
|
11
|
+
from fusion_bench.dataset import CLIPDataset
|
|
12
|
+
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
13
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
14
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
15
|
+
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
16
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
17
|
+
|
|
18
|
+
from .rankone_moe import RankOneMoEAlgorithm
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CLIPRankOneMoEAlgorithm(
|
|
24
|
+
RankOneMoEAlgorithm,
|
|
25
|
+
CLIPClassificationMixin,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
CLIPRankOneMoEAlgorithm is a class that implements the RankOneMoEAlgorithm (https://github.com/EnnengYang/RankOne-MoE)
|
|
29
|
+
for CLIP models. It extends the RankOneMoEAlgorithm and CLIPClassificationMixin classes.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
modelpool: CLIPVisionModelPool = None
|
|
36
|
+
|
|
37
|
+
def load_checkpoint(self, model, checkpoint):
|
|
38
|
+
"""
|
|
39
|
+
Load the checkpoint file.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model: The model to load the checkpoint into.
|
|
43
|
+
checkpoint: The path to the checkpoint file.
|
|
44
|
+
"""
|
|
45
|
+
state = {"model": model}
|
|
46
|
+
self._fabric.load(checkpoint, state)
|
|
47
|
+
|
|
48
|
+
def save_checkpoint(self, model, checkpoint):
|
|
49
|
+
"""
|
|
50
|
+
Save the checkpoint file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: The model to save the checkpoint from.
|
|
54
|
+
checkpoint: The path to the checkpoint file.
|
|
55
|
+
"""
|
|
56
|
+
self._fabric.save(checkpoint, {"model": model})
|
|
57
|
+
|
|
58
|
+
def construct_moe_model(self) -> RankOneMoE:
|
|
59
|
+
"""
|
|
60
|
+
Construct the RankOne-MoE model using the models in the model pool.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
RankOne-MoE: The constructed MoE model.
|
|
64
|
+
"""
|
|
65
|
+
base_model = self.modelpool.load_model("_pretrained_")
|
|
66
|
+
expert_models = [
|
|
67
|
+
self.modelpool.load_model(m) for m in self.modelpool.model_names
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
# Merge the models using task arithmetic
|
|
71
|
+
moe_model = task_arithmetic_merge(
|
|
72
|
+
# This function modifies the model in place, so we need to pass a deepcopy
|
|
73
|
+
deepcopy(base_model),
|
|
74
|
+
expert_models,
|
|
75
|
+
scaling_factor=self.config.init_lambda,
|
|
76
|
+
).requires_grad_(False)
|
|
77
|
+
|
|
78
|
+
# Up-scale MLP modules
|
|
79
|
+
base_encoder: CLIPEncoder = base_model.vision_model.encoder
|
|
80
|
+
moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
|
|
81
|
+
expert_encoders = [m.vision_model.encoder for m in expert_models]
|
|
82
|
+
|
|
83
|
+
num_layers = len(base_encoder.layers)
|
|
84
|
+
for layer_idx in range(num_layers):
|
|
85
|
+
base_mlp = base_encoder.layers[layer_idx].mlp
|
|
86
|
+
expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
|
|
87
|
+
|
|
88
|
+
moe_encoder.layers[layer_idx].mlp = RankOneMoE(
|
|
89
|
+
hidden_size=base_encoder.config.hidden_size,
|
|
90
|
+
base_model=base_mlp,
|
|
91
|
+
expert_models=expert_mlps,
|
|
92
|
+
init_lambda=self.config.init_lambda,
|
|
93
|
+
batch_first=True, # For open_clip models this is False
|
|
94
|
+
router_hidden_layers=self.config.router_hidden_layers,
|
|
95
|
+
batch_reduce=self.config.batch_reduce,
|
|
96
|
+
svd_accelerator=self.config.svd_accelerator,
|
|
97
|
+
rank_k=self.config.rank_k,
|
|
98
|
+
select_k=self.config.select_k,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return moe_model
|
|
102
|
+
|
|
103
|
+
@functools.cache
|
|
104
|
+
def get_shuffled_test_loader_iter(self, tta_dataset: str):
|
|
105
|
+
"""
|
|
106
|
+
Get an iterator for the shuffled test data loader.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tta_dataset (str): The name of the test-time adaptation dataset.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Iterator: An iterator for the shuffled test data loader.
|
|
113
|
+
"""
|
|
114
|
+
dataset = self.modelpool.load_test_dataset(tta_dataset)
|
|
115
|
+
dataset = CLIPDataset(dataset, processor=self.clip_processor)
|
|
116
|
+
log.info("get_shuffled_test_loader_iter")
|
|
117
|
+
loader = DataLoader(
|
|
118
|
+
dataset,
|
|
119
|
+
batch_size=self.config.batch_size,
|
|
120
|
+
shuffle=True,
|
|
121
|
+
num_workers=self.config.num_workers,
|
|
122
|
+
pin_memory=True,
|
|
123
|
+
)
|
|
124
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
125
|
+
return iter(InfiniteDataLoader(loader))
|
|
126
|
+
|
|
127
|
+
def on_test_time_adaptation_start(self):
|
|
128
|
+
"""
|
|
129
|
+
Load the CLIP processor and construct the zero-shot classification head for each task.
|
|
130
|
+
"""
|
|
131
|
+
self.setup_zero_shot_classification_head()
|
|
132
|
+
|
|
133
|
+
def compute_logits(self, module, batch, task) -> Tensor:
|
|
134
|
+
"""
|
|
135
|
+
Compute the logits for the given batch and task.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
module: The model module.
|
|
139
|
+
batch: The input batch.
|
|
140
|
+
task: The task name.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Tensor: The computed logits.
|
|
144
|
+
"""
|
|
145
|
+
images, _ = batch
|
|
146
|
+
text_embeds = self.zeroshot_weights[task]
|
|
147
|
+
|
|
148
|
+
image_embeds = module(images)[1]
|
|
149
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
150
|
+
|
|
151
|
+
# Normalize embeddings
|
|
152
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
153
|
+
|
|
154
|
+
# Cosine similarity
|
|
155
|
+
logits_per_text = (
|
|
156
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
157
|
+
)
|
|
158
|
+
logits_per_image = logits_per_text.t()
|
|
159
|
+
|
|
160
|
+
return logits_per_image
|