project-llm-trainer 0.7.8__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +41 -23
- llm_trainer/grpo_trainer.py +3 -3
- llm_trainer/loss.py +6 -10
- llm_trainer/sft_trainer.py +3 -3
- llm_trainer/tokenizer.py +16 -1
- llm_trainer/train_configs.py +5 -4
- llm_trainer/trainer.py +53 -40
- llm_trainer/utils.py +36 -11
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/ds_train +7 -6
- {project_llm_trainer-0.7.8.dist-info → project_llm_trainer-0.8.1.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.8.dist-info → project_llm_trainer-0.8.1.dist-info}/RECORD +19 -19
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.7.8.dist-info → project_llm_trainer-0.8.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.8.dist-info → project_llm_trainer-0.8.1.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -12,7 +12,8 @@ from .loss import DPOLoss
|
|
|
12
12
|
from .tools import TrainerTools
|
|
13
13
|
from .utils import (
|
|
14
14
|
autocast,
|
|
15
|
-
get_dpo_collate_fn
|
|
15
|
+
get_dpo_collate_fn,
|
|
16
|
+
fill_loss_mask
|
|
16
17
|
)
|
|
17
18
|
from .partition_utils import sync_model_params
|
|
18
19
|
|
|
@@ -69,12 +70,12 @@ class DPOTrainer(Trainer):
|
|
|
69
70
|
|
|
70
71
|
return criterion, None
|
|
71
72
|
|
|
72
|
-
def _convert_train_args(self) -> Tuple[dict, dict, dict
|
|
73
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
73
74
|
dpo_collate_fn = get_dpo_collate_fn(self.train_config.mask_prompt)
|
|
74
|
-
parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
75
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
75
76
|
data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
|
|
76
77
|
|
|
77
|
-
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
78
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
78
79
|
|
|
79
80
|
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
80
81
|
file_path = self.train_config.file_dataset[file_idx]
|
|
@@ -84,7 +85,6 @@ class DPOTrainer(Trainer):
|
|
|
84
85
|
def _calc_loss(self, inputs, attention_mask, logits, labels): ...
|
|
85
86
|
|
|
86
87
|
def _log_probs_from_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
87
|
-
# https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
|
|
88
88
|
if logits.dtype in [torch.float32, torch.float64]:
|
|
89
89
|
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
|
90
90
|
logsumexp_values = torch.stack(
|
|
@@ -102,25 +102,26 @@ class DPOTrainer(Trainer):
|
|
|
102
102
|
return log_probs_labels
|
|
103
103
|
|
|
104
104
|
|
|
105
|
-
def _logprobs(self, logits, labels,
|
|
105
|
+
def _logprobs(self, logits, labels, attention_mask):
|
|
106
106
|
"""
|
|
107
107
|
Calculate the average log probabilities for a batch of sequences.
|
|
108
108
|
|
|
109
109
|
Args:
|
|
110
110
|
logits (torch.Tensor): Logits from the model with shape (B, T, V)
|
|
111
111
|
labels (torch.Tensor): Ground truth labels with shape (B, T).
|
|
112
|
-
|
|
112
|
+
attention_mask (torch.Tensor): Mask tensor with shape (B, T) indicating
|
|
113
113
|
which tokens are not padding (1 for valid tokens, 0 for padding).
|
|
114
114
|
|
|
115
115
|
Returns:
|
|
116
116
|
torch.Tensor: Average log probabilities for each sequence in the batch.
|
|
117
117
|
Shape is (B,) representing the mean log probability for each sequence.
|
|
118
118
|
"""
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
loss_masks = attention_mask.clone().bool()
|
|
120
|
+
loss_masks = fill_loss_mask(loss_masks, labels)
|
|
121
121
|
|
|
122
|
-
|
|
123
|
-
|
|
122
|
+
logits = logits[:, :-1, :]
|
|
123
|
+
labels = labels[:, 1:].clone()
|
|
124
|
+
loss_masks = loss_masks[:, 1:]
|
|
124
125
|
|
|
125
126
|
# dummy token; we'll ignore the losses on these tokens later
|
|
126
127
|
labels[labels == -100] = 0
|
|
@@ -129,11 +130,10 @@ class DPOTrainer(Trainer):
|
|
|
129
130
|
per_token_logps = self._log_probs_from_logits(logits, labels)
|
|
130
131
|
|
|
131
132
|
# Apply the mask to set log-probs of padding tokens to 0
|
|
132
|
-
logprobs_sums = (per_token_logps *
|
|
133
|
-
|
|
134
|
-
# logprobs_means = (per_token_logps * mask).sum(-1) / mask.sum(-1)
|
|
133
|
+
logprobs_sums = (per_token_logps * loss_masks).sum(-1)
|
|
134
|
+
logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
|
|
135
135
|
|
|
136
|
-
return logprobs_sums
|
|
136
|
+
return logprobs_sums, logprobs_means
|
|
137
137
|
|
|
138
138
|
def train(self):
|
|
139
139
|
# 梯度累积步数
|
|
@@ -147,6 +147,7 @@ class DPOTrainer(Trainer):
|
|
|
147
147
|
last_best_checkpoint_loss: Optional[float] = None
|
|
148
148
|
|
|
149
149
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
150
|
+
nll_loss_coef = self.train_config.dpo_config.nll_loss_coef
|
|
150
151
|
|
|
151
152
|
for epoch in range(self.train_config.n_epochs):
|
|
152
153
|
self.train_model.train()
|
|
@@ -188,36 +189,53 @@ class DPOTrainer(Trainer):
|
|
|
188
189
|
try:
|
|
189
190
|
chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
|
|
190
191
|
chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
|
|
192
|
+
|
|
191
193
|
rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
|
|
192
194
|
rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
|
|
193
195
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
+
chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
|
|
197
|
+
rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
|
|
196
198
|
|
|
197
199
|
# 在batch维度concat
|
|
198
200
|
# [chosen, chosen, reject, reject]
|
|
199
201
|
concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
|
|
200
202
|
concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
|
|
201
|
-
|
|
203
|
+
concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
|
|
202
204
|
|
|
203
205
|
if TrainerTools().parallel.parallel_train:
|
|
204
206
|
self.train_model.require_backward_grad_sync = need_update_grad
|
|
205
207
|
|
|
206
208
|
with autocast(TrainerTools().parallel.device_type):
|
|
207
|
-
policy_outputs = self.train_model(concat_inputs, attention_mask=
|
|
208
|
-
|
|
209
|
+
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
|
|
210
|
+
policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels, concat_attention_masks)
|
|
209
211
|
aux_loss = policy_outputs.get('aux_loss')
|
|
210
212
|
|
|
211
213
|
with torch.no_grad():
|
|
212
|
-
ref_outputs = self.ref_model(concat_inputs, attention_mask=
|
|
213
|
-
|
|
214
|
+
ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
|
|
215
|
+
ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels, concat_attention_masks)
|
|
216
|
+
|
|
217
|
+
policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
|
|
218
|
+
policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
|
|
219
|
+
|
|
220
|
+
ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
|
|
221
|
+
ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
|
|
222
|
+
|
|
223
|
+
nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
|
|
214
224
|
|
|
215
225
|
# calc loss
|
|
216
|
-
loss = self.criterion(
|
|
226
|
+
loss = self.criterion(
|
|
227
|
+
policy_chosen_logps,
|
|
228
|
+
policy_rejected_logps,
|
|
229
|
+
ref_chosen_logps,
|
|
230
|
+
ref_rejected_logps
|
|
231
|
+
)
|
|
217
232
|
|
|
218
233
|
if aux_loss_coef and aux_loss:
|
|
219
234
|
loss += aux_loss_coef * aux_loss
|
|
220
235
|
|
|
236
|
+
if nll_loss_coef and nll_loss:
|
|
237
|
+
loss += nll_loss_coef * nll_loss
|
|
238
|
+
|
|
221
239
|
if gradient_accumulation_steps > 1:
|
|
222
240
|
loss = loss / gradient_accumulation_steps
|
|
223
241
|
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -82,11 +82,11 @@ class GRPOTrainer(Trainer):
|
|
|
82
82
|
|
|
83
83
|
return criterion, None
|
|
84
84
|
|
|
85
|
-
def _convert_train_args(self) -> Tuple[dict, dict, dict
|
|
86
|
-
parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
85
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
86
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
87
87
|
data_loader_kwargs.update({"collate_fn": lambda x: x})
|
|
88
88
|
|
|
89
|
-
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
89
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
90
90
|
|
|
91
91
|
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
92
92
|
file_path = self.train_config.file_dataset[file_idx]
|
llm_trainer/loss.py
CHANGED
|
@@ -92,17 +92,13 @@ class DPOLoss(nn.Module):
|
|
|
92
92
|
|
|
93
93
|
def forward(
|
|
94
94
|
self,
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
policy_chosen_logps: torch.Tensor,
|
|
96
|
+
policy_reject_logps: torch.Tensor,
|
|
97
|
+
ref_chosen_logps: torch.Tensor,
|
|
98
|
+
ref_reject_logps: torch.Tensor
|
|
97
99
|
) -> torch.Tensor:
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
ref_reject_probs = reference_logps[batch_size//2:]
|
|
101
|
-
policy_chosen_probs = policy_logps[:batch_size//2]
|
|
102
|
-
policy_reject_probs = policy_logps[batch_size//2:]
|
|
103
|
-
|
|
104
|
-
pi_logratios = policy_chosen_probs - policy_reject_probs
|
|
105
|
-
ref_logratios = ref_chosen_probs - ref_reject_probs
|
|
100
|
+
pi_logratios = policy_chosen_logps - policy_reject_logps
|
|
101
|
+
ref_logratios = ref_chosen_logps - ref_reject_logps
|
|
106
102
|
logits = pi_logratios - ref_logratios
|
|
107
103
|
|
|
108
104
|
if self.ipo:
|
llm_trainer/sft_trainer.py
CHANGED
|
@@ -23,12 +23,12 @@ class SFTTrainer(Trainer):
|
|
|
23
23
|
)
|
|
24
24
|
self.packed_sequences = False
|
|
25
25
|
|
|
26
|
-
def _convert_train_args(self) -> Tuple[dict, dict, dict
|
|
26
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
27
27
|
sft_collate_fn = get_sft_collate_fn(self.train_config.mask_prompt)
|
|
28
|
-
parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
28
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
29
29
|
data_loader_kwargs.update({"collate_fn": sft_collate_fn})
|
|
30
30
|
|
|
31
|
-
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
31
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
32
32
|
|
|
33
33
|
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
34
34
|
file_path = self.train_config.file_dataset[file_idx]
|
llm_trainer/tokenizer.py
CHANGED
|
@@ -3,7 +3,7 @@ import warnings
|
|
|
3
3
|
from typing import List, Dict, Union
|
|
4
4
|
from transformers import Qwen2TokenizerFast
|
|
5
5
|
from transformers import AddedToken
|
|
6
|
-
from transformers import
|
|
6
|
+
from transformers import LlamaTokenizerFast
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
TOKEN_TYPE_QWEN = 'qwen'
|
|
@@ -164,3 +164,18 @@ class Tokenizer:
|
|
|
164
164
|
|
|
165
165
|
return chat_template
|
|
166
166
|
|
|
167
|
+
def get_special_tokens_dict(self):
|
|
168
|
+
return {
|
|
169
|
+
self.text_end: self.end,
|
|
170
|
+
self.text_pad: self.pad,
|
|
171
|
+
self.text_unk: self.unk,
|
|
172
|
+
self.text_user: self.user,
|
|
173
|
+
self.text_assistant: self.assistant,
|
|
174
|
+
self.text_think_start: self.think_start,
|
|
175
|
+
self.text_think_end: self.think_end,
|
|
176
|
+
self.text_answer_start: self.answer_start,
|
|
177
|
+
self.text_answer_end: self.answer_end,
|
|
178
|
+
self.text_system: self.system,
|
|
179
|
+
self.text_image: self.image,
|
|
180
|
+
}
|
|
181
|
+
|
llm_trainer/train_configs.py
CHANGED
|
@@ -107,7 +107,8 @@ class DataLoaderConfig:
|
|
|
107
107
|
|
|
108
108
|
|
|
109
109
|
@dataclass(kw_only=True)
|
|
110
|
-
class
|
|
110
|
+
class OptimConfig:
|
|
111
|
+
optim_type: str = 'adam' # or 'lion'
|
|
111
112
|
enable_lr_scheduler: bool = False
|
|
112
113
|
initial_lr: float
|
|
113
114
|
weight_decay: float = 0.1
|
|
@@ -195,8 +196,8 @@ class TrainConfig:
|
|
|
195
196
|
grpo训练时不生效该配置!
|
|
196
197
|
eval_batch_interval (`int`, default is 100):
|
|
197
198
|
每隔多少个batch进行模型eval
|
|
198
|
-
|
|
199
|
-
|
|
199
|
+
optim_config (`OptimConfig`):
|
|
200
|
+
optim配置项
|
|
200
201
|
data_loader_config: (`DataLoaderConfig`):
|
|
201
202
|
data loader配置项
|
|
202
203
|
kd_config: (`KDConfig`, *Optional*, default is None):
|
|
@@ -213,7 +214,7 @@ class TrainConfig:
|
|
|
213
214
|
image_tags_file_dataset: Optional[FileDataset] = None
|
|
214
215
|
|
|
215
216
|
loss_config: LossConfig = field(default_factory=LossConfig)
|
|
216
|
-
|
|
217
|
+
optim_config: OptimConfig = field(default_factory=OptimConfig)
|
|
217
218
|
|
|
218
219
|
ds_config: DsConfig = field(default_factory=DsConfig)
|
|
219
220
|
|
llm_trainer/trainer.py
CHANGED
|
@@ -77,19 +77,15 @@ class Trainer:
|
|
|
77
77
|
if self.eval_image_tags:
|
|
78
78
|
assert len(self.eval_prompts) == len(self.eval_image_tags)
|
|
79
79
|
|
|
80
|
-
parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
81
|
-
self.parallel_kwargs = parallel_kwargs
|
|
82
|
-
self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
|
|
83
|
-
self.sampler_kwargs: dict[str, Any] = sampler_kwargs
|
|
84
|
-
|
|
80
|
+
self.parallel_kwargs, self.data_loader_kwargs, self.sampler_kwargs = self._convert_train_args()
|
|
85
81
|
# initialize a GradScaler. If enabled=False scaler is a no-op
|
|
86
82
|
self.scalar = torch.GradScaler(enabled=TrainerTools().use_amp)
|
|
87
83
|
|
|
88
84
|
# 注意:学习率要根据GPU的数量进行倍增:
|
|
89
85
|
# 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
|
|
90
|
-
initial_lr = train_config.
|
|
86
|
+
initial_lr = train_config.optim_config.initial_lr
|
|
91
87
|
|
|
92
|
-
self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr
|
|
88
|
+
self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr)
|
|
93
89
|
self.lr_scheduler = self._init_lr_scheduler(initial_lr)
|
|
94
90
|
|
|
95
91
|
self.criterion, self.kd_loss = self._init_loss()
|
|
@@ -127,12 +123,7 @@ class Trainer:
|
|
|
127
123
|
freeze_llm_model = self.train_config.freeze_llm_model
|
|
128
124
|
return model.parameters() if not freeze_llm_model else filter(lambda p: p.requires_grad, model.parameters())
|
|
129
125
|
|
|
130
|
-
def _init_train_model_and_optim(
|
|
131
|
-
self,
|
|
132
|
-
initial_lr: float,
|
|
133
|
-
parallel_kwargs: dict,
|
|
134
|
-
use_ds_optim: bool
|
|
135
|
-
):
|
|
126
|
+
def _init_train_model_and_optim(self, initial_lr: float):
|
|
136
127
|
model = self._new_model(self.train_config)
|
|
137
128
|
|
|
138
129
|
if self.train_config.init_state_dict:
|
|
@@ -161,34 +152,58 @@ class Trainer:
|
|
|
161
152
|
total_size_mb = total_size_bytes / (1024 * 1024)
|
|
162
153
|
log(f"Total size of the model: {total_size_mb:.2f} MB")
|
|
163
154
|
|
|
164
|
-
if use_ds_optim:
|
|
165
|
-
import deepspeed
|
|
166
|
-
origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
|
167
|
-
self._get_trainable_params(model),
|
|
168
|
-
lr=initial_lr,
|
|
169
|
-
weight_decay=self.train_config.lr_config.weight_decay
|
|
170
|
-
)
|
|
171
|
-
else:
|
|
172
|
-
origin_optim = torch.optim.AdamW(
|
|
173
|
-
self._get_trainable_params(model),
|
|
174
|
-
lr=initial_lr,
|
|
175
|
-
weight_decay=self.train_config.lr_config.weight_decay
|
|
176
|
-
)
|
|
177
155
|
model, optim = TrainerTools().parallel.process(
|
|
178
156
|
model=model,
|
|
179
|
-
optimizer=
|
|
180
|
-
kwargs=parallel_kwargs
|
|
157
|
+
optimizer=self._get_optim(model, initial_lr),
|
|
158
|
+
kwargs=self.parallel_kwargs
|
|
181
159
|
)
|
|
182
160
|
|
|
183
161
|
return model, optim
|
|
184
162
|
|
|
163
|
+
def _get_optim(self, model, initial_lr):
|
|
164
|
+
optimizer = None
|
|
165
|
+
|
|
166
|
+
if isinstance(TrainerTools().parallel, DsParallel) and self.parallel_kwargs:
|
|
167
|
+
import deepspeed
|
|
168
|
+
if ('zero_optimization' in self.parallel_kwargs
|
|
169
|
+
and 'offload_optimizer' in self.parallel_kwargs['zero_optimization']
|
|
170
|
+
and self.parallel_kwargs['zero_optimization']['offload_optimizer']['device'] == 'cpu'):
|
|
171
|
+
# offline optimizer to cpu
|
|
172
|
+
# 不能使用 deepspeed.ops.lion.cpu_lion.DeepSpeedCPULion???
|
|
173
|
+
# 所以,这里忽略lion判断
|
|
174
|
+
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
|
|
175
|
+
if self.train_config.optim_config.optim_type == 'lion':
|
|
176
|
+
log('When set offload_optimizer, lion optim is unsupported, so set optim to adam!!!!!')
|
|
177
|
+
else:
|
|
178
|
+
if self.train_config.optim_config.optim_type == 'lion':
|
|
179
|
+
optimizer = deepspeed.ops.lion.FusedLion
|
|
180
|
+
else:
|
|
181
|
+
optimizer = deepspeed.ops.adam.FusedAdam
|
|
182
|
+
|
|
183
|
+
if not optimizer:
|
|
184
|
+
if self.train_config.optim_config.optim_type == 'lion':
|
|
185
|
+
try:
|
|
186
|
+
import lion_pytorch
|
|
187
|
+
except:
|
|
188
|
+
raise Exception('lion is not detected, please use `pip3 install lion_pytorch` to install or set optim_type to adam')
|
|
189
|
+
|
|
190
|
+
optimizer = lion_pytorch.Lion
|
|
191
|
+
else:
|
|
192
|
+
optimizer = torch.optim.AdamW
|
|
193
|
+
|
|
194
|
+
return optimizer(
|
|
195
|
+
self._get_trainable_params(model),
|
|
196
|
+
lr=initial_lr,
|
|
197
|
+
weight_decay=self.train_config.optim_config.weight_decay
|
|
198
|
+
)
|
|
199
|
+
|
|
185
200
|
def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
|
|
186
|
-
if self.train_config.
|
|
187
|
-
warmup_iters = self.train_config.
|
|
188
|
-
min_lr = self.train_config.
|
|
189
|
-
max_lr = self.train_config.
|
|
190
|
-
cosine_annealing_period = self.train_config.
|
|
191
|
-
cosine_annealing_period_mul = self.train_config.
|
|
201
|
+
if self.train_config.optim_config.enable_lr_scheduler:
|
|
202
|
+
warmup_iters = self.train_config.optim_config.warmup_iters
|
|
203
|
+
min_lr = self.train_config.optim_config.min_lr
|
|
204
|
+
max_lr = self.train_config.optim_config.max_lr
|
|
205
|
+
cosine_annealing_period = self.train_config.optim_config.cosine_annealing_period
|
|
206
|
+
cosine_annealing_period_mul = self.train_config.optim_config.cosine_annealing_period_mul
|
|
192
207
|
|
|
193
208
|
return WarmupCosineAnnealingLRScheduler(
|
|
194
209
|
optimizer=self.optimizer,
|
|
@@ -220,9 +235,8 @@ class Trainer:
|
|
|
220
235
|
|
|
221
236
|
return criterion, kd_loss
|
|
222
237
|
|
|
223
|
-
def _convert_train_args(self) -> Tuple[dict, dict, dict
|
|
238
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
224
239
|
parallel_kwargs: Optional[Dict[str, Any]] = None
|
|
225
|
-
use_ds_optim: bool = False
|
|
226
240
|
if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
|
|
227
241
|
parallel_kwargs = {
|
|
228
242
|
'gradient_accumulation_steps': 1,
|
|
@@ -253,7 +267,6 @@ class Trainer:
|
|
|
253
267
|
"device": zero_config.offload_optimizer.device,
|
|
254
268
|
"pin_memory": zero_config.offload_optimizer.pin_memory
|
|
255
269
|
}
|
|
256
|
-
use_ds_optim = True
|
|
257
270
|
if zero_config.offload_param is not None:
|
|
258
271
|
zero_optimization['offload_param'] = {
|
|
259
272
|
"device": zero_config.offload_param.device,
|
|
@@ -328,10 +341,10 @@ class Trainer:
|
|
|
328
341
|
"drop_last": dataloader_args.data_loader_drop_last,
|
|
329
342
|
}
|
|
330
343
|
|
|
331
|
-
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
344
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
332
345
|
|
|
333
346
|
def _init_ref_model_args(self) -> dict:
|
|
334
|
-
parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
|
|
347
|
+
parallel_kwargs = copy.deepcopy(self.parallel_kwargs) if self.parallel_kwargs else None
|
|
335
348
|
|
|
336
349
|
if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
|
|
337
350
|
# reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
|
|
@@ -435,7 +448,7 @@ class Trainer:
|
|
|
435
448
|
exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
|
|
436
449
|
exception_line = e.__traceback__.tb_lineno
|
|
437
450
|
log_msg = f"epoch: {epoch}, batch: {batch}, {e} at {exception_file} line {exception_line}\n"
|
|
438
|
-
log(log_msg, f'{log_dir}
|
|
451
|
+
log(log_msg, f'{log_dir}exception.txt')
|
|
439
452
|
|
|
440
453
|
raise e
|
|
441
454
|
|
llm_trainer/utils.py
CHANGED
|
@@ -154,16 +154,22 @@ def batch_repeat_image_tok(
|
|
|
154
154
|
|
|
155
155
|
|
|
156
156
|
def pretrain_collate_fn(batch_data):
|
|
157
|
-
|
|
157
|
+
# [[x,x,x], [y,y,y]]
|
|
158
|
+
inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
159
|
+
# crossEntropy默认的ignore_index是-100
|
|
160
|
+
labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
|
|
158
161
|
|
|
159
162
|
# inputs, labels
|
|
160
|
-
return {
|
|
163
|
+
return {
|
|
164
|
+
'inputs': inputs,
|
|
165
|
+
'labels': labels
|
|
166
|
+
}
|
|
161
167
|
|
|
162
168
|
|
|
163
169
|
def get_sft_collate_fn(mask_prompt: bool):
|
|
164
170
|
def sft_collate_fn(batch_data):
|
|
165
171
|
"""
|
|
166
|
-
|
|
172
|
+
如果是sft,则不计算prompt部分的loss, 例如:
|
|
167
173
|
logits: [USER]你好[BOT]我好[SEP]
|
|
168
174
|
labels: [USER]你好[BOT]我好[SEP]
|
|
169
175
|
|
|
@@ -184,11 +190,19 @@ def get_sft_collate_fn(mask_prompt: bool):
|
|
|
184
190
|
batch_train_data.append(item['inputs'])
|
|
185
191
|
image_tags.append(item['image_tag'])
|
|
186
192
|
|
|
187
|
-
|
|
193
|
+
# [[x,x,x], [y,y,y]]
|
|
194
|
+
inputs = pad_sequence(batch_train_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
195
|
+
# crossEntropy默认的ignore_index是-100
|
|
196
|
+
labels = pad_sequence(batch_train_data, batch_first=True, padding_value=-100)
|
|
197
|
+
|
|
188
198
|
if mask_prompt:
|
|
189
199
|
labels = _mask_prompt(labels)
|
|
190
200
|
|
|
191
|
-
return {
|
|
201
|
+
return {
|
|
202
|
+
'inputs': inputs,
|
|
203
|
+
'labels': labels,
|
|
204
|
+
'image_tags': image_tags
|
|
205
|
+
}
|
|
192
206
|
|
|
193
207
|
return sft_collate_fn
|
|
194
208
|
|
|
@@ -295,13 +309,24 @@ def join_batch(batch_data: list[dict]) -> dict:
|
|
|
295
309
|
return result
|
|
296
310
|
|
|
297
311
|
|
|
298
|
-
def
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
labels
|
|
312
|
+
def fill_loss_mask(loss_masks, labels):
|
|
313
|
+
"""
|
|
314
|
+
将loss_mask中prompt部分强制设置为False
|
|
315
|
+
loss_masks: shape (B, T)
|
|
316
|
+
labels: shape (B, T)
|
|
317
|
+
"""
|
|
318
|
+
tokenizer = TrainerTools().tokenizer
|
|
319
|
+
# 支持多轮会话的mask
|
|
320
|
+
for batch, label in enumerate(labels):
|
|
321
|
+
start_index = -1
|
|
322
|
+
for index, token in enumerate(label):
|
|
323
|
+
if token == tokenizer.system or token == tokenizer.user:
|
|
324
|
+
start_index = index
|
|
325
|
+
elif token == tokenizer.end and start_index != -1:
|
|
326
|
+
loss_masks[batch, start_index:index + 1] = False
|
|
327
|
+
start_index = -1
|
|
303
328
|
|
|
304
|
-
return
|
|
329
|
+
return loss_masks
|
|
305
330
|
|
|
306
331
|
|
|
307
332
|
def _mask_prompt(labels):
|
|
@@ -10,14 +10,15 @@ if __name__ == '__main__':
|
|
|
10
10
|
if len(arguments) > 1:
|
|
11
11
|
# 0,1,2,3
|
|
12
12
|
cuda_visible_devive = arguments[1]
|
|
13
|
-
else:
|
|
14
|
-
cuda_visible_devive = None
|
|
15
13
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
14
|
+
# cuda location
|
|
15
|
+
if len(arguments) > 2:
|
|
16
|
+
cuda_loc = arguments[2]
|
|
17
|
+
else:
|
|
18
|
+
cuda_loc = 'localhost'
|
|
19
19
|
else:
|
|
20
|
-
|
|
20
|
+
cuda_visible_devive = None
|
|
21
|
+
cuda_loc = None
|
|
21
22
|
|
|
22
23
|
os.environ['PARALLEL_TYPE'] = 'ds'
|
|
23
24
|
|
|
@@ -1,33 +1,33 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
|
|
3
3
|
llm_trainer/dataset.py,sha256=UL3fGeM4XSlyNQRZH-139u3LujqAQx3YyaxNRewk6LE,8935
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=Qi7WKhFO4fdnj9W8BNIF_so6-F8g_YKUoPU9sNjWK_M,13320
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=3CcV-cuyV4ZUTymN9vz3au4uf3gZdyo8SGgSj2NEofs,16774
|
|
9
9
|
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
|
-
llm_trainer/loss.py,sha256=
|
|
10
|
+
llm_trainer/loss.py,sha256=RhTxftLMj1Tqc5pkUvJiZumfbMEPWL8GBGxdTfQggmk,6744
|
|
11
11
|
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
15
|
llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
|
|
16
16
|
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
|
-
llm_trainer/sft_trainer.py,sha256=
|
|
18
|
-
llm_trainer/tokenizer.py,sha256=
|
|
17
|
+
llm_trainer/sft_trainer.py,sha256=rSOGZx53jMgOuJdztfxQASYJ62uD0dVaih4IAnSwGBc,1787
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=0-xQCMz1xiPTDAZiYsVsiECSoZ_1eIvW9XsZOoFfakQ,7250
|
|
19
19
|
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
|
-
llm_trainer/train_configs.py,sha256=
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
22
|
-
llm_trainer/utils.py,sha256=
|
|
23
|
-
project_llm_trainer-0.
|
|
24
|
-
project_llm_trainer-0.
|
|
25
|
-
project_llm_trainer-0.
|
|
26
|
-
project_llm_trainer-0.
|
|
27
|
-
project_llm_trainer-0.
|
|
28
|
-
project_llm_trainer-0.
|
|
29
|
-
project_llm_trainer-0.
|
|
30
|
-
project_llm_trainer-0.
|
|
31
|
-
project_llm_trainer-0.
|
|
32
|
-
project_llm_trainer-0.
|
|
33
|
-
project_llm_trainer-0.
|
|
20
|
+
llm_trainer/train_configs.py,sha256=pPZkbliRdTnWSv3TUuTM23x9RDdMhGSPrxbNAyzDklY,7636
|
|
21
|
+
llm_trainer/trainer.py,sha256=diP-1suOf2U5dY_R8QH5arAx4MgBrKW-GBQ2_ScGNM8,28799
|
|
22
|
+
llm_trainer/utils.py,sha256=xC5plG-8-_Al5yIF5xIU5lroOcBBk98TEhtUJrazZPE,12305
|
|
23
|
+
project_llm_trainer-0.8.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.8.1.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
|
|
25
|
+
project_llm_trainer-0.8.1.data/scripts/ds_train,sha256=tME0xmMdX1D9XuVo07D9dilW5VIWavBS3UK9DoY67WI,709
|
|
26
|
+
project_llm_trainer-0.8.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.8.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.8.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.8.1.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
|
|
30
|
+
project_llm_trainer-0.8.1.dist-info/METADATA,sha256=07L7qqkujmk6YAwD5jPKe6dzyWPRu1Jirmp-6BqzMzA,195
|
|
31
|
+
project_llm_trainer-0.8.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.8.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.8.1.dist-info/RECORD,,
|
{project_llm_trainer-0.7.8.data → project_llm_trainer-0.8.1.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|