project-llm-trainer 0.7.3__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/grpo_trainer.py +2 -1
- llm_trainer/loss.py +5 -4
- llm_trainer/train_configs.py +4 -3
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.4.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.4.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.4.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.3.dist-info → project_llm_trainer-0.7.4.dist-info}/top_level.txt +0 -0
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -68,7 +68,8 @@ class GRPOTrainer(Trainer):
|
|
|
68
68
|
def _init_loss(self):
|
|
69
69
|
criterion = GRPOLoss(
|
|
70
70
|
beta=self.train_config.grpo_config.loss_beta,
|
|
71
|
-
|
|
71
|
+
clip_eps_low=self.train_config.grpo_config.loss_clip_eps,
|
|
72
|
+
clip_eps_high=self.train_config.grpo_config.loss_clip_eps_high,
|
|
72
73
|
delta=self.train_config.grpo_config.loss_delta,
|
|
73
74
|
importance_sampling_level=self.train_config.grpo_config.loss_importance_sampling_level,
|
|
74
75
|
loss_type=self.train_config.grpo_config.loss_type,
|
llm_trainer/loss.py
CHANGED
|
@@ -2,7 +2,6 @@ from typing import List, Optional
|
|
|
2
2
|
import torch
|
|
3
3
|
from torch import nn
|
|
4
4
|
import torch.nn.functional as F
|
|
5
|
-
from .tools import TrainerTools
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class LMLoss(nn.Module):
|
|
@@ -127,7 +126,8 @@ class GRPOLoss(nn.Module):
|
|
|
127
126
|
def __init__(
|
|
128
127
|
self,
|
|
129
128
|
beta: float,
|
|
130
|
-
|
|
129
|
+
clip_eps_low: float,
|
|
130
|
+
clip_eps_high: Optional[float] = None,
|
|
131
131
|
delta: Optional[float] = None,
|
|
132
132
|
importance_sampling_level: str = 'token',
|
|
133
133
|
loss_type: str = 'grpo',
|
|
@@ -136,7 +136,8 @@ class GRPOLoss(nn.Module):
|
|
|
136
136
|
super().__init__()
|
|
137
137
|
|
|
138
138
|
self.beta = beta
|
|
139
|
-
self.
|
|
139
|
+
self.clip_eps_low = clip_eps_low
|
|
140
|
+
self.clip_eps_high = clip_eps_high if clip_eps_high else clip_eps_low
|
|
140
141
|
self.delta = delta
|
|
141
142
|
self.importance_sampling_level = importance_sampling_level
|
|
142
143
|
self.loss_type = loss_type
|
|
@@ -166,7 +167,7 @@ class GRPOLoss(nn.Module):
|
|
|
166
167
|
log_importance_weights = log_ratio
|
|
167
168
|
|
|
168
169
|
coef_1 = torch.exp(log_importance_weights)
|
|
169
|
-
coef_2 = torch.clamp(coef_1, 1 - self.
|
|
170
|
+
coef_2 = torch.clamp(coef_1, 1 - self.clip_eps_low, 1 + self.clip_eps_high)
|
|
170
171
|
|
|
171
172
|
# Two-sided clipping
|
|
172
173
|
if self.delta is not None:
|
llm_trainer/train_configs.py
CHANGED
|
@@ -138,10 +138,11 @@ class GRPOConfig:
|
|
|
138
138
|
grpo_steps: int = 1
|
|
139
139
|
group_size: int = 12
|
|
140
140
|
mixup_alpha: float = 1.0
|
|
141
|
-
loss_beta: float = 0.04
|
|
142
|
-
loss_clip_eps: float =
|
|
141
|
+
loss_beta: float = 0.0 # or 0.04 for grpo
|
|
142
|
+
loss_clip_eps: float = 3e-4
|
|
143
|
+
loss_clip_eps_high: Optional[float] = 4e-4
|
|
143
144
|
loss_delta: Optional[float] = None
|
|
144
|
-
loss_importance_sampling_level: str = '
|
|
145
|
+
loss_importance_sampling_level: str = 'seq' # token or seq
|
|
145
146
|
loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
|
|
146
147
|
gen_max_new_tokens: Optional[int] = None
|
|
147
148
|
gen_temperature: Optional[float] = None
|
|
@@ -5,9 +5,9 @@ llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12
|
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=2mMuRa7UXAgPSgav4Wp9-cs0QOPWQghv2IrW515Gn2Q,16515
|
|
9
9
|
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
|
-
llm_trainer/loss.py,sha256=
|
|
10
|
+
llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
|
|
11
11
|
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
|
-
llm_trainer/train_configs.py,sha256=
|
|
20
|
+
llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
|
|
21
21
|
llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
|
|
22
22
|
llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
|
|
23
|
-
project_llm_trainer-0.7.
|
|
24
|
-
project_llm_trainer-0.7.
|
|
25
|
-
project_llm_trainer-0.7.
|
|
26
|
-
project_llm_trainer-0.7.
|
|
27
|
-
project_llm_trainer-0.7.
|
|
28
|
-
project_llm_trainer-0.7.
|
|
29
|
-
project_llm_trainer-0.7.
|
|
30
|
-
project_llm_trainer-0.7.
|
|
31
|
-
project_llm_trainer-0.7.
|
|
32
|
-
project_llm_trainer-0.7.
|
|
33
|
-
project_llm_trainer-0.7.
|
|
23
|
+
project_llm_trainer-0.7.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.7.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.7.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.7.4.dist-info/METADATA,sha256=Uq-6PqNSqhtsjYPmwAuMW07Y4SLq_caecSoCxYSbJHw,195
|
|
31
|
+
project_llm_trainer-0.7.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.4.dist-info/RECORD,,
|
{project_llm_trainer-0.7.3.data → project_llm_trainer-0.7.4.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|