project-llm-trainer 0.7.8__py3-none-any.whl → 0.7.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -12,7 +12,8 @@ from .loss import DPOLoss
12
12
  from .tools import TrainerTools
13
13
  from .utils import (
14
14
  autocast,
15
- get_dpo_collate_fn
15
+ get_dpo_collate_fn,
16
+ fill_loss_mask
16
17
  )
17
18
  from .partition_utils import sync_model_params
18
19
 
@@ -84,7 +85,6 @@ class DPOTrainer(Trainer):
84
85
  def _calc_loss(self, inputs, attention_mask, logits, labels): ...
85
86
 
86
87
  def _log_probs_from_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
87
- # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
88
88
  if logits.dtype in [torch.float32, torch.float64]:
89
89
  logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
90
90
  logsumexp_values = torch.stack(
@@ -102,25 +102,26 @@ class DPOTrainer(Trainer):
102
102
  return log_probs_labels
103
103
 
104
104
 
105
- def _logprobs(self, logits, labels, mask):
105
+ def _logprobs(self, logits, labels, attention_mask):
106
106
  """
107
107
  Calculate the average log probabilities for a batch of sequences.
108
108
 
109
109
  Args:
110
110
  logits (torch.Tensor): Logits from the model with shape (B, T, V)
111
111
  labels (torch.Tensor): Ground truth labels with shape (B, T).
112
- mask (torch.Tensor): Mask tensor with shape (B, T) indicating
112
+ attention_mask (torch.Tensor): Mask tensor with shape (B, T) indicating
113
113
  which tokens are not padding (1 for valid tokens, 0 for padding).
114
114
 
115
115
  Returns:
116
116
  torch.Tensor: Average log probabilities for each sequence in the batch.
117
117
  Shape is (B,) representing the mean log probability for each sequence.
118
118
  """
119
- labels = labels[:, 1:].clone()
120
- logits = logits[:, :-1, :]
119
+ loss_masks = attention_mask.clone().bool()
120
+ loss_masks = fill_loss_mask(loss_masks, labels)
121
121
 
122
- # # Shift mask right by one to align with labels
123
- mask = mask[:, 1:].clone()
122
+ logits = logits[:, :-1, :]
123
+ labels = labels[:, 1:].clone()
124
+ loss_masks = loss_masks[:, 1:]
124
125
 
125
126
  # dummy token; we'll ignore the losses on these tokens later
126
127
  labels[labels == -100] = 0
@@ -129,11 +130,10 @@ class DPOTrainer(Trainer):
129
130
  per_token_logps = self._log_probs_from_logits(logits, labels)
130
131
 
131
132
  # Apply the mask to set log-probs of padding tokens to 0
132
- logprobs_sums = (per_token_logps * mask).sum(-1)
133
-
134
- # logprobs_means = (per_token_logps * mask).sum(-1) / mask.sum(-1)
133
+ logprobs_sums = (per_token_logps * loss_masks).sum(-1)
134
+ logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
135
135
 
136
- return logprobs_sums #, -logprobs_means.mean()
136
+ return logprobs_sums, logprobs_means
137
137
 
138
138
  def train(self):
139
139
  # 梯度累积步数
@@ -147,6 +147,7 @@ class DPOTrainer(Trainer):
147
147
  last_best_checkpoint_loss: Optional[float] = None
148
148
 
149
149
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
150
+ nll_loss_coef = self.train_config.dpo_config.nll_loss_coef
150
151
 
151
152
  for epoch in range(self.train_config.n_epochs):
152
153
  self.train_model.train()
@@ -188,36 +189,53 @@ class DPOTrainer(Trainer):
188
189
  try:
189
190
  chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
190
191
  chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
192
+
191
193
  rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
192
194
  rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
193
195
 
194
- chosen_attention_mask: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
195
- rejected_attention_mask: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
196
+ chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
197
+ rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
196
198
 
197
199
  # 在batch维度concat
198
200
  # [chosen, chosen, reject, reject]
199
201
  concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
200
202
  concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
201
- concat_mask = torch.concat([chosen_attention_mask, rejected_attention_mask], dim=0)
203
+ concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
202
204
 
203
205
  if TrainerTools().parallel.parallel_train:
204
206
  self.train_model.require_backward_grad_sync = need_update_grad
205
207
 
206
208
  with autocast(TrainerTools().parallel.device_type):
207
- policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
208
- policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
209
+ policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
210
+ policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels, concat_attention_masks)
209
211
  aux_loss = policy_outputs.get('aux_loss')
210
212
 
211
213
  with torch.no_grad():
212
- ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
213
- ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
214
+ ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
215
+ ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels, concat_attention_masks)
216
+
217
+ policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
218
+ policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
219
+
220
+ ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
221
+ ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
222
+
223
+ nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
214
224
 
215
225
  # calc loss
