project-llm-trainer 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -1,8 +1,7 @@
1
1
  import os
2
- from typing import Optional, Union, Tuple
2
+ from typing import Optional, Union
3
3
  import shutil
4
4
  import torch
5
- from sympy import false
6
5
  from torch import nn
7
6
  from torch.optim import Optimizer
8
7
  from torch.nn.parallel import DistributedDataParallel as DDP
@@ -68,7 +68,8 @@ class GRPOTrainer(Trainer):
68
68
  def _init_loss(self):
69
69
  criterion = GRPOLoss(
70
70
  beta=self.train_config.grpo_config.loss_beta,
71
- clip_eps=self.train_config.grpo_config.loss_clip_eps,
71
+ clip_eps_low=self.train_config.grpo_config.loss_clip_eps,
72
+ clip_eps_high=self.train_config.grpo_config.loss_clip_eps_high,
72
73
  delta=self.train_config.grpo_config.loss_delta,
73
74
  importance_sampling_level=self.train_config.grpo_config.loss_importance_sampling_level,
74
75
  loss_type=self.train_config.grpo_config.loss_type,
llm_trainer/loss.py CHANGED
@@ -2,7 +2,6 @@ from typing import List, Optional
2
2
  import torch
3
3
  from torch import nn
4
4
  import torch.nn.functional as F
5
- from .tools import TrainerTools
6
5
 
7
6
 
8
7
  class LMLoss(nn.Module):
@@ -127,7 +126,8 @@ class GRPOLoss(nn.Module):
127
126
  def __init__(
128
127
  self,
129
128
  beta: float,
130
- clip_eps: float,
129
+ clip_eps_low: float,
130
+ clip_eps_high: Optional[float] = None,
131
131
  delta: Optional[float] = None,
132
132
  importance_sampling_level: str = 'token',
133
133
  loss_type: str = 'grpo',
@@ -136,7 +136,8 @@ class GRPOLoss(nn.Module):
136
136
  super().__init__()
137
137
 
138
138
  self.beta = beta
139
- self.clip_eps = clip_eps
139
+ self.clip_eps_low = clip_eps_low
140
+ self.clip_eps_high = clip_eps_high if clip_eps_high else clip_eps_low
140
141
  self.delta = delta
141
142
  self.importance_sampling_level = importance_sampling_level
142
143
  self.loss_type = loss_type
@@ -166,7 +167,7 @@ class GRPOLoss(nn.Module):
166
167
  log_importance_weights = log_ratio
167
168
 
168
169
  coef_1 = torch.exp(log_importance_weights)
169
- coef_2 = torch.clamp(coef_1, 1 - self.clip_eps, 1 + self.clip_eps)
170
+ coef_2 = torch.clamp(coef_1, 1 - self.clip_eps_low, 1 + self.clip_eps_high)
170
171
 
171
172
  # Two-sided clipping
172
173
  if self.delta is not None:
@@ -138,10 +138,11 @@ class GRPOConfig:
138
138
  grpo_steps: int = 1
139
139
  group_size: int = 12
140
140
  mixup_alpha: float = 1.0
141
- loss_beta: float = 0.04
142
- loss_clip_eps: float = 0.1
141
+ loss_beta: float = 0.0 # or 0.04 for grpo
142
+ loss_clip_eps: float = 3e-4
143
+ loss_clip_eps_high: Optional[float] = 4e-4
143
144
  loss_delta: Optional[float] = None
144
- loss_importance_sampling_level: str = 'token' # token or seq
145
+ loss_importance_sampling_level: str = 'seq' # token or seq
145
146
  loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
146
147
  gen_max_new_tokens: Optional[int] = None
147
148
  gen_temperature: Optional[float] = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.7.2
3
+ Version: 0.7.4
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,13 +1,13 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=-sHPwhZwJfiSpbHTDto7n_oagnSVmLe8pkcU9x217gs,4459
2
+ llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
5
5
  llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
8
- llm_trainer/grpo_trainer.py,sha256=zxbLIzk34cHFw5yfRH8EBr0wrFTS7qFa5DepcC0WXwk,16435
8
+ llm_trainer/grpo_trainer.py,sha256=2mMuRa7UXAgPSgav4Wp9-cs0QOPWQghv2IrW515Gn2Q,16515
9
9
  llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
10
- llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
10
+ llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
11
11
  llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
12
12
  llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
17
  llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
- llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
20
+ llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
21
21
  llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
22
22
  llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
23
- project_llm_trainer-0.7.2.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.7.2.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.7.2.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.7.2.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.7.2.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.7.2.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.7.2.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.7.2.dist-info/METADATA,sha256=WYohRO3Qb9o9QD3UZWqWmtoEOzoYJNWmj1_Olds6P4c,195
31
- project_llm_trainer-0.7.2.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.7.2.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.7.2.dist-info/RECORD,,
23
+ project_llm_trainer-0.7.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.7.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.7.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.7.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.7.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.7.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.7.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.7.4.dist-info/METADATA,sha256=Uq-6PqNSqhtsjYPmwAuMW07Y4SLq_caecSoCxYSbJHw,195
31
+ project_llm_trainer-0.7.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.7.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.7.4.dist-info/RECORD,,