project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,393 @@
1
+ import time
2
+ import copy
3
+ from typing import Tuple, List, Union, Callable, Optional
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ import torch.distributed as dist
8
+ import torch.nn.functional as F
9
+
10
+ from llm_model import LlmModel
11
+
12
+ from .parallel_ds import DsParallel
13
+ from .trainer import Trainer
14
+ from .train_configs import TrainConfig
15
+ from .dataset import GRPORolloutDataset
16
+ from .loss import GRPOLoss
17
+ from .tools import TrainerTools
18
+ from .generate_utils import batch_generate
19
+
20
+ from .checkpoint import (
21
+ save_checkpoint,
22
+ load_checkpoint_for_eval,
23
+ save_steps,
24
+ )
25
+
26
+ class GRPOTrainer(Trainer):
27
+ def __init__(
28
+ self,
29
+ *,
30
+ train_config: TrainConfig,
31
+ reward_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], List[float]],
32
+ eval_prompts: List[str],
33
+ eval_image_tags: Optional[List[int]] = None
34
+ ):
35
+ super().__init__(
36
+ train_config=train_config,
37
+ eval_prompts=eval_prompts,
38
+ eval_image_tags=eval_image_tags
39
+ )
40
+
41
+ self.reward_func = reward_func
42
+ self.reference_model = self._init_reference_model()
43
+ self.generate_model = self._init_generate_model()
44
+
45
+ # 默认使用torch提供的pad_sequence
46
+ # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
47
+ self._use_origin_pad_sequence = True
48
+
49
+ # 保存一下train model的checkpoint,方便下面reference_model使用
50
+ save_checkpoint(self.train_model, self.optimizer)
51
+
52
+ def _init_reference_model(self):
53
+ reference_model = LlmModel(self.train_config.model_config)
54
+
55
+ device = 'cpu' # TrainerTools().parallel.device
56
+ reference_model.to(device)
57
+ # load_checkpoint_for_eval(model=reference_model, device=device)
58
+
59
+ reference_model.eval()
60
+ for param in reference_model.parameters():
61
+ param.requires_grad = False
62
+
63
+ return reference_model
64
+
65
+ def _init_generate_model(self):
66
+ return copy.deepcopy(self.reference_model)
67
+ # generate_model = LlmModel(self.train_config.model_config)
68
+ #
69
+ # device = 'cpu' #TrainerTools().parallel.device
70
+ # generate_model.to(device)
71
+ # # load_checkpoint_for_eval(model=generate_model, device=device)
72
+ #
73
+ # generate_model.eval()
74
+ # for param in generate_model.parameters():
75
+ # param.requires_grad = False
76
+ #
77
+ # return generate_model
78
+
79
+ def _init_loss(self):
80
+ criterion = GRPOLoss(
81
+ clip_eps=self.train_config.grpo_config.clip_eps,
82
+ kl_weight=self.train_config.grpo_config.kl_weight
83
+ )
84
+
85
+ return criterion, None
86
+
87
+ def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
88
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
89
+ data_loader_kwargs.update({"collate_fn": lambda x: x})
90
+
91
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
92
+
93
+ def _create_dataset(self, file_path) -> Dataset:
94
+ return GRPORolloutDataset(file_path)
95
+
96
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
97
+
98
+ def _left_pad_sequence(
99
+ self,
100
+ sequences: Union[torch.Tensor, List[torch.Tensor]],
101
+ padding_value: float,
102
+ ) -> torch.Tensor:
103
+ if self._use_origin_pad_sequence:
104
+ try:
105
+ return pad_sequence(sequences, batch_first=True, padding_value=padding_value, padding_side='left')
106
+ except:
107
+ self._use_origin_pad_sequence = False
108
+ return self._left_pad_sequence(sequences, padding_value)
109
+ else:
110
+ # 反转每个序列的顺序(如 [1,2,3] → [3,2,1])
111
+ reversed_sequences = [seq.flip(dims=(0,)) for seq in sequences]
112
+ # 使用默认的右侧填充
113
+ padded_reversed = pad_sequence(reversed_sequences, batch_first=True, padding_value=padding_value)
114
+ # 再次反转序列顺序,恢复原始方向(填充在左侧)
115
+ return padded_reversed.flip(dims=(1,))
116
+
117
+ def _selective_log_softmax(self, logits, input_ids):
118
+ # Convert raw logits into log probabilities along the vocabulary axis.
119
+ # [batch_size, seq_len, vocab_size]
120
+ log_probs = F.log_softmax(logits, dim=-1)
121
+
122
+ # Reshape input_ids from (batch_size, seq_len) to (batch_size, seq_len, 1) for gathering.
123
+ # Then, gather the log probability for each token in input_ids.
124
+ selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
125
+
126
+ # Remove the extra last dimension to get back to shape (batch_size, seq_len).
127
+ return selected_log_probs.squeeze(-1)
128
+
129
+ def _compute_log_probabilities(
130
+ self,
131
+ model,
132
+ input_ids,
133
+ attention_mask,
134
+ logits_to_keep
135
+ ):
136
+ # prompt部分[1, 2, 3]
137
+ # 生成模型生成的内容是[4, 5],logits_to_keep=2
138
+ # 则下面的输入 [1, 2, 3, 4, 5], 正常情况下输出是[2, 3, 4, 5, 6]
139
+ # logits_to_keep=2,时输出[5, 6]
140
+ # 但是我们想要的[4, 5]部分
141
+ # 所以需要logits_to_keep=2+1,输出[4, 5, 6]
142
+
143
+ # [batch_size, total_seq_len, vocab_size]
144
+ outputs = model(
145
+ input_ids=input_ids,
146
+ attention_mask=attention_mask,
147
+ logits_to_keep=logits_to_keep + 1
148
+ )
149
+
150
+ # [batch_size, total_seq_len - 1, vocab_size]
151
+ logits = outputs['logits'][:, :-1, :]
152
+
153
+ input_ids = input_ids[:, -logits_to_keep:]
154
+ logits = logits[:, -logits_to_keep:, :]
155
+
156
+ # Compute and return the log probabilities for the selected tokens.
157
+ return self._selective_log_softmax(logits, input_ids), outputs['aux_loss']
158
+
159
+ def _compute_group_relative_advantages(self, rewards):
160
+ group_size = self.train_config.grpo_config.group_size
161
+
162
+ # Reshape rewards to group by prompt
163
+ # [batch, group_size]
164
+ rewards_by_group = rewards.view(-1, group_size)
165
+
166
+ # Compute mean and standard deviation for each prompt group
167
+ # [batch]
168
+ group_means = rewards_by_group.mean(dim=1)
169
+ group_stds = rewards_by_group.std(dim=1)
170
+
171
+ # Expand the means and stds to match the original flat rewards tensor shape
172
+ # [batch*group_size]
173
+ expanded_means = group_means.repeat_interleave(group_size)
174
+ expanded_stds = group_stds.repeat_interleave(group_size)
175
+
176
+ # Normalize rewards to get advantages
177
+ # [batch*group_size]
178
+ advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
179
+
180
+ # [batch*group_size, 1]
181
+ return advantages.unsqueeze(1) # Add dimension for token-wise operations
182
+
183
+ def _generate_completions(self, prompts, group_size: int):
184
+ pad_token_id = TrainerTools().tokenizer.pad
185
+ device = TrainerTools().parallel.device
186
+
187
+ # 左边添加pad,对齐prompt长度
188
+ # [batch, max_prompt_len]
189
+ prompt_ids = self._left_pad_sequence(prompts, padding_value=pad_token_id)
190
+ prompt_ids = prompt_ids.to(device)
191
+
192
+ prompt_len = prompt_ids.shape[1]
193
+
194
+ # [batch*group_size, max_prompt_len]
195
+ prompt_ids = prompt_ids.repeat_interleave(group_size, 0)
196
+ # [batch*group_size, max_prompt_len]
197
+ prompt_masks = prompt_ids != pad_token_id
198
+
199
+ # [batch*group_size, max_prompt_len+max_gen_len]
200
+ outputs: torch.Tensor = batch_generate(
201
+ # model=self.train_model,
202
+ model=self.generate_model,
203
+ tokens=prompt_ids,
204
+ pad_token_id=pad_token_id,
205
+ attention_mask=prompt_masks,
206
+ max_position_embeddings=self.train_config.model_config.max_position_embeddings,
207
+ max_new_tokens=self.train_config.grpo_config.gen_max_new_tokens,
208
+ temperature=self.train_config.grpo_config.gen_temperature,
209
+ k=self.train_config.grpo_config.gen_k,
210
+ p=self.train_config.grpo_config.gen_p,
211
+ device=device,
212
+ suppress_tokens=self.train_config.grpo_config.gen_suppress_tokens
213
+ )
214
+
215
+ # [batch*group_size, max_gen_len]
216
+ completion_ids = outputs[:, prompt_len:]
217
+ # [batch*group_size, max_gen_len]
218
+ completion_masks = (completion_ids != pad_token_id).int()
219
+
220
+ return prompt_ids, prompt_masks, completion_ids, completion_masks
221
+
222
+ def _generate_rollout_data(self, batch_data: List[dict]):
223
+ prompts = [item["prompt"] for item in batch_data]
224
+ answers = [item["answer"] for item in batch_data]
225
+ group_size = self.train_config.grpo_config.group_size
226
+
227
+ # 使用no_grad替换inference_mode
228
+ # 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
229
+ with torch.no_grad():
230
+ # with torch.inference_mode():
231
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(prompts, group_size)
232
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
233
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
234
+ logits_to_keep = completion_ids.shape[1]
235
+
236
+ # Compute old_log_probs from the current model, with gradients disabled.
237
+ old_log_probs, _ = self._compute_log_probabilities(self.generate_model, input_ids, attention_mask, logits_to_keep)
238
+
239
+ # Compute ref_log_probs from the reference model, which remains static.
240
+ ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
241
+
242
+ repeated_prompts = [p for p in prompts for _ in range(group_size)]
243
+ repeated_answers = [a for a in answers for _ in range(group_size)]
244
+
245
+ return {
246
+ 'input_ids': input_ids,
247
+ 'attention_mask': attention_mask,
248
+ 'completion_mask': completion_mask,
249
+ 'old_log_probs': old_log_probs,
250
+ 'ref_log_probs': ref_log_probs,
251
+ 'completion_ids': completion_ids,
252
+ 'repeated_prompts': repeated_prompts,
253
+ 'repeated_answers': repeated_answers,
254
+ 'logits_to_keep': logits_to_keep
255
+ }
256
+
257
+ def _maximize_grpo_objective(self, rollout_data):
258
+ device = TrainerTools().parallel.device
259
+
260
+ input_ids = rollout_data['input_ids']
261
+ attention_mask = rollout_data['attention_mask']
262
+ completion_mask = rollout_data['completion_mask']
263
+ old_log_probs = rollout_data['old_log_probs']
264
+ ref_log_probs = rollout_data['ref_log_probs']
265
+ logits_to_keep = rollout_data['logits_to_keep']
266
+ completion_ids = rollout_data['completion_ids']
267
+ repeated_prompts = rollout_data['repeated_prompts']
268
+ repeated_answers = rollout_data['repeated_answers']
269
+
270
+ # [batch*group_size]
271
+ rewards = torch.tensor(
272
+ self.reward_func(repeated_prompts, completion_ids, repeated_answers),
273
+ dtype=torch.float32,
274
+ device=device
275
+ )
276
+
277
+ # [batch*group_size, 1]
278
+ advantages = self._compute_group_relative_advantages(rewards)
279
+
280
+ # Compute current log probabilities
281
+ log_probs, aux_loss = self._compute_log_probabilities(self.train_model, input_ids, attention_mask, logits_to_keep)
282
+
283
+ loss = self.criterion(
284
+ log_probs=log_probs,
285
+ old_log_probs=old_log_probs,
286
+ ref_log_probs=ref_log_probs,
287
+ completion_mask=completion_mask,
288
+ advantages=advantages
289
+ )
290
+
291
+ return loss, aux_loss
292
+
293
+ def train(self):
294
+ global_steps = 0
295
+ skipping_train = False
296
+ device = TrainerTools().parallel.device
297
+ aux_loss_coef = self.train_config.loss_config.aux_loss_coef
298
+
299
+ for epoch in range(self.train_config.n_epochs):
300
+ load_checkpoint_for_eval(model=self.reference_model, device=device)
301
+ self.train_model.train()
302
+ file_count = len(self.train_config.file_dataset)
303
+
304
+ for file_idx in range(file_count):
305
+ file_path = self.train_config.file_dataset[file_idx]
306
+ dataset = self._create_dataset(file_path)
307
+
308
+ train_data_loader = TrainerTools().parallel.process_dataloader(
309
+ dataset=dataset,
310
+ data_loader_kwargs=self.data_loader_kwargs,
311
+ sampler_kwargs=self.sampler_kwargs
312
+ )
313
+
314
+ last_ckpt_batch = 0
315
+ batch_count_per_file = len(train_data_loader)
316
+
317
+ TrainerTools().parallel.on_epoch_start(epoch)
318
+ self._on_file_start(epoch, file_path)
319
+
320
+ for batch, batch_data in enumerate(train_data_loader):
321
+ global_steps += 1
322
+ if global_steps < self.last_global_steps:
323
+ skipping_train = True
324
+ continue
325
+
326
+ skipping_train = False
327
+
328
+ # start generate
329
+ # 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
330
+ self.generate_model.to(TrainerTools().parallel.device)
331
+ self.reference_model.to(TrainerTools().parallel.device)
332
+
333
+ # 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
334
+ load_checkpoint_for_eval(self.generate_model, TrainerTools().parallel.device)
335
+ # 生成数据
336
+ rollout_data = self._generate_rollout_data(batch_data)
337
+
338
+ # 卸载到cpu上,等待下次使用时再to gpu
339
+ self.generate_model.to('cpu')
340
+ self.reference_model.to('cpu')
341
+ torch.cuda.empty_cache()
342
+ # end generate
343
+
344
+ try:
345
+ for grpo_step in range(self.train_config.grpo_config.grpo_steps):
346
+ with self.ctx:
347
+ loss, aux_loss = self._maximize_grpo_objective(rollout_data)
348
+ if aux_loss_coef and aux_loss:
349
+ loss += aux_loss_coef * aux_loss
350
+
351
+ self._backward_loss(loss)
352
+
353
+ if TrainerTools().parallel.parallel_train:
354
+ dist.all_reduce(loss, dist.ReduceOp.AVG)
355
+
356
+ # ds模式已经集成gradient_clipping
357
+ if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
358
+ # clip grad
359
+ self.scalar.unscale_(self.optimizer)
360
+ torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
361
+
362
+ self._step()
363
+
364
+ self._log_loss(
365
+ epoch_tag=f'epoch: {epoch}',
366
+ file_tag=f'file: {file_idx + 1}/{file_count}',
367
+ batch_tag=f'batch: {batch}/{batch_count_per_file}',
368
+ loss=loss.detach().item()
369
+ )
370
+ except Exception as e:
371
+ self._on_exception(e, epoch, batch)
372
+ finally:
373
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
374
+
375
+ if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
376
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
377
+ last_ckpt_batch = batch
378
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
379
+
380
+ try:
381
+ del loss
382
+ except UnboundLocalError: ...
383
+
384
+ # end epoch
385
+ if not skipping_train:
386
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
387
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
388
+ TrainerTools().parallel.on_epoch_end(epoch)
389
+ self._on_epoch_end(tag=f'epoch:{epoch}')
390
+
391
+ # 等待checkpoint保存完成
392
+ time.sleep(10)
393
+ TrainerTools().parallel.destroy()
llm_trainer/log.py ADDED
@@ -0,0 +1,16 @@
1
+ import time, os
2
+
3
+ def get_log_dir() -> str:
4
+ log_dir = os.environ['LOG_DIR']
5
+ if not os.path.exists(log_dir):
6
+ os.mkdir(log_dir)
7
+
8
+ return f'{log_dir}/' if not log_dir.endswith('/') else log_dir
9
+
10
+ def log(msg: str, log_file=None):
11
+ cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
12
+ if not log_file:
13
+ print(f'[{cur_time}] {msg}')
14
+ else:
15
+ with open(log_file, 'a') as f:
16
+ f.write(f"[{cur_time}] {msg}")
llm_trainer/loss.py ADDED
@@ -0,0 +1,171 @@
1
+ from typing import List, Optional, Tuple
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class LMLoss(nn.Module):
8
+ """
9
+ llm loss
10
+ """
11
+ def __init__(
12
+ self,
13
+ ignore_index: int = -100,
14
+ *,
15
+ critical_tokens: Optional[List[int]] = None,
16
+ critical_alpha: float = 1.0,
17
+ vocab_size: int = 0
18
+ ):
19
+ super().__init__()
20
+ self.ignore_index = ignore_index
21
+ self.critical_tokens = critical_tokens
22
+ self.critical_alpha = critical_alpha
23
+
24
+ if critical_tokens and vocab_size > 0:
25
+ self.register_buffer('weights', torch.ones(vocab_size))
26
+ # 为关键token设置权重
27
+ self.weights[self.critical_tokens] = critical_alpha
28
+
29
+
30
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
31
+ # logits shape (batch, seq_len, vocab_size)
32
+ # labels shape (batch, seq_len)
33
+ shift_logits = logits[..., :-1, :].contiguous()
34
+ shift_labels = labels[..., 1:].contiguous()
35
+
36
+ logits = shift_logits.reshape(-1, logits.shape[-1])
37
+ targets = shift_labels.reshape(-1)
38
+
39
+ ce_loss = F.cross_entropy(
40
+ logits,
41
+ targets,
42
+ ignore_index=self.ignore_index,
43
+ weight=self.weights.to(logits.device, dtype=logits.dtype) if self.critical_tokens else None
44
+ )
45
+
46
+ # 添加额外惩罚项(可选)
47
+ # if self.critical_tokens:
48
+ # crit_mask = torch.isin(targets, torch.tensor(self.critical_tokens).to(targets.device))
49
+ # crit_logits = logits[crit_mask]
50
+ # crit_targets = targets[crit_mask]
51
+ # extra_loss = F.cross_entropy(crit_logits, crit_targets, ignore_index=self.ignore_index)
52
+ # return ce_loss + extra_loss * (self.critical_alpha - 1) # 增强惩罚
53
+
54
+ return ce_loss
55
+
56
+
57
+ class KDLoss(nn.Module):
58
+ """
59
+ Language Model Knowledge Distillation Loss
60
+ https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py#L266
61
+ """
62
+ def __init__(self, ignore_index: int = -100):
63
+ super().__init__()
64
+ self.ignore_index = ignore_index
65
+
66
+ def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
67
+ teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
68
+ inf_mask = torch.isinf(logits)
69
+
70
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
71
+ prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
72
+
73
+ x = torch.sum(prod_probs, dim=-1).view(-1)
74
+ mask = (labels != self.ignore_index).int()
75
+
76
+ distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
77
+
78
+ return distil_loss
79
+
80
+
81
+ class DPOLoss(nn.Module):
82
+ def __init__(
83
+ self,
84
+ beta: float,
85
+ label_smoothing: float = 0.0,
86
+ ipo: bool = False
87
+ ):
88
+ super().__init__()
89
+ self.beta = beta
90
+ self.label_smoothing = label_smoothing
91
+ self.ipo = ipo
92
+
93
+ def forward(
94
+ self,
95
+ policy_logps: torch.Tensor,
96
+ reference_logps: torch.Tensor,
97
+ ) -> torch.Tensor:
98
+ batch_size = reference_logps.shape[0]
99
+ ref_chosen_probs = reference_logps[:batch_size//2]
100
+ ref_reject_probs = reference_logps[batch_size//2:]
101
+ policy_chosen_probs = policy_logps[:batch_size//2]
102
+ policy_reject_probs = policy_logps[batch_size//2:]
103
+
104
+ pi_logratios = policy_chosen_probs - policy_reject_probs
105
+ ref_logratios = ref_chosen_probs - ref_reject_probs
106
+ logits = pi_logratios - ref_logratios
107
+
108
+ if self.ipo:
109
+ losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
110
+ else:
111
+ # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
112
+ losses = (
113
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
114
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
115
+ )
116
+
117
+ loss = losses.mean()
118
+ # chosen_rewards = self.beta * (policy_chosen_probs - ref_chosen_probs).detach()
119
+ # rejected_rewards = self.beta * (policy_reject_probs - ref_reject_probs).detach()
120
+
121
+ return loss
122
+
123
+
124
+ class GRPOLoss(nn.Module):
125
+ def __init__(
126
+ self,
127
+ clip_eps: float,
128
+ kl_weight: float
129
+ ):
130
+ super().__init__()
131
+ self.clip_eps = clip_eps
132
+ self.kl_weight = kl_weight
133
+
134
+ def forward(
135
+ self,
136
+ log_probs: torch.Tensor,
137
+ old_log_probs: torch.Tensor,
138
+ ref_log_probs: torch.Tensor,
139
+ completion_mask: torch.Tensor,
140
+ advantages: torch.Tensor
141
+ ) -> torch.Tensor:
142
+ # Compute policy ratio
143
+ ratio = torch.exp(log_probs - old_log_probs)
144
+
145
+ # Compute surrogate loss with clipping
146
+ surrogate1 = ratio * advantages
147
+ surrogate2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
148
+ surrogate_loss = torch.min(surrogate1, surrogate2)
149
+
150
+ # Compute KL divergence penalty
151
+ kl_div = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
152
+
153
+ # Combine losses
154
+ per_token_loss = surrogate_loss - self.kl_weight * kl_div
155
+ loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
156
+
157
+ return loss
158
+
159
+ # kl = self._approx_kl_divergence(
160
+ # log_probs=log_probs,
161
+ # ref_log_probs=ref_log_probs,
162
+ # mask=mask,
163
+ # )
164
+ #
165
+ # ratio = (log_probs - old_log_probs).exp()
166
+ # surr1 = ratio * advantages
167
+ # surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
168
+ # loss = -torch.min(surr1, surr2) + self.kl_weight * kl
169
+ #
170
+ # loss = self._masked_mean(loss, mask, dim=-1).mean()
171
+ # return loss, kl.mean()