216
- loss = self.criterion(policy_probs, ref_probs)
226
+ loss = self.criterion(
227
+ policy_chosen_logps,
228
+ policy_rejected_logps,
229
+ ref_chosen_logps,
230
+ ref_rejected_logps
231
+ )
217
232
 
218
233
  if aux_loss_coef and aux_loss:
219
234
  loss += aux_loss_coef * aux_loss
220
235
 
236
+ if nll_loss_coef and nll_loss:
237
+ loss += nll_loss_coef * nll_loss
238
+
221
239
  if gradient_accumulation_steps > 1:
222
240
  loss = loss / gradient_accumulation_steps
223
241
 
llm_trainer/loss.py CHANGED
@@ -92,17 +92,13 @@ class DPOLoss(nn.Module):
92
92
 
93
93
  def forward(
94
94
  self,
95
- policy_logps: torch.Tensor,
96
- reference_logps: torch.Tensor,
95
+ policy_chosen_logps: torch.Tensor,
96
+ policy_reject_logps: torch.Tensor,
97
+ ref_chosen_logps: torch.Tensor,
98
+ ref_reject_logps: torch.Tensor
97
99
  ) -> torch.Tensor:
98
- batch_size = reference_logps.shape[0]
99
- ref_chosen_probs = reference_logps[:batch_size//2]
100
- ref_reject_probs = reference_logps[batch_size//2:]
101
- policy_chosen_probs = policy_logps[:batch_size//2]
102
- policy_reject_probs = policy_logps[batch_size//2:]
103
-
104
- pi_logratios = policy_chosen_probs - policy_reject_probs
105
- ref_logratios = ref_chosen_probs - ref_reject_probs
100
+ pi_logratios = policy_chosen_logps - policy_reject_logps
101
+ ref_logratios = ref_chosen_logps - ref_reject_logps
106
102
  logits = pi_logratios - ref_logratios
107
103
 
108
104
  if self.ipo:
llm_trainer/tokenizer.py CHANGED
@@ -3,7 +3,7 @@ import warnings
3
3
  from typing import List, Dict, Union
4
4
  from transformers import Qwen2TokenizerFast
5
5
  from transformers import AddedToken
6
- from transformers import LlamaTokenizer, LlamaTokenizerFast
6
+ from transformers import LlamaTokenizerFast
7
7
  import torch
8
8
 
9
9
  TOKEN_TYPE_QWEN = 'qwen'
@@ -164,3 +164,18 @@ class Tokenizer:
164
164
 
165
165
  return chat_template
166
166
 
167
+ def get_special_tokens_dict(self):
168
+ return {
169
+ self.text_end: self.end,
170
+ self.text_pad: self.pad,
171
+ self.text_unk: self.unk,
172
+ self.text_user: self.user,
173
+ self.text_assistant: self.assistant,
174
+ self.text_think_start: self.think_start,
175
+ self.text_think_end: self.think_end,
176
+ self.text_answer_start: self.answer_start,
177
+ self.text_answer_end: self.answer_end,
178
+ self.text_system: self.system,
179
+ self.text_image: self.image,
180
+ }
181
+
llm_trainer/utils.py CHANGED
@@ -154,16 +154,22 @@ def batch_repeat_image_tok(
154
154
 
155
155
 
156
156
  def pretrain_collate_fn(batch_data):
157
- inputs, labels = _pad_sequence(batch_data)
157
+ # [[x,x,x], [y,y,y]]
158
+ inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
159
+ # crossEntropy默认的ignore_index是-100
160
+ labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
158
161
 
159
162
  # inputs, labels
160
- return {'inputs': inputs, 'labels': labels}
163
+ return {
164
+ 'inputs': inputs,
165
+ 'labels': labels
166
+ }
161
167
 
162
168
 
163
169
  def get_sft_collate_fn(mask_prompt: bool):
164
170
  def sft_collate_fn(batch_data):
165
171
  """
166
- 如果是sft,则不计算prompt部分的loss, 例如:
172
+ 如果是sft,则不计算prompt部分的loss, 例如:
167
173
  logits: [USER]你好[BOT]我好[SEP]
168
174
  labels: [USER]你好[BOT]我好[SEP]
169
175
 
@@ -184,11 +190,19 @@ def get_sft_collate_fn(mask_prompt: bool):
184
190
  batch_train_data.append(item['inputs'])
185
191
  image_tags.append(item['image_tag'])
186
192
 
187
- inputs, labels = _pad_sequence(batch_train_data)
193
+ # [[x,x,x], [y,y,y]]
194
+ inputs = pad_sequence(batch_train_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
195
+ # crossEntropy默认的ignore_index是-100
196
+ labels = pad_sequence(batch_train_data, batch_first=True, padding_value=-100)
197
+
188
198
  if mask_prompt:
189
199
  labels = _mask_prompt(labels)
190
200
 
191
- return {'inputs': inputs, 'labels': labels, 'image_tags': image_tags}
201
+ return {
202
+ 'inputs': inputs,
203
+ 'labels': labels,
204
+ 'image_tags': image_tags
205
+ }
192
206
 
193
207
  return sft_collate_fn
194
208
 
@@ -295,13 +309,24 @@ def join_batch(batch_data: list[dict]) -> dict:
295
309
  return result
296
310
 
297
311
 
298
- def _pad_sequence(batch_data):
299
- # [[x,x,x], [y,y,y]]
300
- inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
301
- # crossEntropy默认的ignore_index是-100
302
- labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
312
+ def fill_loss_mask(loss_masks, labels):
313
+ """
314
+ 将loss_mask中prompt部分强制设置为False
315
+ loss_masks: shape (B, T)
316
+ labels: shape (B, T)
317
+ """
318
+ tokenizer = TrainerTools().tokenizer
319
+ # 支持多轮会话的mask
320
+ for batch, label in enumerate(labels):
321
+ start_index = -1
322
+ for index, token in enumerate(label):
323
+ if token == tokenizer.system or token == tokenizer.user:
324
+ start_index = index
325
+ elif token == tokenizer.end and start_index != -1:
326
+ loss_masks[batch, start_index:index + 1] = False
327
+ start_index = -1
303
328
 
304
- return inputs, labels
329
+ return loss_masks
305
330
 
306
331
 
307
332
  def _mask_prompt(labels):
@@ -10,14 +10,15 @@ if __name__ == '__main__':
10
10
  if len(arguments) > 1:
11
11
  # 0,1,2,3
12
12
  cuda_visible_devive = arguments[1]
13
- else:
14
- cuda_visible_devive = None
15
13
 
16
- # cuda location
17
- if len(arguments) > 2:
18
- cuda_loc = arguments[2]
14
+ # cuda location
15
+ if len(arguments) > 2:
16
+ cuda_loc = arguments[2]
17
+ else:
18
+ cuda_loc = 'localhost'
19
19
  else:
20
- cuda_loc = 'localhost'
20
+ cuda_visible_devive = None
21
+ cuda_loc = None
21
22
 
22
23
  os.environ['PARALLEL_TYPE'] = 'ds'
23
24
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.7.8
3
+ Version: 0.7.9
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,13 +1,13 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
3
3
  llm_trainer/dataset.py,sha256=UL3fGeM4XSlyNQRZH-139u3LujqAQx3YyaxNRewk6LE,8935
4
- llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
4
+ llm_trainer/dpo_trainer.py,sha256=Bgds18UWFhzf_UNCFN-iBCdhKf9pcXJBFPEc32oJeXA,13354
5
5
  llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
8
8
  llm_trainer/grpo_trainer.py,sha256=MXnP8Kc9CQJw0CB3uMbHxIYwvpuujai4hgbbpUut_K4,16808
9
9
  llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
10
- llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
10
+ llm_trainer/loss.py,sha256=RhTxftLMj1Tqc5pkUvJiZumfbMEPWL8GBGxdTfQggmk,6744
11
11
  llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
12
12
  llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
@@ -15,19 +15,19 @@ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,
15
15
  llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
16
16
  llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
17
  llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
18
- llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
18
+ llm_trainer/tokenizer.py,sha256=0-xQCMz1xiPTDAZiYsVsiECSoZ_1eIvW9XsZOoFfakQ,7250
19
19
  llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
20
  llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
21
21
  llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
22
- llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
23
- project_llm_trainer-0.7.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.7.8.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
25
- project_llm_trainer-0.7.8.data/scripts/ds_train,sha256=3nXNNKmYI7miqyBdf-Ijl_rW1cGIKrAMZ1CSswN_gGo,665
26
- project_llm_trainer-0.7.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.7.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.7.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.7.8.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
30
- project_llm_trainer-0.7.8.dist-info/METADATA,sha256=rSYUrEkdjPCyYUqT2SOw3-hzT40wU3AwEw-ouHh1rBY,195
31
- project_llm_trainer-0.7.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.7.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.7.8.dist-info/RECORD,,
22
+ llm_trainer/utils.py,sha256=xC5plG-8-_Al5yIF5xIU5lroOcBBk98TEhtUJrazZPE,12305
23
+ project_llm_trainer-0.7.9.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.7.9.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
25
+ project_llm_trainer-0.7.9.data/scripts/ds_train,sha256=tME0xmMdX1D9XuVo07D9dilW5VIWavBS3UK9DoY67WI,709
26
+ project_llm_trainer-0.7.9.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.7.9.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.7.9.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.7.9.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
30
+ project_llm_trainer-0.7.9.dist-info/METADATA,sha256=mDGLc1BjmIlOPz85JYB5bFnlXJgJ5VaNesW4z0HDZCA,195
31
+ project_llm_trainer-0.7.9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.7.9.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.7.9.dist-info/RECORD,,