project-llm-trainer 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +2 -2
- llm_trainer/generate_utils.py +3 -5
- llm_trainer/grpo_trainer.py +2 -2
- llm_trainer/trainer.py +2 -2
- llm_trainer/utils.py +1 -1
- {project_llm_trainer-0.7.0.dist-info → project_llm_trainer-0.7.1.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.0.dist-info → project_llm_trainer-0.7.1.dist-info}/RECORD +16 -16
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.7.0.dist-info → project_llm_trainer-0.7.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.0.dist-info → project_llm_trainer-0.7.1.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -11,7 +11,7 @@ from .dataset import DPODataset
|
|
|
11
11
|
from .loss import DPOLoss
|
|
12
12
|
from .tools import TrainerTools
|
|
13
13
|
from .utils import (
|
|
14
|
-
|
|
14
|
+
autocast,
|
|
15
15
|
get_dpo_collate_fn
|
|
16
16
|
)
|
|
17
17
|
from .partition_utils import sync_model_params
|
|
@@ -203,7 +203,7 @@ class DPOTrainer(Trainer):
|
|
|
203
203
|
if TrainerTools().parallel.parallel_train:
|
|
204
204
|
self.train_model.require_backward_grad_sync = need_update_grad
|
|
205
205
|
|
|
206
|
-
with
|
|
206
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
207
207
|
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
|
|
208
208
|
policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
|
|
209
209
|
aux_loss = policy_outputs.get('aux_loss')
|
llm_trainer/generate_utils.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
from llm_model import VlmModel, KVCache
|
|
5
5
|
from .tools import TrainerTools
|
|
6
6
|
from .utils import (
|
|
7
|
-
|
|
7
|
+
autocast,
|
|
8
8
|
batch_repeat_image_tok
|
|
9
9
|
)
|
|
10
10
|
|
|
@@ -127,7 +127,6 @@ def _generate(
|
|
|
127
127
|
如果temperature很大但内容单一,需要增大k、p
|
|
128
128
|
"""
|
|
129
129
|
use_kv_cache = True
|
|
130
|
-
ctx = autocastcontext(device)
|
|
131
130
|
|
|
132
131
|
if isinstance(model, VlmModel):
|
|
133
132
|
tokens = batch_repeat_image_tok(tokens, tokens_per_image)
|
|
@@ -141,7 +140,7 @@ def _generate(
|
|
|
141
140
|
with torch.inference_mode():
|
|
142
141
|
for _ in range(max_new_tokens):
|
|
143
142
|
t = tokens # tokens[:, -max_position_embeddings:]
|
|
144
|
-
with
|
|
143
|
+
with autocast(device):
|
|
145
144
|
result = model(
|
|
146
145
|
t,
|
|
147
146
|
past_key_values=kv_cache,
|
|
@@ -327,7 +326,6 @@ def batch_generate(
|
|
|
327
326
|
device: Union[str, torch.device, int]
|
|
328
327
|
):
|
|
329
328
|
use_kv_cache = True
|
|
330
|
-
ctx = autocastcontext(device)
|
|
331
329
|
|
|
332
330
|
if isinstance(model, VlmModel):
|
|
333
331
|
tokens = batch_repeat_image_tok(tokens, tokens_per_image)
|
|
@@ -350,7 +348,7 @@ def batch_generate(
|
|
|
350
348
|
break
|
|
351
349
|
|
|
352
350
|
t = tokens #tokens[:, -max_position_embeddings:]
|
|
353
|
-
with
|
|
351
|
+
with autocast(device):
|
|
354
352
|
result = model(
|
|
355
353
|
t,
|
|
356
354
|
attention_mask=attention_mask,
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -13,7 +13,7 @@ from .loss import GRPOLoss
|
|
|
13
13
|
from .tools import TrainerTools
|
|
14
14
|
from .generate_utils import batch_generate
|
|
15
15
|
from .log import log
|
|
16
|
-
from .utils import
|
|
16
|
+
from .utils import autocast
|
|
17
17
|
|
|
18
18
|
from .partition_utils import (
|
|
19
19
|
sync_model_params,
|
|
@@ -342,7 +342,7 @@ class GRPOTrainer(Trainer):
|
|
|
342
342
|
log(f'start train for batch {batch}/{batch_count_per_file}')
|
|
343
343
|
|
|
344
344
|
for grpo_step in range(self.train_config.grpo_config.grpo_steps):
|
|
345
|
-
with
|
|
345
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
346
346
|
loss, aux_loss = self._maximize_grpo_objective(rollout_data)
|
|
347
347
|
if aux_loss_coef and aux_loss:
|
|
348
348
|
loss += aux_loss_coef * aux_loss
|
llm_trainer/trainer.py
CHANGED
|
@@ -36,7 +36,7 @@ from .checkpoint import (
|
|
|
36
36
|
|
|
37
37
|
from .utils import (
|
|
38
38
|
set_seed,
|
|
39
|
-
|
|
39
|
+
autocast,
|
|
40
40
|
create_doc_boundary_mask,
|
|
41
41
|
generate_position_ids,
|
|
42
42
|
pretrain_collate_fn,
|
|
@@ -556,7 +556,7 @@ class Trainer:
|
|
|
556
556
|
if TrainerTools().parallel.parallel_train:
|
|
557
557
|
self.train_model.require_backward_grad_sync = need_update_grad
|
|
558
558
|
|
|
559
|
-
with
|
|
559
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
560
560
|
result = self.train_model(
|
|
561
561
|
inputs,
|
|
562
562
|
attention_mask=attention_mask,
|
llm_trainer/utils.py
CHANGED
|
@@ -15,7 +15,7 @@ def set_seed(seed=42):
|
|
|
15
15
|
torch.cuda.manual_seed_all(seed)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def
|
|
18
|
+
def autocast(device_type):
|
|
19
19
|
if TrainerTools().use_amp:
|
|
20
20
|
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
21
21
|
return torch.autocast(
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
|
-
llm_trainer/generate_utils.py,sha256=
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=zxbLIzk34cHFw5yfRH8EBr0wrFTS7qFa5DepcC0WXwk,16435
|
|
9
9
|
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
10
|
llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
|
|
11
11
|
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,18
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
20
|
llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
22
|
-
llm_trainer/utils.py,sha256=
|
|
23
|
-
project_llm_trainer-0.7.
|
|
24
|
-
project_llm_trainer-0.7.
|
|
25
|
-
project_llm_trainer-0.7.
|
|
26
|
-
project_llm_trainer-0.7.
|
|
27
|
-
project_llm_trainer-0.7.
|
|
28
|
-
project_llm_trainer-0.7.
|
|
29
|
-
project_llm_trainer-0.7.
|
|
30
|
-
project_llm_trainer-0.7.
|
|
31
|
-
project_llm_trainer-0.7.
|
|
32
|
-
project_llm_trainer-0.7.
|
|
33
|
-
project_llm_trainer-0.7.
|
|
21
|
+
llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
|
|
22
|
+
llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
|
|
23
|
+
project_llm_trainer-0.7.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.7.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.7.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.7.1.dist-info/METADATA,sha256=5O5GDggubLuaVquiTdCwB3K2v8dD2EwqVVFvsgeSyZM,195
|
|
31
|
+
project_llm_trainer-0.7.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.1.dist-info/RECORD,,
|
{project_llm_trainer-0.7.0.data → project_llm_trainer-0.7.1.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|