fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from .lightning_fabric import LightningFabricMixin
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from lightning.fabric.wrappers import (
|
|
15
|
+
_FabricDataLoader,
|
|
16
|
+
_FabricModule,
|
|
17
|
+
_FabricOptimizer,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FabricTrainingMixin(LightningFabricMixin):
|
|
24
|
+
"""
|
|
25
|
+
This is a general purpose mixin for training a model with PyTorch Lightning.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_latest_saved_checkpoint_global_step: int = -1
|
|
29
|
+
"""The global step index of the latest saved checkpoint."""
|
|
30
|
+
_expected_total_steps: int = None
|
|
31
|
+
"""The expected total number of steps of the entire training."""
|
|
32
|
+
is_training: bool
|
|
33
|
+
"""Whether the training is in progress. If set to False, the training will stop."""
|
|
34
|
+
epoch_idx: int
|
|
35
|
+
"""The epoch index, which is the number of epochs completed."""
|
|
36
|
+
global_step_idx: int
|
|
37
|
+
"""The global step index, which is the number of parameter update steps."""
|
|
38
|
+
max_epochs: int
|
|
39
|
+
"""Max number of epochs of the entire training."""
|
|
40
|
+
max_steps: int
|
|
41
|
+
"""Max number of parameter update steps of the entire training."""
|
|
42
|
+
max_steps_per_epoch: int
|
|
43
|
+
"""Max number of parameter update steps per epoch."""
|
|
44
|
+
gradient_clip_algorithm: Literal["value", "norm"]
|
|
45
|
+
"""The algorithm to clip gradients. Available options: 'value', 'norm'."""
|
|
46
|
+
gradient_clip_val: float
|
|
47
|
+
"""The value to clip gradients. If None, no clipping is applied."""
|
|
48
|
+
accumulate_grad_batches: int
|
|
49
|
+
"""The number of gradient accumulation steps. The effective global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`."""
|
|
50
|
+
lr_scheduler_interval: Literal["step", "epoch"]
|
|
51
|
+
"""The interval to run the learning rate scheduler. Available options: 'step', 'epoch'."""
|
|
52
|
+
lr_scheduler_frequency: int
|
|
53
|
+
"""The frequency to run the learning rate scheduler."""
|
|
54
|
+
checkpoint_save_interval: Literal["step", "epoch"]
|
|
55
|
+
"""The interval to save the model checkpoint. Available options: 'step', 'epoch'."""
|
|
56
|
+
checkpoint_save_frequency: int
|
|
57
|
+
"""The frequency to save the model checkpoint."""
|
|
58
|
+
|
|
59
|
+
def clip_gradients_if_needed(self, model, optimizer):
|
|
60
|
+
"""
|
|
61
|
+
Clips gradients if the gradient clipping value is set.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
model (nn.Module): The model whose gradients need to be clipped.
|
|
65
|
+
optimizer (torch.optim.Optimizer): The optimizer used for training.
|
|
66
|
+
"""
|
|
67
|
+
fabric = self.fabric
|
|
68
|
+
|
|
69
|
+
if self.gradient_clip_val is not None:
|
|
70
|
+
if self.gradient_clip_algorithm == "value":
|
|
71
|
+
fabric.clip_gradients(model, optimizer, clip_val=self.gradient_clip_val)
|
|
72
|
+
elif self.gradient_clip_algorithm == "norm":
|
|
73
|
+
fabric.clip_gradients(model, optimizer, max_norm=self.gradient_clip_val)
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def compute_expected_total_steps(
|
|
80
|
+
self, train_dataloader: torch.utils.data.DataLoader
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Computes the expected total number of steps for the entire training.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
train_dataloader (torch.utils.data.DataLoader): The dataloader for the training data.
|
|
87
|
+
"""
|
|
88
|
+
# compute expected total steps
|
|
89
|
+
self._expected_total_steps = []
|
|
90
|
+
if self.max_steps > 0:
|
|
91
|
+
self._expected_total_steps.append(self.max_steps)
|
|
92
|
+
if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
|
|
93
|
+
self._expected_total_steps.append(
|
|
94
|
+
self.max_steps_per_epoch * self.max_epochs
|
|
95
|
+
)
|
|
96
|
+
if self.max_epochs > 0:
|
|
97
|
+
self._expected_total_steps.append(
|
|
98
|
+
len(train_dataloader) * self.max_epochs // self.accumulate_grad_batches
|
|
99
|
+
)
|
|
100
|
+
self._expected_total_steps = min(self._expected_total_steps)
|
|
101
|
+
log.info(f"Expected total steps: {self._expected_total_steps}")
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def expected_total_steps(self):
|
|
105
|
+
"""
|
|
106
|
+
The expected total number of steps of the entire training. You need to run `compute_expected_total_steps` method to compute this value before accessing it.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
ValueError: If the expected total steps have not been computed.
|
|
110
|
+
"""
|
|
111
|
+
if self._expected_total_steps is None:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"The expected total steps have not been computed. Run `compute_expected_total_steps` method."
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
return self._expected_total_steps
|
|
117
|
+
|
|
118
|
+
def conditional_checkpoint_save(
|
|
119
|
+
self,
|
|
120
|
+
stage: Literal["end_of_step", "end_of_epoch", "end_of_training"],
|
|
121
|
+
*args,
|
|
122
|
+
**kwargs,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Conditionally saves a checkpoint based on the current training stage.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
stage (Literal["end_of_step", "end_of_epoch", "end_of_training"]): The current stage of training.
|
|
129
|
+
"""
|
|
130
|
+
if stage == "end_of_step":
|
|
131
|
+
if (
|
|
132
|
+
self.checkpoint_save_interval == "step"
|
|
133
|
+
and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
|
|
134
|
+
):
|
|
135
|
+
save_path = os.path.join(
|
|
136
|
+
self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
|
|
137
|
+
)
|
|
138
|
+
self.save_checkpoint(save_path, *args, **kwargs)
|
|
139
|
+
elif stage == "end_of_epoch":
|
|
140
|
+
if (
|
|
141
|
+
self.checkpoint_save_interval == "epoch"
|
|
142
|
+
and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
|
|
143
|
+
):
|
|
144
|
+
save_path = os.path.join(
|
|
145
|
+
self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
|
|
146
|
+
)
|
|
147
|
+
self.save_checkpoint(save_path, *args, **kwargs)
|
|
148
|
+
elif stage == "end_of_training":
|
|
149
|
+
# if the checkpoint has not been saved yet, save it
|
|
150
|
+
if self.global_step_idx > self._latest_saved_checkpoint_global_step:
|
|
151
|
+
save_path = os.path.join(
|
|
152
|
+
self.log_dir,
|
|
153
|
+
"checkpoints",
|
|
154
|
+
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
155
|
+
)
|
|
156
|
+
self.save_checkpoint(save_path, *args, **kwargs)
|
|
157
|
+
try:
|
|
158
|
+
os.symlink(
|
|
159
|
+
src=save_path,
|
|
160
|
+
dst=os.path.join(
|
|
161
|
+
self.log_dir, "checkpoints", "latest_model.ckpt"
|
|
162
|
+
),
|
|
163
|
+
target_is_directory=os.path.isdir(save_path),
|
|
164
|
+
)
|
|
165
|
+
except Exception as e:
|
|
166
|
+
log.error(f"Failed to create symlink: {e}")
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
def save_checkpoint(self, path, **kwargs):
|
|
174
|
+
"""
|
|
175
|
+
Saves a checkpoint of the model.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
path (str): The path where the checkpoint will be saved.
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
NotImplementedError: If the method is not implemented.
|
|
182
|
+
"""
|
|
183
|
+
raise NotImplementedError("save_checkpoint method is not implemented")
|
|
184
|
+
|
|
185
|
+
def train(
|
|
186
|
+
self,
|
|
187
|
+
model: Union[nn.Module, "_FabricModule"],
|
|
188
|
+
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
|
|
189
|
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Trains the model.
|
|
193
|
+
|
|
194
|
+
The global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
model (Union[nn.Module, "_FabricModule"]): The model to be trained.
|
|
198
|
+
optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
|
|
199
|
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
|
|
200
|
+
"""
|
|
201
|
+
fabric = self.fabric
|
|
202
|
+
self.is_training = True
|
|
203
|
+
# number of parameter update iterations, not the number of batches
|
|
204
|
+
self.global_step_idx = 0
|
|
205
|
+
model.train()
|
|
206
|
+
optimizer.zero_grad()
|
|
207
|
+
for epoch_idx in tqdm(
|
|
208
|
+
range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
|
|
209
|
+
"Training Epoch",
|
|
210
|
+
dynamic_ncols=True,
|
|
211
|
+
leave=False,
|
|
212
|
+
disable=not fabric.is_global_zero,
|
|
213
|
+
):
|
|
214
|
+
self.epoch_idx = epoch_idx
|
|
215
|
+
self.train_epoch(model, optimizer, lr_scheduler)
|
|
216
|
+
# run lr_scheduler at the end of the epoch if interval is set to "epoch"
|
|
217
|
+
if (
|
|
218
|
+
self.lr_scheduler_interval == "epoch"
|
|
219
|
+
and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
|
|
220
|
+
):
|
|
221
|
+
lr_scheduler.step()
|
|
222
|
+
|
|
223
|
+
# save the model at the end of the epoch if interval is set to "epoch" and frequency is met
|
|
224
|
+
self.conditional_checkpoint_save(stage="end_of_epoch")
|
|
225
|
+
|
|
226
|
+
if not self.is_training:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
optimizer.zero_grad()
|
|
230
|
+
# save the model at the end of training
|
|
231
|
+
self.conditional_checkpoint_save(stage="end_of_training")
|
|
232
|
+
return model
|
|
233
|
+
|
|
234
|
+
@abstractmethod
|
|
235
|
+
def train_epoch(
|
|
236
|
+
self,
|
|
237
|
+
model: Union[nn.Module, "_FabricModule"],
|
|
238
|
+
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
|
|
239
|
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
240
|
+
):
|
|
241
|
+
"""
|
|
242
|
+
Trains the model for one epoch.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model (Union[nn.Module, "_FabricModule"]): The model to be trained.
|
|
246
|
+
optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
|
|
247
|
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
NotImplementedError: If the method is not implemented.
|
|
251
|
+
"""
|
|
252
|
+
raise NotImplementedError(
|
|
253
|
+
"Copy this as a template and implement your own train_epoch method"
|
|
254
|
+
)
|
|
255
|
+
fabric = self.fabric
|
|
256
|
+
|
|
257
|
+
accumulated_loss = 0
|
|
258
|
+
for step_idx, batch in enumerate(
|
|
259
|
+
pbar := tqdm(
|
|
260
|
+
self.train_dataloader,
|
|
261
|
+
desc="Training Batches",
|
|
262
|
+
dynamic_ncols=True,
|
|
263
|
+
leave=False,
|
|
264
|
+
disable=not fabric.is_global_zero,
|
|
265
|
+
)
|
|
266
|
+
):
|
|
267
|
+
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
268
|
+
|
|
269
|
+
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
270
|
+
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
271
|
+
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
272
|
+
output = self.compute_loss(batch)
|
|
273
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
274
|
+
|
|
275
|
+
fabric.backward(loss)
|
|
276
|
+
accumulated_loss += loss.item()
|
|
277
|
+
|
|
278
|
+
# 1. update the model parameters if not accumulating gradients
|
|
279
|
+
# 2. step the lr_scheduler if interval is set to "step" and frequency is met
|
|
280
|
+
# 3. save the model if interval is set to "step" and frequency is met
|
|
281
|
+
# 4. log metrics
|
|
282
|
+
# 5. increase the global step index and reset the accumulated metrics
|
|
283
|
+
if not is_accumulating:
|
|
284
|
+
self.clip_gradients_if_needed(model, optimizer)
|
|
285
|
+
|
|
286
|
+
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
287
|
+
if (
|
|
288
|
+
self.lr_scheduler_interval == "step"
|
|
289
|
+
and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
|
|
290
|
+
):
|
|
291
|
+
lr_scheduler.step()
|
|
292
|
+
|
|
293
|
+
# update the model parameters and zero the gradients
|
|
294
|
+
optimizer.step()
|
|
295
|
+
optimizer.zero_grad()
|
|
296
|
+
|
|
297
|
+
metrics = {
|
|
298
|
+
"train/loss": accumulated_loss,
|
|
299
|
+
"train/lr": optimizer.param_groups[0]["lr"],
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
303
|
+
pbar.set_postfix(metrics)
|
|
304
|
+
|
|
305
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
306
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
307
|
+
|
|
308
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
309
|
+
if (
|
|
310
|
+
self.max_steps_per_epoch > 0
|
|
311
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
312
|
+
):
|
|
313
|
+
break
|
|
314
|
+
# break if max_steps is set, and exit training
|
|
315
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
316
|
+
self.is_training = False
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
self.global_step_idx += 1
|
|
320
|
+
accumulated_loss = 0
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
4
5
|
|
|
5
6
|
import lightning as L
|
|
6
7
|
import torch
|
|
@@ -8,11 +9,12 @@ from lightning.fabric.loggers import TensorBoardLogger
|
|
|
8
9
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
9
10
|
from omegaconf import DictConfig, OmegaConf
|
|
10
11
|
|
|
11
|
-
from fusion_bench.utils.instantiate import instantiate
|
|
12
12
|
from fusion_bench.utils import import_object
|
|
13
|
+
from fusion_bench.utils.instantiate import instantiate
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
import lightning.fabric.loggers.tensorboard
|
|
17
|
+
from lightning.fabric.strategies import FSDPStrategy
|
|
16
18
|
|
|
17
19
|
log = logging.getLogger(__name__)
|
|
18
20
|
|
|
@@ -32,6 +34,13 @@ def get_policy(*args: str) -> set:
|
|
|
32
34
|
return {import_object(arg) for arg in args}
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
def get_size_based_auto_wrap_policy(*args, **kwargs):
|
|
38
|
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
|
39
|
+
|
|
40
|
+
policy = functools.partial(size_based_auto_wrap_policy, *args, **kwargs)
|
|
41
|
+
return policy
|
|
42
|
+
|
|
43
|
+
|
|
35
44
|
class LightningFabricMixin:
|
|
36
45
|
"""
|
|
37
46
|
A mixin class for integrating Lightning Fabric into a project.
|
|
@@ -16,6 +16,7 @@ _import_structure = {
|
|
|
16
16
|
"HuggingFaceGPT2ClassificationPool",
|
|
17
17
|
"GPT2ForSequenceClassificationPool",
|
|
18
18
|
],
|
|
19
|
+
"seq_classification_lm": ["SeqenceClassificationModelPool"],
|
|
19
20
|
}
|
|
20
21
|
|
|
21
22
|
|
|
@@ -31,6 +32,7 @@ if TYPE_CHECKING:
|
|
|
31
32
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
32
33
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
33
34
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
35
|
+
from .seq_classification_lm import SeqenceClassificationModelPool
|
|
34
36
|
|
|
35
37
|
else:
|
|
36
38
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
from .causal_lm import CausalLMBackbonePool, CausalLMPool
|
|
2
|
+
from .causal_lm import CausalLMBackbonePool, CausalLMPool, load_peft_causal_lm
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from copy import deepcopy
|
|
4
4
|
from typing import Any, Optional, TypeAlias, Union, cast # noqa: F401
|
|
5
5
|
|
|
6
|
+
import peft
|
|
6
7
|
from omegaconf import DictConfig, flag_override
|
|
7
8
|
from torch import nn
|
|
8
9
|
from torch.nn.modules import Module
|
|
@@ -23,28 +24,6 @@ log = logging.getLogger(__name__)
|
|
|
23
24
|
CausalLM: TypeAlias = Union[LlamaForCausalLM, MistralForCausalLM, Any]
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
def config_priority_get(priority_config, general_config, key, default):
|
|
27
|
-
"""
|
|
28
|
-
Retrieve a configuration value with priority.
|
|
29
|
-
|
|
30
|
-
This function retrieves the value associated with `key` from `priority_config` if it exists.
|
|
31
|
-
If the key is not found in `priority_config`, it retrieves the value from `general_config`.
|
|
32
|
-
If the key is not found in either configuration, it returns the provided `default` value.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
priority_config (dict): The configuration dictionary with higher priority.
|
|
36
|
-
general_config (dict): The general configuration dictionary.
|
|
37
|
-
key (str): The key to look up in the configuration dictionaries.
|
|
38
|
-
default: The default value to return if the key is not found in either configuration.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
|
|
42
|
-
"""
|
|
43
|
-
if key in priority_config:
|
|
44
|
-
return priority_config[key]
|
|
45
|
-
return general_config.get(key, default)
|
|
46
|
-
|
|
47
|
-
|
|
48
27
|
class CausalLMPool(BaseModelPool):
|
|
49
28
|
_config_mapping = BaseModelPool._config_mapping | {
|
|
50
29
|
"_tokenizer": "tokenizer",
|
|
@@ -138,3 +117,23 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
138
117
|
model_name_or_config, *args, **kwargs
|
|
139
118
|
)
|
|
140
119
|
return model.model.layers
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def load_peft_causal_lm(
|
|
123
|
+
base_model_path: str,
|
|
124
|
+
peft_model_path: str,
|
|
125
|
+
torch_dtype: str = "bfloat16",
|
|
126
|
+
is_trainable: bool = True,
|
|
127
|
+
merge_and_unload: bool = False,
|
|
128
|
+
):
|
|
129
|
+
base_model = LlamaForCausalLM.from_pretrained(
|
|
130
|
+
base_model_path, torch_dtype=torch_dtype
|
|
131
|
+
)
|
|
132
|
+
model = peft.PeftModel.from_pretrained(
|
|
133
|
+
base_model,
|
|
134
|
+
peft_model_path,
|
|
135
|
+
is_trainable=is_trainable,
|
|
136
|
+
)
|
|
137
|
+
if merge_and_unload:
|
|
138
|
+
model = model.merge_and_unload()
|
|
139
|
+
return model
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from transformers import AutoModelForSequenceClassification
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_reward_model_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
|
|
5
|
+
"""
|
|
6
|
+
Create a reward model for reward modeling (RLHF).
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
pretrained_model_name_or_path (str): The name or path of the pretrained model.
|
|
10
|
+
**kwargs: Additional keyword arguments passed to the model class.
|
|
11
|
+
"""
|
|
12
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
13
|
+
pretrained_model_name_or_path, num_labels=1, **kwargs
|
|
14
|
+
)
|
|
15
|
+
return model
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union, cast # noqa: F401
|
|
5
|
+
|
|
6
|
+
from omegaconf import DictConfig, flag_override
|
|
7
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils import instantiate
|
|
12
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from transformers import LlamaForSequenceClassification
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SeqenceClassificationModelPool(BaseModelPool):
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
models,
|
|
25
|
+
*,
|
|
26
|
+
tokenizer: Optional[DictConfig],
|
|
27
|
+
model_kwargs: Optional[DictConfig] = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
super().__init__(models, **kwargs)
|
|
31
|
+
# process `model_kwargs`
|
|
32
|
+
self._tokenizer = tokenizer
|
|
33
|
+
self._model_kwargs = model_kwargs
|
|
34
|
+
if self._model_kwargs is None:
|
|
35
|
+
self._model_kwargs = DictConfig({})
|
|
36
|
+
with flag_override(self._model_kwargs, "allow_objects", True):
|
|
37
|
+
if hasattr(self._model_kwargs, "torch_dtype"):
|
|
38
|
+
self._model_kwargs.torch_dtype = parse_dtype(
|
|
39
|
+
self._model_kwargs.torch_dtype
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def load_model(
|
|
44
|
+
self,
|
|
45
|
+
model_name_or_config: str | DictConfig,
|
|
46
|
+
*args,
|
|
47
|
+
**kwargs,
|
|
48
|
+
) -> Union[PreTrainedModel, "LlamaForSequenceClassification"]:
|
|
49
|
+
model_kwargs = deepcopy(self._model_kwargs)
|
|
50
|
+
model_kwargs.update(kwargs)
|
|
51
|
+
if isinstance(model_name_or_config, str):
|
|
52
|
+
log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
53
|
+
return super().load_model(model_name_or_config, *args, **model_kwargs)
|
|
54
|
+
|
|
55
|
+
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
56
|
+
assert self._tokenizer is not None, "Tokenizer is not defined in the config"
|
|
57
|
+
log.info("Loading tokenizer.", stacklevel=2)
|
|
58
|
+
tokenizer = instantiate(self._tokenizer, *args, **kwargs)
|
|
59
|
+
return tokenizer
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def save_model(
|
|
63
|
+
self,
|
|
64
|
+
model: PreTrainedModel,
|
|
65
|
+
path: str,
|
|
66
|
+
push_to_hub: bool = False,
|
|
67
|
+
model_dtype: Optional[str] = None,
|
|
68
|
+
save_tokenizer: bool = False,
|
|
69
|
+
tokenizer_kwargs=None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Save the model to the specified path.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
model (PreTrainedModel): The model to be saved.
|
|
77
|
+
path (str): The path where the model will be saved.
|
|
78
|
+
push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
|
|
79
|
+
save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
|
|
80
|
+
**kwargs: Additional keyword arguments passed to the `save_pretrained` method.
|
|
81
|
+
"""
|
|
82
|
+
path = os.path.expanduser(path)
|
|
83
|
+
if save_tokenizer:
|
|
84
|
+
if tokenizer_kwargs is None:
|
|
85
|
+
tokenizer_kwargs = {}
|
|
86
|
+
# load the tokenizer
|
|
87
|
+
tokenizer = self.load_tokenizer(**tokenizer_kwargs)
|
|
88
|
+
tokenizer.save_pretrained(
|
|
89
|
+
path,
|
|
90
|
+
push_to_hub=push_to_hub,
|
|
91
|
+
)
|
|
92
|
+
if model_dtype is not None:
|
|
93
|
+
model.to(dtype=parse_dtype(model_dtype))
|
|
94
|
+
model.save_pretrained(
|
|
95
|
+
path,
|
|
96
|
+
push_to_hub=push_to_hub,
|
|
97
|
+
**kwargs,
|
|
98
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .load_tokenizer import chat_template_mapping, load_tokenizer_with_chat_template
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
CHAT_TEMPLATE = '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\'+ message[\'content\'] | trim + \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n'
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
|
|
5
|
+
from .llama_3_Instruct import CHAT_TEMPLATE as LLAMA_3_INSTRUCT_CHAT_TEMPLATE
|
|
6
|
+
|
|
7
|
+
chat_template_mapping = {"llama_3_instruct": LLAMA_3_INSTRUCT_CHAT_TEMPLATE}
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_tokenizer_with_chat_template(
|
|
13
|
+
pretrained_model_name_or_path: str,
|
|
14
|
+
model_family: str,
|
|
15
|
+
overwrite_chat_template: bool = True,
|
|
16
|
+
**kwargs,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Load the tokenizer for Llama 3 model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
pretrained_model_name_or_path (str): The name or path of the pretrained model.
|
|
23
|
+
model_family (str): The model family.
|
|
24
|
+
**kwargs: Additional keyword arguments passed to the tokenizer class.
|
|
25
|
+
"""
|
|
26
|
+
assert (
|
|
27
|
+
model_family in chat_template_mapping
|
|
28
|
+
), f"Model family {model_family} not found. Available model families: {chat_template_mapping.keys()}"
|
|
29
|
+
|
|
30
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
31
|
+
pretrained_model_name_or_path,
|
|
32
|
+
**kwargs,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if tokenizer.chat_template is None:
|
|
36
|
+
tokenizer.chat_template = chat_template_mapping[model_family]
|
|
37
|
+
else:
|
|
38
|
+
if overwrite_chat_template:
|
|
39
|
+
log.warning("Overwriting the chat template with the default chat template.")
|
|
40
|
+
tokenizer.chat_template = chat_template_mapping[model_family]
|
|
41
|
+
else:
|
|
42
|
+
log.warning("Chat template already exists. Skipping overwriting.")
|
|
43
|
+
return tokenizer
|