project-llm-trainer 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/grpo_trainer.py +18 -9
- llm_trainer/loss.py +5 -4
- llm_trainer/train_configs.py +4 -3
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.5.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.5.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.5.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.5.dist-info}/top_level.txt +0 -0
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -50,6 +50,10 @@ class GRPOTrainer(Trainer):
|
|
|
50
50
|
self._use_origin_pad_sequence = True
|
|
51
51
|
|
|
52
52
|
def _init_ref_model(self):
|
|
53
|
+
# beta == 0,不需要ref_model
|
|
54
|
+
if self.train_config.grpo_config.loss_beta == 0.0:
|
|
55
|
+
return None
|
|
56
|
+
|
|
53
57
|
ref_model = self._new_model(self.train_config)
|
|
54
58
|
|
|
55
59
|
ref_model, _ = TrainerTools().parallel.process(
|
|
@@ -68,7 +72,8 @@ class GRPOTrainer(Trainer):
|
|
|
68
72
|
def _init_loss(self):
|
|
69
73
|
criterion = GRPOLoss(
|
|
70
74
|
beta=self.train_config.grpo_config.loss_beta,
|
|
71
|
-
|
|
75
|
+
clip_eps_low=self.train_config.grpo_config.loss_clip_eps,
|
|
76
|
+
clip_eps_high=self.train_config.grpo_config.loss_clip_eps_high,
|
|
72
77
|
delta=self.train_config.grpo_config.loss_delta,
|
|
73
78
|
importance_sampling_level=self.train_config.grpo_config.loss_importance_sampling_level,
|
|
74
79
|
loss_type=self.train_config.grpo_config.loss_type,
|
|
@@ -229,8 +234,11 @@ class GRPOTrainer(Trainer):
|
|
|
229
234
|
# Compute old_log_probs from the current model, with gradients disabled.
|
|
230
235
|
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
231
236
|
|
|
232
|
-
|
|
233
|
-
|
|
237
|
+
if self.ref_model:
|
|
238
|
+
# Compute ref_log_probs from the reference model, which remains static.
|
|
239
|
+
ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
|
|
240
|
+
else:
|
|
241
|
+
ref_log_probs = None
|
|
234
242
|
|
|
235
243
|
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
236
244
|
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
@@ -293,11 +301,12 @@ class GRPOTrainer(Trainer):
|
|
|
293
301
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
294
302
|
|
|
295
303
|
for epoch in range(self.train_config.n_epochs):
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
304
|
+
if self.ref_model:
|
|
305
|
+
sync_model_params(
|
|
306
|
+
_from=self.train_model,
|
|
307
|
+
_to=self.ref_model,
|
|
308
|
+
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
309
|
+
)
|
|
301
310
|
|
|
302
311
|
file_count = len(self.train_config.file_dataset)
|
|
303
312
|
|
|
@@ -365,7 +374,7 @@ class GRPOTrainer(Trainer):
|
|
|
365
374
|
self._log_loss(
|
|
366
375
|
epoch_tag=f'epoch: {epoch}',
|
|
367
376
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
368
|
-
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
377
|
+
batch_tag=f'batch: {batch}/{batch_count_per_file}, grpo_step={grpo_step}',
|
|
369
378
|
loss=current_loss
|
|
370
379
|
)
|
|
371
380
|
except Exception as e:
|
llm_trainer/loss.py
CHANGED
|
@@ -2,7 +2,6 @@ from typing import List, Optional
|
|
|
2
2
|
import torch
|
|
3
3
|
from torch import nn
|
|
4
4
|
import torch.nn.functional as F
|
|
5
|
-
from .tools import TrainerTools
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class LMLoss(nn.Module):
|
|
@@ -127,7 +126,8 @@ class GRPOLoss(nn.Module):
|
|
|
127
126
|
def __init__(
|
|
128
127
|
self,
|
|
129
128
|
beta: float,
|
|
130
|
-
|
|
129
|
+
clip_eps_low: float,
|
|
130
|
+
clip_eps_high: Optional[float] = None,
|
|
131
131
|
delta: Optional[float] = None,
|
|
132
132
|
importance_sampling_level: str = 'token',
|
|
133
133
|
loss_type: str = 'grpo',
|
|
@@ -136,7 +136,8 @@ class GRPOLoss(nn.Module):
|
|
|
136
136
|
super().__init__()
|
|
137
137
|
|
|
138
138
|
self.beta = beta
|
|
139
|
-
self.
|
|
139
|
+
self.clip_eps_low = clip_eps_low
|
|
140
|
+
self.clip_eps_high = clip_eps_high if clip_eps_high else clip_eps_low
|
|
140
141
|
self.delta = delta
|
|
141
142
|
self.importance_sampling_level = importance_sampling_level
|
|
142
143
|
self.loss_type = loss_type
|
|
@@ -166,7 +167,7 @@ class GRPOLoss(nn.Module):
|
|
|
166
167
|
log_importance_weights = log_ratio
|
|
167
168
|
|
|
168
169
|
coef_1 = torch.exp(log_importance_weights)
|
|
169
|
-
coef_2 = torch.clamp(coef_1, 1 - self.
|
|
170
|
+
coef_2 = torch.clamp(coef_1, 1 - self.clip_eps_low, 1 + self.clip_eps_high)
|
|
170
171
|
|
|
171
172
|
# Two-sided clipping
|
|
172
173
|
if self.delta is not None:
|
llm_trainer/train_configs.py
CHANGED
|
@@ -138,10 +138,11 @@ class GRPOConfig:
|
|
|
138
138
|
grpo_steps: int = 1
|
|
139
139
|
group_size: int = 12
|
|
140
140
|
mixup_alpha: float = 1.0
|
|
141
|
-
loss_beta: float = 0.04
|
|
142
|
-
loss_clip_eps: float =
|
|
141
|
+
loss_beta: float = 0.0 # or 0.04 for grpo
|
|
142
|
+
loss_clip_eps: float = 3e-4
|
|
143
|
+
loss_clip_eps_high: Optional[float] = 4e-4
|
|
143
144
|
loss_delta: Optional[float] = None
|
|
144
|
-
loss_importance_sampling_level: str = '
|
|
145
|
+
loss_importance_sampling_level: str = 'seq' # token or seq
|
|
145
146
|
loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
|
|
146
147
|
gen_max_new_tokens: Optional[int] = None
|
|
147
148
|
gen_temperature: Optional[float] = None
|
|
@@ -5,9 +5,9 @@ llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12
|
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=MXnP8Kc9CQJw0CB3uMbHxIYwvpuujai4hgbbpUut_K4,16808
|
|
9
9
|
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
|
-
llm_trainer/loss.py,sha256=
|
|
10
|
+
llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
|
|
11
11
|
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
|
-
llm_trainer/train_configs.py,sha256=
|
|
20
|
+
llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
|
|
21
21
|
llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
|
|
22
22
|
llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
|
|
23
|
-
project_llm_trainer-0.7.
|
|
24
|
-
project_llm_trainer-0.7.
|
|
25
|
-
project_llm_trainer-0.7.
|
|
26
|
-
project_llm_trainer-0.7.
|
|
27
|
-
project_llm_trainer-0.7.
|
|
28
|
-
project_llm_trainer-0.7.
|
|
29
|
-
project_llm_trainer-0.7.
|
|
30
|
-
project_llm_trainer-0.7.
|
|
31
|
-
project_llm_trainer-0.7.
|
|
32
|
-
project_llm_trainer-0.7.
|
|
33
|
-
project_llm_trainer-0.7.
|
|
23
|
+
project_llm_trainer-0.7.5.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.5.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.7.5.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.7.5.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.5.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.5.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.5.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.7.5.dist-info/METADATA,sha256=9DcoFVuXDrhxZOVWF1Ouzk7NF6NTEnpBTkg1n6bMCYQ,195
|
|
31
|
+
project_llm_trainer-0.7.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.5.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.5.dist-info/RECORD,,
|
{project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.5.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|