fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import itertools
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -13,11 +14,12 @@ from omegaconf import DictConfig
|
|
|
13
14
|
from torch import nn
|
|
14
15
|
from torch.utils.data import ConcatDataset, DataLoader
|
|
15
16
|
from tqdm.auto import tqdm
|
|
17
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
16
18
|
from typing_extensions import TYPE_CHECKING, override
|
|
17
19
|
|
|
18
20
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
19
21
|
from fusion_bench.dataset.llama.collate import padded_collate_sft
|
|
20
|
-
from fusion_bench.mixins import
|
|
22
|
+
from fusion_bench.mixins import FabricTrainingMixin
|
|
21
23
|
from fusion_bench.modelpool import CausalLMPool
|
|
22
24
|
from fusion_bench.utils import instantiate
|
|
23
25
|
from fusion_bench.utils.dtype import get_dtype
|
|
@@ -33,7 +35,7 @@ if TYPE_CHECKING:
|
|
|
33
35
|
log = logging.getLogger(__name__)
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
class FullFinetuneSFT(BaseAlgorithm,
|
|
38
|
+
class FullFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
|
|
37
39
|
|
|
38
40
|
model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
|
|
39
41
|
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
|
|
@@ -58,7 +60,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
58
60
|
gradient_clip_algorithm: Literal["value", "norm"] = "norm",
|
|
59
61
|
save_optimizer_state: bool = False,
|
|
60
62
|
save_full_model: bool = False,
|
|
63
|
+
save_ckpt_type: Literal["lightning", "hf"] = "lightning",
|
|
61
64
|
ckpt_path: Optional[str] = None,
|
|
65
|
+
max_length: int = 6144,
|
|
66
|
+
fix_token_embedding: bool = True,
|
|
62
67
|
**kwargs,
|
|
63
68
|
):
|
|
64
69
|
"""
|
|
@@ -80,7 +85,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
80
85
|
gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
|
|
81
86
|
save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
|
|
82
87
|
save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
|
|
88
|
+
save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
|
|
83
89
|
ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
|
|
90
|
+
max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
|
|
91
|
+
fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
|
|
84
92
|
"""
|
|
85
93
|
self._optimizer = optimizer
|
|
86
94
|
self._lr_scheduler = lr_scheduler
|
|
@@ -97,18 +105,28 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
97
105
|
self.gradient_clip_algorithm = gradient_clip_algorithm
|
|
98
106
|
self.save_optimizer_state = save_optimizer_state
|
|
99
107
|
self.save_full_model = save_full_model
|
|
108
|
+
self.save_ckpt_type = save_ckpt_type
|
|
100
109
|
self.ckpt_path = ckpt_path
|
|
110
|
+
self.max_length = max_length
|
|
111
|
+
self.fix_token_embedding = fix_token_embedding
|
|
101
112
|
super().__init__(**kwargs)
|
|
102
113
|
|
|
103
114
|
def run(self, modelpool: CausalLMPool):
|
|
104
115
|
self.modelpool = modelpool
|
|
105
116
|
self.setup()
|
|
106
|
-
self.train()
|
|
117
|
+
self.train(self.model, self.optimizer, self.lr_scheduler)
|
|
107
118
|
return self.model
|
|
108
119
|
|
|
109
120
|
def setup_model(self):
|
|
121
|
+
self.tokenizer = self.modelpool.load_tokenizer()
|
|
122
|
+
if self.tokenizer.pad_token_id is None:
|
|
123
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
124
|
+
|
|
110
125
|
model = self.modelpool.load_pretrained_model()
|
|
111
|
-
self.model = model
|
|
126
|
+
self.model: "LlamaForCausalLM" = model
|
|
127
|
+
|
|
128
|
+
if self.fix_token_embedding:
|
|
129
|
+
self.model.model.embed_tokens.requires_grad_(False)
|
|
112
130
|
|
|
113
131
|
if self.fabric.strategy == "fsdp" or isinstance(
|
|
114
132
|
self.fabric.strategy, FSDPStrategy
|
|
@@ -117,21 +135,14 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
117
135
|
self.model.gradient_checkpointing_enable(
|
|
118
136
|
gradient_checkpointing_kwargs={"use_reentrant": True}
|
|
119
137
|
)
|
|
138
|
+
self.use_cache = False
|
|
139
|
+
else:
|
|
140
|
+
self.use_cache = True
|
|
120
141
|
self.model_dtype = get_dtype(self.model)
|
|
121
142
|
|
|
122
143
|
def configure_optimizer(self):
|
|
123
144
|
# compute expected total steps
|
|
124
|
-
self.
|
|
125
|
-
if self.max_steps > 0:
|
|
126
|
-
self.expected_total_steps.append(self.max_steps)
|
|
127
|
-
if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
|
|
128
|
-
self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
|
|
129
|
-
if self.max_epochs > 0:
|
|
130
|
-
self.expected_total_steps.append(
|
|
131
|
-
len(self.train_dataloader) * self.max_epochs
|
|
132
|
-
)
|
|
133
|
-
self.expected_total_steps = min(self.expected_total_steps)
|
|
134
|
-
log.info(f"Expected total steps: {self.expected_total_steps}")
|
|
145
|
+
self.compute_expected_total_steps(self.train_dataloader)
|
|
135
146
|
|
|
136
147
|
optimizer = instantiate(self._optimizer, self.model.parameters())
|
|
137
148
|
if self._lr_scheduler is not None:
|
|
@@ -170,7 +181,9 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
170
181
|
train_dataset,
|
|
171
182
|
**self.dataloader_kwargs,
|
|
172
183
|
shuffle=True,
|
|
173
|
-
collate_fn=
|
|
184
|
+
collate_fn=functools.partial(
|
|
185
|
+
padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
|
|
186
|
+
),
|
|
174
187
|
)
|
|
175
188
|
self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
|
|
176
189
|
|
|
@@ -186,25 +199,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
186
199
|
self.model, self.optimizer = fabric.setup(self.model, optimizer)
|
|
187
200
|
self.lr_scheduler = lr_scheduler
|
|
188
201
|
|
|
189
|
-
|
|
202
|
+
@override
|
|
203
|
+
def train_epoch(self, *args, **kwargs):
|
|
190
204
|
fabric = self.fabric
|
|
191
205
|
|
|
192
|
-
|
|
193
|
-
if self.gradient_clip_algorithm == "value":
|
|
194
|
-
fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
|
|
195
|
-
elif self.gradient_clip_algorithm == "norm":
|
|
196
|
-
fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
|
|
197
|
-
else:
|
|
198
|
-
raise ValueError(
|
|
199
|
-
f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
def train_epoch(self):
|
|
203
|
-
fabric = self.fabric
|
|
206
|
+
accumulated_loss = 0
|
|
204
207
|
for step_idx, batch in enumerate(
|
|
205
208
|
pbar := tqdm(
|
|
206
209
|
self.train_dataloader,
|
|
207
|
-
desc="Training
|
|
210
|
+
desc="Training Batches",
|
|
208
211
|
dynamic_ncols=True,
|
|
209
212
|
leave=False,
|
|
210
213
|
disable=not fabric.is_global_zero,
|
|
@@ -212,24 +215,30 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
212
215
|
):
|
|
213
216
|
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
214
217
|
|
|
218
|
+
if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
|
|
219
|
+
log.warning(
|
|
220
|
+
f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
|
|
221
|
+
)
|
|
222
|
+
batch["input_ids"] = batch["input_ids"][:, : self.max_length]
|
|
223
|
+
batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
|
|
224
|
+
batch["labels"] = batch["labels"][:, : self.max_length]
|
|
225
|
+
|
|
215
226
|
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
216
227
|
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
217
228
|
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
218
|
-
output = self.model(
|
|
219
|
-
|
|
229
|
+
output = self.model(
|
|
230
|
+
input_ids=batch["input_ids"],
|
|
231
|
+
attention_mask=batch["attention_mask"],
|
|
232
|
+
labels=batch["labels"],
|
|
233
|
+
use_cache=self.use_cache,
|
|
234
|
+
)
|
|
235
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
220
236
|
|
|
221
237
|
fabric.backward(loss)
|
|
222
|
-
|
|
223
|
-
metrics = {
|
|
224
|
-
"train/loss": loss.item(),
|
|
225
|
-
"train/epoch_idx": self.epoch_idx,
|
|
226
|
-
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
227
|
-
}
|
|
228
|
-
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
229
|
-
pbar.set_postfix(metrics)
|
|
238
|
+
accumulated_loss += loss.item()
|
|
230
239
|
|
|
231
240
|
if not is_accumulating:
|
|
232
|
-
self.
|
|
241
|
+
self.clip_gradients_if_needed(self.model, self.optimizer)
|
|
233
242
|
|
|
234
243
|
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
235
244
|
if (
|
|
@@ -242,104 +251,30 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
242
251
|
self.optimizer.step()
|
|
243
252
|
self.optimizer.zero_grad()
|
|
244
253
|
|
|
245
|
-
|
|
246
|
-
|
|
254
|
+
metrics = {
|
|
255
|
+
"train/loss": accumulated_loss,
|
|
256
|
+
"train/epoch_idx": self.epoch_idx,
|
|
257
|
+
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
258
|
+
}
|
|
259
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
260
|
+
pbar.set_postfix(metrics)
|
|
247
261
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
self.max_steps_per_epoch > 0
|
|
251
|
-
and step_idx + 1 >= self.max_steps_per_epoch
|
|
252
|
-
):
|
|
253
|
-
break
|
|
254
|
-
# break if max_steps is set, and exit training
|
|
255
|
-
if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
|
|
256
|
-
self.is_training = False
|
|
257
|
-
break
|
|
262
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
263
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
258
264
|
|
|
259
|
-
|
|
265
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
266
|
+
if (
|
|
267
|
+
self.max_steps_per_epoch > 0
|
|
268
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
269
|
+
):
|
|
270
|
+
break
|
|
271
|
+
# break if max_steps is set, and exit training
|
|
272
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
273
|
+
self.is_training = False
|
|
274
|
+
break
|
|
260
275
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
self.is_training = True
|
|
264
|
-
self.global_step_idx = 0
|
|
265
|
-
self.model.train()
|
|
266
|
-
for epoch_idx in tqdm(
|
|
267
|
-
range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
|
|
268
|
-
"Training Epoch",
|
|
269
|
-
dynamic_ncols=True,
|
|
270
|
-
leave=False,
|
|
271
|
-
disable=not fabric.is_global_zero,
|
|
272
|
-
):
|
|
273
|
-
self.epoch_idx = epoch_idx
|
|
274
|
-
self.train_epoch()
|
|
275
|
-
# run lr_scheduler at the end of the epoch if interval is set to "epoch"
|
|
276
|
-
if (
|
|
277
|
-
self.lr_scheduler_interval == "epoch"
|
|
278
|
-
and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
|
|
279
|
-
):
|
|
280
|
-
self.lr_scheduler.step()
|
|
281
|
-
|
|
282
|
-
# save the model at the end of the epoch if interval is set to "epoch" and frequency is met
|
|
283
|
-
self._try_save_checkpoint(stage="end_of_epoch")
|
|
284
|
-
|
|
285
|
-
if not self.is_training:
|
|
286
|
-
break
|
|
287
|
-
|
|
288
|
-
# save the model at the end of training
|
|
289
|
-
self._try_save_checkpoint(stage="end_of_training")
|
|
290
|
-
|
|
291
|
-
def _try_save_checkpoint(
|
|
292
|
-
self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
|
|
293
|
-
):
|
|
294
|
-
if stage == "end_of_step":
|
|
295
|
-
if (
|
|
296
|
-
self.checkpoint_save_interval == "step"
|
|
297
|
-
and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
|
|
298
|
-
):
|
|
299
|
-
self.save_checkpoint(
|
|
300
|
-
os.path.join(
|
|
301
|
-
self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
|
|
302
|
-
)
|
|
303
|
-
)
|
|
304
|
-
elif stage == "end_of_epoch":
|
|
305
|
-
if (
|
|
306
|
-
self.checkpoint_save_interval == "epoch"
|
|
307
|
-
and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
|
|
308
|
-
):
|
|
309
|
-
self.save_checkpoint(
|
|
310
|
-
os.path.join(
|
|
311
|
-
self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
|
|
312
|
-
)
|
|
313
|
-
)
|
|
314
|
-
elif stage == "end_of_training":
|
|
315
|
-
# if the checkpoint has not been saved yet, save it
|
|
316
|
-
if self.global_step_idx > self._latest_saved_checkpoint_global_step:
|
|
317
|
-
self.save_checkpoint(
|
|
318
|
-
os.path.join(
|
|
319
|
-
self.log_dir,
|
|
320
|
-
"checkpoints",
|
|
321
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
322
|
-
)
|
|
323
|
-
)
|
|
324
|
-
try:
|
|
325
|
-
os.symlink(
|
|
326
|
-
os.path.join(
|
|
327
|
-
self.log_dir,
|
|
328
|
-
"checkpoints",
|
|
329
|
-
"latest_model.ckpt",
|
|
330
|
-
),
|
|
331
|
-
os.path.join(
|
|
332
|
-
self.log_dir,
|
|
333
|
-
"checkpoints",
|
|
334
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
335
|
-
),
|
|
336
|
-
)
|
|
337
|
-
except Exception as e:
|
|
338
|
-
pass
|
|
339
|
-
else:
|
|
340
|
-
raise ValueError(
|
|
341
|
-
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
342
|
-
)
|
|
276
|
+
self.global_step_idx += 1
|
|
277
|
+
accumulated_loss = 0
|
|
343
278
|
|
|
344
279
|
def save_checkpoint(
|
|
345
280
|
self,
|
|
@@ -351,24 +286,36 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
351
286
|
return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
|
|
352
287
|
|
|
353
288
|
fabric = self.fabric
|
|
354
|
-
state = {"model": self.model}
|
|
355
289
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
290
|
+
if self.save_ckpt_type == "lightning":
|
|
291
|
+
state = {"model": self.model}
|
|
292
|
+
|
|
293
|
+
# save the optimizer and lr_scheduler state if needed
|
|
294
|
+
if self.save_optimizer_state and save_optimizer_state is not False:
|
|
295
|
+
state.update(
|
|
296
|
+
{
|
|
297
|
+
"optimizer": self.optimizer,
|
|
298
|
+
"lr_scheduler": self.lr_scheduler,
|
|
299
|
+
"global_step_idx": self.global_step_idx,
|
|
300
|
+
"epoch_idx": self.epoch_idx,
|
|
301
|
+
}
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
trainable_param_names = set(
|
|
305
|
+
name
|
|
306
|
+
for name, param in self.model.state_dict(keep_vars=True).items()
|
|
307
|
+
if param.requires_grad
|
|
308
|
+
)
|
|
309
|
+
filter = (
|
|
310
|
+
None
|
|
311
|
+
if self.save_full_model
|
|
312
|
+
else {"model": lambda k, p: k in trainable_param_names}
|
|
365
313
|
)
|
|
366
314
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
315
|
+
fabric.save(path, state=state, filter=filter)
|
|
316
|
+
else:
|
|
317
|
+
self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
|
|
370
318
|
|
|
371
|
-
fabric.save(path, state=state, filter=filter)
|
|
372
319
|
self._latest_saved_checkpoint_global_step = self.global_step_idx
|
|
373
320
|
|
|
374
321
|
def load_checkpoint(self, path: Union[str, Path]):
|
|
@@ -401,3 +348,28 @@ def load_checkpoint(
|
|
|
401
348
|
state = {"model": model}
|
|
402
349
|
state.update(state_components)
|
|
403
350
|
fabric.load(ckpt_path, state=state, strict=strict)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
if __name__ == "__main__":
|
|
354
|
+
# convert a checkpoint to hf format
|
|
355
|
+
import argparse
|
|
356
|
+
|
|
357
|
+
parser = argparse.ArgumentParser()
|
|
358
|
+
parser.add_argument("--base-model-path", type=str)
|
|
359
|
+
parser.add_argument("--ckpt-path", type=str)
|
|
360
|
+
parser.add_argument("--output-path", type=str)
|
|
361
|
+
|
|
362
|
+
args = parser.parse_args()
|
|
363
|
+
|
|
364
|
+
fabric = L.Fabric(devices=1, strategy="fsdp")
|
|
365
|
+
fabric.launch()
|
|
366
|
+
|
|
367
|
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
|
|
368
|
+
tokenizer.save_pretrained(args.output_path)
|
|
369
|
+
|
|
370
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
371
|
+
args.base_model_path, torch_dtype=torch.bfloat16
|
|
372
|
+
)
|
|
373
|
+
model = fabric.setup_module(model)
|
|
374
|
+
load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
|
|
375
|
+
model.save_pretrained(args.output_path)
|