project-llm-trainer 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -2,6 +2,7 @@ import os
2
2
  from typing import Optional, Union, Tuple
3
3
  import shutil
4
4
  import torch
5
+ from sympy import false
5
6
  from torch import nn
6
7
  from torch.optim import Optimizer
7
8
  from torch.nn.parallel import DistributedDataParallel as DDP
@@ -36,6 +37,10 @@ def save_best_checkpoint(
36
37
  current_loss: float,
37
38
  last_best_checkpoint_loss: Optional[float] = None
38
39
  ) -> bool:
40
+ # 指定不保存最佳checkpoint
41
+ if os.environ.get('SAVE_BEST_CHECKPOINT', '1') != '1':
42
+ return False
43
+
39
44
  need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
40
45
  if need_replace and TrainerTools().parallel.is_main_process:
41
46
  try:
@@ -62,8 +67,7 @@ def save_best_checkpoint(
62
67
  os.remove(best_checkpoint_name)
63
68
 
64
69
  shutil.copy2(checkpoint_name, best_checkpoint_name)
65
- except:
66
- pass
70
+ except: pass
67
71
 
68
72
  TrainerTools().parallel.wait('save best checkpoint')
69
73
  return need_replace
@@ -11,7 +11,7 @@ from .dataset import DPODataset
11
11
  from .loss import DPOLoss
12
12
  from .tools import TrainerTools
13
13
  from .utils import (
14
- autocastcontext,
14
+ autocast,
15
15
  get_dpo_collate_fn
16
16
  )
17
17
  from .partition_utils import sync_model_params
@@ -203,7 +203,7 @@ class DPOTrainer(Trainer):
203
203
  if TrainerTools().parallel.parallel_train:
204
204
  self.train_model.require_backward_grad_sync = need_update_grad
205
205
 
206
- with autocastcontext(TrainerTools().parallel.device_type):
206
+ with autocast(TrainerTools().parallel.device_type):
207
207
  policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
208
208
  policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
209
209
  aux_loss = policy_outputs.get('aux_loss')
@@ -28,9 +28,11 @@ def save_ds_checkpoint(model: nn.Module):
28
28
 
29
29
  # 只在main rank上执行
30
30
  if TrainerTools().parallel.is_main_process:
31
+ # 最多保存多少checkpoint,默认为2
32
+ max_to_keep = int(os.environ.get('CKPT_MAX_TO_KEEP', '2'))
31
33
  # 删除历史checkpoint
32
34
  ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
33
- if len(ckpt_paths) > 2:
35
+ if len(ckpt_paths) > max_to_keep:
34
36
  # 按修改时间排序,找到最旧的目录
35
37
  oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
36
38
  try:
@@ -4,7 +4,7 @@ import torch
4
4
  from llm_model import VlmModel, KVCache
5
5
  from .tools import TrainerTools
6
6
  from .utils import (
7
- autocastcontext,
7
+ autocast,
8
8
  batch_repeat_image_tok
9
9
  )
10
10
 
@@ -127,7 +127,6 @@ def _generate(
127
127
  如果temperature很大但内容单一,需要增大k、p
128
128
  """
129
129
  use_kv_cache = True
130
- ctx = autocastcontext(device)
131
130
 
132
131
  if isinstance(model, VlmModel):
133
132
  tokens = batch_repeat_image_tok(tokens, tokens_per_image)
@@ -141,7 +140,7 @@ def _generate(
141
140
  with torch.inference_mode():
142
141
  for _ in range(max_new_tokens):
143
142
  t = tokens # tokens[:, -max_position_embeddings:]
144
- with ctx:
143
+ with autocast(device):
145
144
  result = model(
146
145
  t,
147
146
  past_key_values=kv_cache,
@@ -327,7 +326,6 @@ def batch_generate(
327
326
  device: Union[str, torch.device, int]
328
327
  ):
329
328
  use_kv_cache = True
330
- ctx = autocastcontext(device)
331
329
 
332
330
  if isinstance(model, VlmModel):
333
331
  tokens = batch_repeat_image_tok(tokens, tokens_per_image)
@@ -350,7 +348,7 @@ def batch_generate(
350
348
  break
351
349
 
352
350
  t = tokens #tokens[:, -max_position_embeddings:]
353
- with ctx:
351
+ with autocast(device):
354
352
  result = model(
355
353
  t,
356
354
  attention_mask=attention_mask,
@@ -13,7 +13,7 @@ from .loss import GRPOLoss
13
13
  from .tools import TrainerTools
14
14
  from .generate_utils import batch_generate
15
15
  from .log import log
16
- from .utils import autocastcontext
16
+ from .utils import autocast
17
17
 
18
18
  from .partition_utils import (
19
19
  sync_model_params,
@@ -342,7 +342,7 @@ class GRPOTrainer(Trainer):
342
342
  log(f'start train for batch {batch}/{batch_count_per_file}')
343
343
 
344
344
  for grpo_step in range(self.train_config.grpo_config.grpo_steps):
345
- with autocastcontext(TrainerTools().parallel.device_type):
345
+ with autocast(TrainerTools().parallel.device_type):
346
346
  loss, aux_loss = self._maximize_grpo_objective(rollout_data)
347
347
  if aux_loss_coef and aux_loss:
348
348
  loss += aux_loss_coef * aux_loss
llm_trainer/trainer.py CHANGED
@@ -36,7 +36,7 @@ from .checkpoint import (
36
36
 
37
37
  from .utils import (
38
38
  set_seed,
39
- autocastcontext,
39
+ autocast,
40
40
  create_doc_boundary_mask,
41
41
  generate_position_ids,
42
42
  pretrain_collate_fn,
@@ -556,7 +556,7 @@ class Trainer:
556
556
  if TrainerTools().parallel.parallel_train:
557
557
  self.train_model.require_backward_grad_sync = need_update_grad
558
558
 
559
- with autocastcontext(TrainerTools().parallel.device_type):
559
+ with autocast(TrainerTools().parallel.device_type):
560
560
  result = self.train_model(
561
561
  inputs,
562
562
  attention_mask=attention_mask,
llm_trainer/utils.py CHANGED
@@ -15,7 +15,7 @@ def set_seed(seed=42):
15
15
  torch.cuda.manual_seed_all(seed)
16
16
 
17
17
 
18
- def autocastcontext(device_type):
18
+ def autocast(device_type):
19
19
  if TrainerTools().use_amp:
20
20
  dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
21
21
  return torch.autocast(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.7.0
3
+ Version: 0.7.2
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
2
+ llm_trainer/checkpoint.py,sha256=-sHPwhZwJfiSpbHTDto7n_oagnSVmLe8pkcU9x217gs,4459
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=_8ZwOKQH69c6Fa5Cey5hNep7XUoI4jPIXQaQcV3soGw,12367
5
- llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
4
+ llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
5
+ llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
- llm_trainer/generate_utils.py,sha256=zX5218RX4ltahCQCZVVCWQghCWhKslPk2NUnl_CakIE,15050
8
- llm_trainer/grpo_trainer.py,sha256=0iWvpuMI5CDNIjH08Dd1ihZFqDYenVnHACiMY2GLJtg,16449
7
+ llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
8
+ llm_trainer/grpo_trainer.py,sha256=zxbLIzk34cHFw5yfRH8EBr0wrFTS7qFa5DepcC0WXwk,16435
9
9
  llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
10
10
  llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
11
11
  llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,18
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
20
  llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
21
- llm_trainer/trainer.py,sha256=2TC2GJeoGd0fDE6CFodk1chsSkk0v0yO0wrFYim5t4g,27938
22
- llm_trainer/utils.py,sha256=ox2fWtSOS7F2Nh7_FoHxuQgaps1jGW3q59VXz04wRuA,11491
23
- project_llm_trainer-0.7.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.7.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.7.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.7.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.7.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.7.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.7.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.7.0.dist-info/METADATA,sha256=Q_UU9xBZIIBFOmfQJg1708lFfYn4bu5FA0fuxJCCcxQ,195
31
- project_llm_trainer-0.7.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.7.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.7.0.dist-info/RECORD,,
21
+ llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
22
+ llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
23
+ project_llm_trainer-0.7.2.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.7.2.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.7.2.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.7.2.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.7.2.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.7.2.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.7.2.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.7.2.dist-info/METADATA,sha256=WYohRO3Qb9o9QD3UZWqWmtoEOzoYJNWmj1_Olds6P4c,195
31
+ project_llm_trainer-0.7.2.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.7.2.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.7.2.dist-info/RECORD,,