fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,10 @@ Reference:
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import nn
|
|
9
13
|
|
|
10
14
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
11
15
|
from fusion_bench.method import SimpleAverageAlgorithm
|
|
@@ -18,6 +22,41 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
18
22
|
log = logging.getLogger(__name__)
|
|
19
23
|
|
|
20
24
|
|
|
25
|
+
def expo_merge(
|
|
26
|
+
sft_model: nn.Module,
|
|
27
|
+
rlhf_model: nn.Module,
|
|
28
|
+
extrapolation_factor: float,
|
|
29
|
+
inplace: bool = True,
|
|
30
|
+
enable_grad: bool = False,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Minimal implementation of ExPO merge.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
sft_model (nn.Module): The pretrained model (base model).
|
|
37
|
+
rlhf_model (nn.Module): The finetuned model (medium-aligned model).
|
|
38
|
+
extrapolation_factor (float): The extrapolation factor.
|
|
39
|
+
inplace (bool): Whether to perform the merge in-place. If not, the rlhf_model will be copied before merging.
|
|
40
|
+
enable_grad (bool): Whether to enable gradient computation during the merge.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
nn.Module: The merged model.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
if not inplace:
|
|
47
|
+
rlhf_model = deepcopy(rlhf_model)
|
|
48
|
+
|
|
49
|
+
with torch.set_grad_enabled(enable_grad):
|
|
50
|
+
for (sft_name, sft_param), (rlhf_name, rlhf_param) in zip(
|
|
51
|
+
sft_model.named_parameters(), rlhf_model.named_parameters()
|
|
52
|
+
):
|
|
53
|
+
assert sft_name == rlhf_name, f"Model mismatch: {sft_name} != {rlhf_name}"
|
|
54
|
+
rlhf_param.data = rlhf_param.data + extrapolation_factor * (
|
|
55
|
+
rlhf_param.data - sft_param.data
|
|
56
|
+
)
|
|
57
|
+
return rlhf_model
|
|
58
|
+
|
|
59
|
+
|
|
21
60
|
class ExPOAlgorithm(BaseAlgorithm):
|
|
22
61
|
R"""
|
|
23
62
|
ExPO merge algorithm.
|
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
This is basically the same as fullfinetune_sft.py, but with a different loss function.
|
|
3
|
+
|
|
4
|
+
The dataset contains the following fields:
|
|
5
|
+
|
|
6
|
+
- chosen_input_ids: The input token ids for the winner.
|
|
7
|
+
- chosen_attention_mask: The attention mask for the winner.
|
|
8
|
+
- rejected_input_ids: The input token ids for the loser.
|
|
9
|
+
- rejected_attention_mask: The attention mask for the loser.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import functools
|
|
14
|
+
import itertools
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, override
|
|
19
|
+
|
|
20
|
+
import lightning as L
|
|
21
|
+
import omegaconf
|
|
22
|
+
import torch
|
|
23
|
+
from lightning.fabric.strategies.fsdp import FSDPStrategy
|
|
24
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
25
|
+
from omegaconf import DictConfig
|
|
26
|
+
from torch import Tensor, nn
|
|
27
|
+
from torch.utils.data import ConcatDataset, DataLoader
|
|
28
|
+
from tqdm.auto import tqdm
|
|
29
|
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
30
|
+
|
|
31
|
+
from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
|
|
32
|
+
from fusion_bench.method import BaseAlgorithm
|
|
33
|
+
from fusion_bench.mixins import FabricTrainingMixin
|
|
34
|
+
from fusion_bench.modelpool import SeqenceClassificationModelPool
|
|
35
|
+
from fusion_bench.utils import instantiate
|
|
36
|
+
from fusion_bench.utils.dtype import get_dtype
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from lightning.fabric.wrappers import (
|
|
40
|
+
_FabricDataLoader,
|
|
41
|
+
_FabricModule,
|
|
42
|
+
_FabricOptimizer,
|
|
43
|
+
)
|
|
44
|
+
from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
|
|
45
|
+
|
|
46
|
+
log = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):
|
|
50
|
+
|
|
51
|
+
model: Union[nn.Module, "_FabricModule", "LlamaForSequenceClassification"]
|
|
52
|
+
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
|
|
53
|
+
train_dataloader: Union[DataLoader, "_FabricDataLoader"]
|
|
54
|
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
optimizer: DictConfig,
|
|
59
|
+
lr_scheduler: Optional[DictConfig],
|
|
60
|
+
dataloader_kwargs: DictConfig,
|
|
61
|
+
max_epochs: int,
|
|
62
|
+
max_steps: int = -1,
|
|
63
|
+
max_steps_per_epoch: int = -1,
|
|
64
|
+
lr_scheduler_interval: Literal["epoch", "step"] = "step",
|
|
65
|
+
lr_scheduler_frequency: int = 1,
|
|
66
|
+
checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
|
|
67
|
+
checkpoint_save_frequency: int = 1,
|
|
68
|
+
accumulate_grad_batches: int = 1,
|
|
69
|
+
gradient_clip_val: Optional[float] = None,
|
|
70
|
+
gradient_clip_algorithm: Literal["value", "norm"] = "norm",
|
|
71
|
+
save_optimizer_state: bool = False,
|
|
72
|
+
save_full_model: bool = False,
|
|
73
|
+
save_ckpt_type: Literal["lightning", "hf"] = "lightning",
|
|
74
|
+
ckpt_path: Optional[str] = None,
|
|
75
|
+
max_length: int = 6144,
|
|
76
|
+
fix_token_embedding: bool = True,
|
|
77
|
+
**kwargs,
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Class for reward modeling using Bradley-Terry model.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
optimizer(DictConfig): Configuration for the optimizer.
|
|
84
|
+
lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
|
|
85
|
+
dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
|
|
86
|
+
max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
|
|
87
|
+
max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
|
|
88
|
+
max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
|
|
89
|
+
lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
|
|
90
|
+
lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
|
|
91
|
+
checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
|
|
92
|
+
checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
|
|
93
|
+
accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
|
|
94
|
+
gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
|
|
95
|
+
gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
|
|
96
|
+
save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
|
|
97
|
+
save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
|
|
98
|
+
save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
|
|
99
|
+
ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
|
|
100
|
+
max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
|
|
101
|
+
fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
|
|
102
|
+
"""
|
|
103
|
+
self._optimizer = optimizer
|
|
104
|
+
self._lr_scheduler = lr_scheduler
|
|
105
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
106
|
+
self.max_epochs = max_epochs
|
|
107
|
+
self.max_steps = max_steps
|
|
108
|
+
self.max_steps_per_epoch = max_steps_per_epoch
|
|
109
|
+
self.lr_scheduler_interval = lr_scheduler_interval
|
|
110
|
+
self.lr_scheduler_frequency = lr_scheduler_frequency
|
|
111
|
+
self.checkpoint_save_interval = checkpoint_save_interval
|
|
112
|
+
self.checkpoint_save_frequency = checkpoint_save_frequency
|
|
113
|
+
self.accumulate_grad_batches = accumulate_grad_batches
|
|
114
|
+
self.gradient_clip_val = gradient_clip_val
|
|
115
|
+
self.gradient_clip_algorithm = gradient_clip_algorithm
|
|
116
|
+
self.save_optimizer_state = save_optimizer_state
|
|
117
|
+
self.save_full_model = save_full_model
|
|
118
|
+
self.save_ckpt_type = save_ckpt_type
|
|
119
|
+
self.ckpt_path = ckpt_path
|
|
120
|
+
self.max_length = max_length
|
|
121
|
+
self.fix_token_embedding = fix_token_embedding
|
|
122
|
+
super().__init__(**kwargs)
|
|
123
|
+
|
|
124
|
+
def run(self, modelpool: SeqenceClassificationModelPool):
|
|
125
|
+
self.modelpool = modelpool
|
|
126
|
+
self.setup()
|
|
127
|
+
self.train(self.model, self.optimizer, self.lr_scheduler)
|
|
128
|
+
return self.model
|
|
129
|
+
|
|
130
|
+
def setup_model(self):
|
|
131
|
+
self.tokenizer = self.modelpool.load_tokenizer()
|
|
132
|
+
if self.tokenizer.pad_token_id is None:
|
|
133
|
+
self.tokenizer.pad_token_id = (
|
|
134
|
+
self.tokenizer.eos_token_id
|
|
135
|
+
) #! make sure eos_token_id only show up at the end of the sequence
|
|
136
|
+
|
|
137
|
+
model = self.modelpool.load_pretrained_model()
|
|
138
|
+
self.model: "LlamaForSequenceClassification" = model
|
|
139
|
+
|
|
140
|
+
if model.config.pad_token_id is None:
|
|
141
|
+
model.config.pad_token_id = self.tokenizer.pad_token_id
|
|
142
|
+
|
|
143
|
+
if self.fix_token_embedding:
|
|
144
|
+
self.model.model.embed_tokens.requires_grad_(False)
|
|
145
|
+
|
|
146
|
+
if self.fabric.strategy == "fsdp" or isinstance(
|
|
147
|
+
self.fabric.strategy, FSDPStrategy
|
|
148
|
+
):
|
|
149
|
+
# https://github.com/Lightning-AI/pytorch-lightning/issues/19267
|
|
150
|
+
self.model.gradient_checkpointing_enable(
|
|
151
|
+
gradient_checkpointing_kwargs={"use_reentrant": True}
|
|
152
|
+
)
|
|
153
|
+
self.use_cache = False
|
|
154
|
+
else:
|
|
155
|
+
self.use_cache = True
|
|
156
|
+
self.model_dtype = get_dtype(self.model)
|
|
157
|
+
|
|
158
|
+
def setup_data(self):
|
|
159
|
+
fabric = self.fabric
|
|
160
|
+
modelpool = self.modelpool
|
|
161
|
+
assert (
|
|
162
|
+
len(modelpool.train_dataset_names) > 0
|
|
163
|
+
), "No training datasets found in modelpool."
|
|
164
|
+
|
|
165
|
+
train_datasets = [
|
|
166
|
+
modelpool.load_train_dataset(dataset_name)
|
|
167
|
+
for dataset_name in modelpool.train_dataset_names
|
|
168
|
+
]
|
|
169
|
+
if len(train_datasets) > 1:
|
|
170
|
+
train_dataset = ConcatDataset(train_datasets)
|
|
171
|
+
else:
|
|
172
|
+
train_dataset = train_datasets[0]
|
|
173
|
+
|
|
174
|
+
self.train_dataset = train_dataset
|
|
175
|
+
self.train_dataloader = DataLoader(
|
|
176
|
+
train_dataset,
|
|
177
|
+
**self.dataloader_kwargs,
|
|
178
|
+
shuffle=True,
|
|
179
|
+
collate_fn=functools.partial(
|
|
180
|
+
bradley_terry_rm_collate,
|
|
181
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
182
|
+
), # NOTE: different from SFT, uses bradley_terry_rm_collate
|
|
183
|
+
)
|
|
184
|
+
self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
|
|
185
|
+
|
|
186
|
+
def configure_optimizer(self):
|
|
187
|
+
# compute expected total steps
|
|
188
|
+
self.compute_expected_total_steps(self.train_dataloader)
|
|
189
|
+
|
|
190
|
+
optimizer = instantiate(self._optimizer, self.model.parameters())
|
|
191
|
+
if self._lr_scheduler is not None:
|
|
192
|
+
for key, arg in self._lr_scheduler.items():
|
|
193
|
+
if arg == "_T_max_":
|
|
194
|
+
log.info(
|
|
195
|
+
f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
|
|
196
|
+
)
|
|
197
|
+
self._lr_scheduler[key] = self.expected_total_steps
|
|
198
|
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
|
|
199
|
+
self._lr_scheduler,
|
|
200
|
+
optimizer=optimizer,
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
lr_scheduler = None
|
|
204
|
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
|
|
205
|
+
|
|
206
|
+
def setup(self):
|
|
207
|
+
fabric = self.fabric
|
|
208
|
+
|
|
209
|
+
self.setup_model()
|
|
210
|
+
self.setup_data()
|
|
211
|
+
|
|
212
|
+
optimizer = self.configure_optimizer()
|
|
213
|
+
optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
|
|
214
|
+
|
|
215
|
+
self.model, self.optimizer = fabric.setup(self.model, optimizer)
|
|
216
|
+
self.lr_scheduler = lr_scheduler
|
|
217
|
+
|
|
218
|
+
def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
|
|
219
|
+
"""
|
|
220
|
+
Maximize the likelihood of the winner over the loser using the Bradley-Terry model.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
|
|
224
|
+
"""
|
|
225
|
+
batch_size = batch["input_ids"].size(0)
|
|
226
|
+
assert batch_size % 2 == 0, "Batch size must be even."
|
|
227
|
+
|
|
228
|
+
outputs = self.model(
|
|
229
|
+
input_ids=batch["input_ids"],
|
|
230
|
+
attention_mask=batch["attention_mask"],
|
|
231
|
+
use_cache=self.use_cache,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
rewards = outputs[0]
|
|
235
|
+
chosen_reward = rewards[: batch_size // 2]
|
|
236
|
+
rejected_rewards = rewards[batch_size // 2 :]
|
|
237
|
+
loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"chosen_reward": chosen_reward,
|
|
241
|
+
"rejected_reward": rejected_rewards,
|
|
242
|
+
"loss": loss,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
@override
|
|
246
|
+
def train_epoch(self, *args, **kwargs):
|
|
247
|
+
fabric = self.fabric
|
|
248
|
+
|
|
249
|
+
accumulated_loss = 0
|
|
250
|
+
accumulated_chosen_reward = 0
|
|
251
|
+
accumulated_rejected_reward = 0
|
|
252
|
+
for step_idx, batch in enumerate(
|
|
253
|
+
pbar := tqdm(
|
|
254
|
+
self.train_dataloader,
|
|
255
|
+
desc="Training Batches",
|
|
256
|
+
dynamic_ncols=True,
|
|
257
|
+
leave=False,
|
|
258
|
+
disable=not fabric.is_global_zero,
|
|
259
|
+
)
|
|
260
|
+
):
|
|
261
|
+
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
262
|
+
|
|
263
|
+
if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
|
|
264
|
+
log.warning(
|
|
265
|
+
f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
|
|
266
|
+
)
|
|
267
|
+
batch["input_ids"] = batch["input_ids"][:, -self.max_length :]
|
|
268
|
+
batch["attention_mask"] = batch["attention_mask"][:, -self.max_length :]
|
|
269
|
+
|
|
270
|
+
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
271
|
+
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
272
|
+
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
273
|
+
output = self.compute_loss(batch)
|
|
274
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
275
|
+
|
|
276
|
+
fabric.backward(loss)
|
|
277
|
+
|
|
278
|
+
accumulated_loss += loss.item()
|
|
279
|
+
accumulated_chosen_reward += output["chosen_reward"].mean().item()
|
|
280
|
+
accumulated_rejected_reward += output["rejected_reward"].mean().item()
|
|
281
|
+
|
|
282
|
+
# 1. update the model parameters if not accumulating gradients
|
|
283
|
+
# 2. step the lr_scheduler if interval is set to "step" and frequency is met
|
|
284
|
+
# 3. save the model if interval is set to "step" and frequency is met
|
|
285
|
+
# 4. log metrics
|
|
286
|
+
# 5. increase the global step index
|
|
287
|
+
if not is_accumulating:
|
|
288
|
+
self.clip_gradients_if_needed(self.model, self.optimizer)
|
|
289
|
+
|
|
290
|
+
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
291
|
+
if (
|
|
292
|
+
self.lr_scheduler_interval == "step"
|
|
293
|
+
and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
|
|
294
|
+
):
|
|
295
|
+
self.lr_scheduler.step()
|
|
296
|
+
|
|
297
|
+
# update the model parameters and zero the gradients
|
|
298
|
+
self.optimizer.step()
|
|
299
|
+
self.optimizer.zero_grad()
|
|
300
|
+
|
|
301
|
+
metrics = {
|
|
302
|
+
"train/loss": accumulated_loss,
|
|
303
|
+
"train/chosen_reward": accumulated_chosen_reward
|
|
304
|
+
/ self.accumulate_grad_batches,
|
|
305
|
+
"train/rejected_reward": accumulated_rejected_reward
|
|
306
|
+
/ self.accumulate_grad_batches,
|
|
307
|
+
"train/epoch_idx": self.epoch_idx,
|
|
308
|
+
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
309
|
+
}
|
|
310
|
+
metrics["train/chosen_reward-rejected_reward"] = (
|
|
311
|
+
metrics["train/chosen_reward"] - metrics["train/rejected_reward"]
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
315
|
+
pbar.set_postfix(metrics)
|
|
316
|
+
|
|
317
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
318
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
319
|
+
|
|
320
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
321
|
+
if (
|
|
322
|
+
self.max_steps_per_epoch > 0
|
|
323
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
324
|
+
):
|
|
325
|
+
break
|
|
326
|
+
# break if max_steps is set, and exit training
|
|
327
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
328
|
+
self.is_training = False
|
|
329
|
+
break
|
|
330
|
+
|
|
331
|
+
self.global_step_idx += 1
|
|
332
|
+
accumulated_loss = 0
|
|
333
|
+
accumulated_chosen_reward = 0
|
|
334
|
+
accumulated_rejected_reward = 0
|
|
335
|
+
|
|
336
|
+
def save_checkpoint(
|
|
337
|
+
self,
|
|
338
|
+
path: Union[str, Path],
|
|
339
|
+
save_optimizer_state: Optional[bool] = None,
|
|
340
|
+
overwrite: bool = False,
|
|
341
|
+
):
|
|
342
|
+
if not overwrite and os.path.exists(path):
|
|
343
|
+
return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
|
|
344
|
+
|
|
345
|
+
fabric = self.fabric
|
|
346
|
+
|
|
347
|
+
if self.save_ckpt_type == "lightning":
|
|
348
|
+
state = {"model": self.model}
|
|
349
|
+
|
|
350
|
+
# save the optimizer and lr_scheduler state if needed
|
|
351
|
+
if self.save_optimizer_state and save_optimizer_state is not False:
|
|
352
|
+
state.update(
|
|
353
|
+
{
|
|
354
|
+
"optimizer": self.optimizer,
|
|
355
|
+
"lr_scheduler": self.lr_scheduler,
|
|
356
|
+
"global_step_idx": self.global_step_idx,
|
|
357
|
+
"epoch_idx": self.epoch_idx,
|
|
358
|
+
}
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
trainable_param_names = set(
|
|
362
|
+
name
|
|
363
|
+
for name, param in self.model.state_dict(keep_vars=True).items()
|
|
364
|
+
if param.requires_grad
|
|
365
|
+
)
|
|
366
|
+
filter = (
|
|
367
|
+
None
|
|
368
|
+
if self.save_full_model
|
|
369
|
+
else {"model": lambda k, p: k in trainable_param_names}
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
fabric.save(path, state=state, filter=filter)
|
|
373
|
+
else:
|
|
374
|
+
self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
|
|
375
|
+
|
|
376
|
+
self._latest_saved_checkpoint_global_step = self.global_step_idx
|
|
377
|
+
|
|
378
|
+
def load_checkpoint(self, path: Union[str, Path]):
|
|
379
|
+
fabric = self.fabric
|
|
380
|
+
|
|
381
|
+
state = {"model": self.model}
|
|
382
|
+
|
|
383
|
+
# save the optimizer and lr_scheduler state if needed
|
|
384
|
+
if self.save_optimizer_state:
|
|
385
|
+
state.update(
|
|
386
|
+
{
|
|
387
|
+
"optimizer": self.optimizer,
|
|
388
|
+
"lr_scheduler": self.lr_scheduler,
|
|
389
|
+
}
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
fabric.load(path, state)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def load_checkpoint(
|
|
396
|
+
fabric: L.Fabric,
|
|
397
|
+
ckpt_path: Union[str, Path],
|
|
398
|
+
model: Union[nn.Module, "LlamaForSequenceClassification"],
|
|
399
|
+
strict: bool = True,
|
|
400
|
+
**state_components,
|
|
401
|
+
):
|
|
402
|
+
"""
|
|
403
|
+
Load a checkpoint into a model.
|
|
404
|
+
"""
|
|
405
|
+
state = {"model": model}
|
|
406
|
+
state.update(state_components)
|
|
407
|
+
fabric.load(ckpt_path, state=state, strict=strict)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
if __name__ == "__main__":
|
|
411
|
+
# convert a checkpoint to hf format
|
|
412
|
+
import argparse
|
|
413
|
+
|
|
414
|
+
parser = argparse.ArgumentParser()
|
|
415
|
+
parser.add_argument("--base-model-path", type=str)
|
|
416
|
+
parser.add_argument("--ckpt-path", type=str)
|
|
417
|
+
parser.add_argument("--output-path", type=str)
|
|
418
|
+
|
|
419
|
+
args = parser.parse_args()
|
|
420
|
+
|
|
421
|
+
fabric = L.Fabric(devices=1, strategy="fsdp")
|
|
422
|
+
fabric.launch()
|
|
423
|
+
|
|
424
|
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
|
|
425
|
+
tokenizer.save_pretrained(args.output_path)
|
|
426
|
+
|
|
427
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
428
|
+
args.base_model_path, num_labels=1, torch_dtype=torch.bfloat16
|
|
429
|
+
)
|
|
430
|
+
model = fabric.setup_module(model)
|
|
431
|
+
load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
|
|
432
|
+
model.save_pretrained(args.output_path)
|