project-llm-trainer 0.7.8__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -12,7 +12,8 @@ from .loss import DPOLoss
12
12
  from .tools import TrainerTools
13
13
  from .utils import (
14
14
  autocast,
15
- get_dpo_collate_fn
15
+ get_dpo_collate_fn,
16
+ fill_loss_mask
16
17
  )
17
18
  from .partition_utils import sync_model_params
18
19
 
@@ -69,12 +70,12 @@ class DPOTrainer(Trainer):
69
70
 
70
71
  return criterion, None
71
72
 
72
- def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
73
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
73
74
  dpo_collate_fn = get_dpo_collate_fn(self.train_config.mask_prompt)
74
- parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
75
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
75
76
  data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
76
77
 
77
- return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
78
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
78
79
 
79
80
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
80
81
  file_path = self.train_config.file_dataset[file_idx]
@@ -84,7 +85,6 @@ class DPOTrainer(Trainer):
84
85
  def _calc_loss(self, inputs, attention_mask, logits, labels): ...
85
86
 
86
87
  def _log_probs_from_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
87
- # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
88
88
  if logits.dtype in [torch.float32, torch.float64]:
89
89
  logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
90
90
  logsumexp_values = torch.stack(
@@ -102,25 +102,26 @@ class DPOTrainer(Trainer):
102
102
  return log_probs_labels
103
103
 
104
104
 
105
- def _logprobs(self, logits, labels, mask):
105
+ def _logprobs(self, logits, labels, attention_mask):
106
106
  """
107
107
  Calculate the average log probabilities for a batch of sequences.
108
108
 
109
109
  Args:
110
110
  logits (torch.Tensor): Logits from the model with shape (B, T, V)
111
111
  labels (torch.Tensor): Ground truth labels with shape (B, T).
112
- mask (torch.Tensor): Mask tensor with shape (B, T) indicating
112
+ attention_mask (torch.Tensor): Mask tensor with shape (B, T) indicating
113
113
  which tokens are not padding (1 for valid tokens, 0 for padding).
114
114
 
115
115
  Returns:
116
116
  torch.Tensor: Average log probabilities for each sequence in the batch.
117
117
  Shape is (B,) representing the mean log probability for each sequence.
118
118
  """
119
- labels = labels[:, 1:].clone()
120
- logits = logits[:, :-1, :]
119
+ loss_masks = attention_mask.clone().bool()
120
+ loss_masks = fill_loss_mask(loss_masks, labels)
121
121
 
122
- # # Shift mask right by one to align with labels
123
- mask = mask[:, 1:].clone()
122
+ logits = logits[:, :-1, :]
123
+ labels = labels[:, 1:].clone()
124
+ loss_masks = loss_masks[:, 1:]
124
125
 
125
126
  # dummy token; we'll ignore the losses on these tokens later
126
127
  labels[labels == -100] = 0
@@ -129,11 +130,10 @@ class DPOTrainer(Trainer):
129
130
  per_token_logps = self._log_probs_from_logits(logits, labels)
130
131
 
131
132
  # Apply the mask to set log-probs of padding tokens to 0
132
- logprobs_sums = (per_token_logps * mask).sum(-1)
133
-
134
- # logprobs_means = (per_token_logps * mask).sum(-1) / mask.sum(-1)
133
+ logprobs_sums = (per_token_logps * loss_masks).sum(-1)
134
+ logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
135
135
 
136
- return logprobs_sums #, -logprobs_means.mean()
136
+ return logprobs_sums, logprobs_means
137
137
 
138
138
  def train(self):
139
139
  # 梯度累积步数
@@ -147,6 +147,7 @@ class DPOTrainer(Trainer):
147
147
  last_best_checkpoint_loss: Optional[float] = None
148
148
 
149
149
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
150
+ nll_loss_coef = self.train_config.dpo_config.nll_loss_coef
150
151
 
151
152
  for epoch in range(self.train_config.n_epochs):
152
153
  self.train_model.train()
@@ -188,36 +189,53 @@ class DPOTrainer(Trainer):
188
189
  try:
189
190
  chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
190
191
  chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
192
+
191
193
  rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
192
194
  rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
193
195
 
194
- chosen_attention_mask: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
195
- rejected_attention_mask: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
196
+ chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
197
+ rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
196
198
 
197
199
  # 在batch维度concat
198
200
  # [chosen, chosen, reject, reject]
199
201
  concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
200
202
  concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
201
- concat_mask = torch.concat([chosen_attention_mask, rejected_attention_mask], dim=0)
203
+ concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
202
204
 
203
205
  if TrainerTools().parallel.parallel_train:
204
206
  self.train_model.require_backward_grad_sync = need_update_grad
205
207
 
206
208
  with autocast(TrainerTools().parallel.device_type):
207
- policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
208
- policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
209
+ policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
210
+ policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels, concat_attention_masks)
209
211
  aux_loss = policy_outputs.get('aux_loss')
210
212
 
211
213
  with torch.no_grad():
212
- ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
213
- ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
214
+ ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
215
+ ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels, concat_attention_masks)
216
+
217
+ policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
218
+ policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
219
+
220
+ ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
221
+ ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
222
+
223
+ nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
214
224
 
215
225
  # calc loss
216
- loss = self.criterion(policy_probs, ref_probs)
226
+ loss = self.criterion(
227
+ policy_chosen_logps,
228
+ policy_rejected_logps,
229
+ ref_chosen_logps,
230
+ ref_rejected_logps
231
+ )
217
232
 
218
233
  if aux_loss_coef and aux_loss:
219
234
  loss += aux_loss_coef * aux_loss
220
235
 
236
+ if nll_loss_coef and nll_loss:
237
+ loss += nll_loss_coef * nll_loss
238
+
221
239
  if gradient_accumulation_steps > 1:
222
240
  loss = loss / gradient_accumulation_steps
223
241
 
@@ -82,11 +82,11 @@ class GRPOTrainer(Trainer):
82
82
 
83
83
  return criterion, None
84
84
 
85
- def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
86
- parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
85
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
86
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
87
87
  data_loader_kwargs.update({"collate_fn": lambda x: x})
88
88
 
89
- return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
89
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
90
90
 
91
91
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
92
92
  file_path = self.train_config.file_dataset[file_idx]
llm_trainer/loss.py CHANGED
@@ -92,17 +92,13 @@ class DPOLoss(nn.Module):
92
92
 
93
93
  def forward(
94
94
  self,
95
- policy_logps: torch.Tensor,
96
- reference_logps: torch.Tensor,
95
+ policy_chosen_logps: torch.Tensor,
96
+ policy_reject_logps: torch.Tensor,
97
+ ref_chosen_logps: torch.Tensor,
98
+ ref_reject_logps: torch.Tensor
97
99
  ) -> torch.Tensor:
98
- batch_size = reference_logps.shape[0]
99
- ref_chosen_probs = reference_logps[:batch_size//2]
100
- ref_reject_probs = reference_logps[batch_size//2:]
101
- policy_chosen_probs = policy_logps[:batch_size//2]
102
- policy_reject_probs = policy_logps[batch_size//2:]
103
-
104
- pi_logratios = policy_chosen_probs - policy_reject_probs
105
- ref_logratios = ref_chosen_probs - ref_reject_probs
100
+ pi_logratios = policy_chosen_logps - policy_reject_logps
101
+ ref_logratios = ref_chosen_logps - ref_reject_logps
106
102
  logits = pi_logratios - ref_logratios
107
103
 
108
104
  if self.ipo:
@@ -23,12 +23,12 @@ class SFTTrainer(Trainer):
23
23
  )
24
24
  self.packed_sequences = False
25
25
 
26
- def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
26
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
27
27
  sft_collate_fn = get_sft_collate_fn(self.train_config.mask_prompt)
28
- parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
28
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
29
29
  data_loader_kwargs.update({"collate_fn": sft_collate_fn})
30
30
 
31
- return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
31
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
32
32
 
33
33
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
34
34
  file_path = self.train_config.file_dataset[file_idx]
llm_trainer/tokenizer.py CHANGED
@@ -3,7 +3,7 @@ import warnings
3
3
  from typing import List, Dict, Union
4
4
  from transformers import Qwen2TokenizerFast
5
5
  from transformers import AddedToken
6
- from transformers import LlamaTokenizer, LlamaTokenizerFast
6
+ from transformers import LlamaTokenizerFast
7
7
  import torch
8
8
 
9
9
  TOKEN_TYPE_QWEN = 'qwen'
@@ -164,3 +164,18 @@ class Tokenizer:
164
164
 
165
165
  return chat_template
166
166
 
167
+ def get_special_tokens_dict(self):
168
+ return {
169
+ self.text_end: self.end,
170
+ self.text_pad: self.pad,
171
+ self.text_unk: self.unk,
172
+ self.text_user: self.user,
173
+ self.text_assistant: self.assistant,
174
+ self.text_think_start: self.think_start,
175
+ self.text_think_end: self.think_end,
176
+ self.text_answer_start: self.answer_start,
177
+ self.text_answer_end: self.answer_end,
178
+ self.text_system: self.system,
179
+ self.text_image: self.image,
180
+ }
181
+
@@ -107,7 +107,8 @@ class DataLoaderConfig:
107
107
 
108
108
 
109
109
  @dataclass(kw_only=True)
110
- class LrConfig:
110
+ class OptimConfig:
111
+ optim_type: str = 'adam' # or 'lion'
111
112
  enable_lr_scheduler: bool = False
112
113
  initial_lr: float
113
114
  weight_decay: float = 0.1
@@ -195,8 +196,8 @@ class TrainConfig:
195
196
  grpo训练时不生效该配置!
196
197
  eval_batch_interval (`int`, default is 100):
197
198
  每隔多少个batch进行模型eval
198
- lr_config (`LrConfig`):
199
- lr配置项
199
+ optim_config (`OptimConfig`):
200
+ optim配置项
200
201
  data_loader_config: (`DataLoaderConfig`):
201
202
  data loader配置项
202
203
  kd_config: (`KDConfig`, *Optional*, default is None):
@@ -213,7 +214,7 @@ class TrainConfig:
213
214
  image_tags_file_dataset: Optional[FileDataset] = None
214
215
 
215
216
  loss_config: LossConfig = field(default_factory=LossConfig)
216
- lr_config: LrConfig = field(default_factory=LrConfig)
217
+ optim_config: OptimConfig = field(default_factory=OptimConfig)
217
218
 
218
219
  ds_config: DsConfig = field(default_factory=DsConfig)
219
220
 
llm_trainer/trainer.py CHANGED
@@ -77,19 +77,15 @@ class Trainer:
77
77
  if self.eval_image_tags:
78
78
  assert len(self.eval_prompts) == len(self.eval_image_tags)
79
79
 
80
- parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
81
- self.parallel_kwargs = parallel_kwargs
82
- self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
83
- self.sampler_kwargs: dict[str, Any] = sampler_kwargs
84
-
80
+ self.parallel_kwargs, self.data_loader_kwargs, self.sampler_kwargs = self._convert_train_args()
85
81
  # initialize a GradScaler. If enabled=False scaler is a no-op
86
82
  self.scalar = torch.GradScaler(enabled=TrainerTools().use_amp)
87
83
 
88
84
  # 注意:学习率要根据GPU的数量进行倍增:
89
85
  # 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
90
- initial_lr = train_config.lr_config.initial_lr
86
+ initial_lr = train_config.optim_config.initial_lr
91
87
 
92
- self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
88
+ self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr)
93
89
  self.lr_scheduler = self._init_lr_scheduler(initial_lr)
94
90
 
95
91
  self.criterion, self.kd_loss = self._init_loss()
@@ -127,12 +123,7 @@ class Trainer:
127
123
  freeze_llm_model = self.train_config.freeze_llm_model
128
124
  return model.parameters() if not freeze_llm_model else filter(lambda p: p.requires_grad, model.parameters())
129
125
 
130
- def _init_train_model_and_optim(
131
- self,
132
- initial_lr: float,
133
- parallel_kwargs: dict,
134
- use_ds_optim: bool
135
- ):
126
+ def _init_train_model_and_optim(self, initial_lr: float):
136
127
  model = self._new_model(self.train_config)
137
128
 
138
129
  if self.train_config.init_state_dict:
@@ -161,34 +152,58 @@ class Trainer:
161
152
  total_size_mb = total_size_bytes / (1024 * 1024)
162
153
  log(f"Total size of the model: {total_size_mb:.2f} MB")
163
154
 
164
- if use_ds_optim:
165
- import deepspeed
166
- origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
167
- self._get_trainable_params(model),
168
- lr=initial_lr,
169
- weight_decay=self.train_config.lr_config.weight_decay
170
- )
171
- else:
172
- origin_optim = torch.optim.AdamW(
173
- self._get_trainable_params(model),
174
- lr=initial_lr,
175
- weight_decay=self.train_config.lr_config.weight_decay
176
- )
177
155
  model, optim = TrainerTools().parallel.process(
178
156
  model=model,
179
- optimizer=origin_optim,
180
- kwargs=parallel_kwargs
157
+ optimizer=self._get_optim(model, initial_lr),
158
+ kwargs=self.parallel_kwargs
181
159
  )
182
160
 
183
161
  return model, optim
184
162
 
163
+ def _get_optim(self, model, initial_lr):
164
+ optimizer = None
165
+
166
+ if isinstance(TrainerTools().parallel, DsParallel) and self.parallel_kwargs:
167
+ import deepspeed
168
+ if ('zero_optimization' in self.parallel_kwargs
169
+ and 'offload_optimizer' in self.parallel_kwargs['zero_optimization']
170
+ and self.parallel_kwargs['zero_optimization']['offload_optimizer']['device'] == 'cpu'):
171
+ # offline optimizer to cpu
172
+ # 不能使用 deepspeed.ops.lion.cpu_lion.DeepSpeedCPULion???
173
+ # 所以,这里忽略lion判断
174
+ optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
175
+ if self.train_config.optim_config.optim_type == 'lion':
176
+ log('When set offload_optimizer, lion optim is unsupported, so set optim to adam!!!!!')
177
+ else:
178
+ if self.train_config.optim_config.optim_type == 'lion':
179
+ optimizer = deepspeed.ops.lion.FusedLion
180
+ else:
181
+ optimizer = deepspeed.ops.adam.FusedAdam
182
+
183
+ if not optimizer:
184
+ if self.train_config.optim_config.optim_type == 'lion':
185
+ try:
186
+ import lion_pytorch
187
+ except:
188
+ raise Exception('lion is not detected, please use `pip3 install lion_pytorch` to install or set optim_type to adam')
189
+
190
+ optimizer = lion_pytorch.Lion
191
+ else:
192
+ optimizer = torch.optim.AdamW
193
+
194
+ return optimizer(
195
+ self._get_trainable_params(model),
196
+ lr=initial_lr,
197
+ weight_decay=self.train_config.optim_config.weight_decay
198
+ )
199
+
185
200
  def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
186
- if self.train_config.lr_config.enable_lr_scheduler:
187
- warmup_iters = self.train_config.lr_config.warmup_iters
188
- min_lr = self.train_config.lr_config.min_lr
189
- max_lr = self.train_config.lr_config.max_lr
190
- cosine_annealing_period = self.train_config.lr_config.cosine_annealing_period
191
- cosine_annealing_period_mul = self.train_config.lr_config.cosine_annealing_period_mul
201
+ if self.train_config.optim_config.enable_lr_scheduler:
202
+ warmup_iters = self.train_config.optim_config.warmup_iters
203
+ min_lr = self.train_config.optim_config.min_lr
204
+ max_lr = self.train_config.optim_config.max_lr
205
+ cosine_annealing_period = self.train_config.optim_config.cosine_annealing_period
206
+ cosine_annealing_period_mul = self.train_config.optim_config.cosine_annealing_period_mul
192
207
 
193
208
  return WarmupCosineAnnealingLRScheduler(
194
209
  optimizer=self.optimizer,
@@ -220,9 +235,8 @@ class Trainer:
220
235
 
221
236
  return criterion, kd_loss
222
237
 
223
- def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
238
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
224
239
  parallel_kwargs: Optional[Dict[str, Any]] = None
225
- use_ds_optim: bool = False
226
240
  if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
227
241
  parallel_kwargs = {
228
242
  'gradient_accumulation_steps': 1,
@@ -253,7 +267,6 @@ class Trainer:
253
267
  "device": zero_config.offload_optimizer.device,
254
268
  "pin_memory": zero_config.offload_optimizer.pin_memory
255
269
  }
256
- use_ds_optim = True
257
270
  if zero_config.offload_param is not None:
258
271
  zero_optimization['offload_param'] = {
259
272
  "device": zero_config.offload_param.device,
@@ -328,10 +341,10 @@ class Trainer:
328
341
  "drop_last": dataloader_args.data_loader_drop_last,
329
342
  }
330
343
 
331
- return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
344
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
332
345
 
333
346
  def _init_ref_model_args(self) -> dict:
334
- parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
347
+ parallel_kwargs = copy.deepcopy(self.parallel_kwargs) if self.parallel_kwargs else None
335
348
 
336
349
  if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
337
350
  # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
@@ -435,7 +448,7 @@ class Trainer:
435
448
  exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
436
449
  exception_line = e.__traceback__.tb_lineno
437
450
  log_msg = f"epoch: {epoch}, batch: {batch}, {e} at {exception_file} line {exception_line}\n"
438
- log(log_msg, f'{log_dir}log.txt')
451
+ log(log_msg, f'{log_dir}exception.txt')
439
452
 
440
453
  raise e
441
454
 
llm_trainer/utils.py CHANGED
@@ -154,16 +154,22 @@ def batch_repeat_image_tok(
154
154
 
155
155
 
156
156
  def pretrain_collate_fn(batch_data):
157
- inputs, labels = _pad_sequence(batch_data)
157
+ # [[x,x,x], [y,y,y]]
158
+ inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
159
+ # crossEntropy默认的ignore_index是-100
160
+ labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
158
161
 
159
162
  # inputs, labels
160
- return {'inputs': inputs, 'labels': labels}
163
+ return {
164
+ 'inputs': inputs,
165
+ 'labels': labels
166
+ }
161
167
 
162
168
 
163
169
  def get_sft_collate_fn(mask_prompt: bool):
164
170
  def sft_collate_fn(batch_data):
165
171
  """
166
- 如果是sft,则不计算prompt部分的loss, 例如:
172
+ 如果是sft,则不计算prompt部分的loss, 例如:
167
173
  logits: [USER]你好[BOT]我好[SEP]
168
174
  labels: [USER]你好[BOT]我好[SEP]
169
175
 
@@ -184,11 +190,19 @@ def get_sft_collate_fn(mask_prompt: bool):
184
190
  batch_train_data.append(item['inputs'])
185
191
  image_tags.append(item['image_tag'])
186
192
 
187
- inputs, labels = _pad_sequence(batch_train_data)
193
+ # [[x,x,x], [y,y,y]]
194
+ inputs = pad_sequence(batch_train_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
195
+ # crossEntropy默认的ignore_index是-100
196
+ labels = pad_sequence(batch_train_data, batch_first=True, padding_value=-100)
197
+
188
198
  if mask_prompt:
189
199
  labels = _mask_prompt(labels)
190
200
 
191
- return {'inputs': inputs, 'labels': labels, 'image_tags': image_tags}
201
+ return {
202
+ 'inputs': inputs,
203
+ 'labels': labels,
204
+ 'image_tags': image_tags
205
+ }
192
206
 
193
207
  return sft_collate_fn
194
208
 
@@ -295,13 +309,24 @@ def join_batch(batch_data: list[dict]) -> dict:
295
309
  return result
296
310
 
297
311
 
298
- def _pad_sequence(batch_data):
299
- # [[x,x,x], [y,y,y]]
300
- inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
301
- # crossEntropy默认的ignore_index是-100
302
- labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
312
+ def fill_loss_mask(loss_masks, labels):
313
+ """
314
+ 将loss_mask中prompt部分强制设置为False
315
+ loss_masks: shape (B, T)
316
+ labels: shape (B, T)
317
+ """
318
+ tokenizer = TrainerTools().tokenizer
319
+ # 支持多轮会话的mask
320
+ for batch, label in enumerate(labels):
321
+ start_index = -1
322
+ for index, token in enumerate(label):
323
+ if token == tokenizer.system or token == tokenizer.user:
324
+ start_index = index
325
+ elif token == tokenizer.end and start_index != -1:
326
+ loss_masks[batch, start_index:index + 1] = False
327
+ start_index = -1
303
328
 
304
- return inputs, labels
329
+ return loss_masks
305
330
 
306
331
 
307
332
  def _mask_prompt(labels):
@@ -10,14 +10,15 @@ if __name__ == '__main__':
10
10
  if len(arguments) > 1:
11
11
  # 0,1,2,3
12
12
  cuda_visible_devive = arguments[1]
13
- else:
14
- cuda_visible_devive = None
15
13
 
16
- # cuda location
17
- if len(arguments) > 2:
18
- cuda_loc = arguments[2]
14
+ # cuda location
15
+ if len(arguments) > 2:
16
+ cuda_loc = arguments[2]
17
+ else:
18
+ cuda_loc = 'localhost'
19
19
  else:
20
- cuda_loc = 'localhost'
20
+ cuda_visible_devive = None
21
+ cuda_loc = None
21
22
 
22
23
  os.environ['PARALLEL_TYPE'] = 'ds'
23
24
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.7.8
3
+ Version: 0.8.1
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,33 +1,33 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
3
3
  llm_trainer/dataset.py,sha256=UL3fGeM4XSlyNQRZH-139u3LujqAQx3YyaxNRewk6LE,8935
4
- llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
4
+ llm_trainer/dpo_trainer.py,sha256=Qi7WKhFO4fdnj9W8BNIF_so6-F8g_YKUoPU9sNjWK_M,13320
5
5
  llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
8
- llm_trainer/grpo_trainer.py,sha256=MXnP8Kc9CQJw0CB3uMbHxIYwvpuujai4hgbbpUut_K4,16808
8
+ llm_trainer/grpo_trainer.py,sha256=3CcV-cuyV4ZUTymN9vz3au4uf3gZdyo8SGgSj2NEofs,16774
9
9
  llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
10
- llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
10
+ llm_trainer/loss.py,sha256=RhTxftLMj1Tqc5pkUvJiZumfbMEPWL8GBGxdTfQggmk,6744
11
11
  llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
12
12
  llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
14
  llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
15
  llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
16
16
  llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
- llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
18
- llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
17
+ llm_trainer/sft_trainer.py,sha256=rSOGZx53jMgOuJdztfxQASYJ62uD0dVaih4IAnSwGBc,1787
18
+ llm_trainer/tokenizer.py,sha256=0-xQCMz1xiPTDAZiYsVsiECSoZ_1eIvW9XsZOoFfakQ,7250
19
19
  llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
- llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
21
- llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
22
- llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
23
- project_llm_trainer-0.7.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.7.8.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
25
- project_llm_trainer-0.7.8.data/scripts/ds_train,sha256=3nXNNKmYI7miqyBdf-Ijl_rW1cGIKrAMZ1CSswN_gGo,665
26
- project_llm_trainer-0.7.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.7.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.7.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.7.8.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
30
- project_llm_trainer-0.7.8.dist-info/METADATA,sha256=rSYUrEkdjPCyYUqT2SOw3-hzT40wU3AwEw-ouHh1rBY,195
31
- project_llm_trainer-0.7.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.7.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.7.8.dist-info/RECORD,,
20
+ llm_trainer/train_configs.py,sha256=pPZkbliRdTnWSv3TUuTM23x9RDdMhGSPrxbNAyzDklY,7636
21
+ llm_trainer/trainer.py,sha256=diP-1suOf2U5dY_R8QH5arAx4MgBrKW-GBQ2_ScGNM8,28799
22
+ llm_trainer/utils.py,sha256=xC5plG-8-_Al5yIF5xIU5lroOcBBk98TEhtUJrazZPE,12305
23
+ project_llm_trainer-0.8.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.8.1.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
25
+ project_llm_trainer-0.8.1.data/scripts/ds_train,sha256=tME0xmMdX1D9XuVo07D9dilW5VIWavBS3UK9DoY67WI,709
26
+ project_llm_trainer-0.8.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.8.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.8.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.8.1.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
30
+ project_llm_trainer-0.8.1.dist-info/METADATA,sha256=07L7qqkujmk6YAwD5jPKe6dzyWPRu1Jirmp-6BqzMzA,195
31
+ project_llm_trainer-0.8.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.8.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.8.1.dist-info/RECORD,,