project-llm-trainer 0.5.17__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -1,4 +1,3 @@
1
- import time
2
1
  from typing import Tuple, List, Optional
3
2
  import torch
4
3
  from torch.utils.data import Dataset
@@ -11,11 +10,11 @@ from .train_configs import TrainConfig
11
10
  from .dataset import DPODataset
12
11
  from .loss import DPOLoss
13
12
  from .tools import TrainerTools
14
- from .utils import get_dpo_collate_fn
15
- from .partition_utils import (
16
- sync_model_params,
17
- unwrap_model_for_generation
13
+ from .utils import (
14
+ autocastcontext,
15
+ get_dpo_collate_fn
18
16
  )
17
+ from .partition_utils import sync_model_params
19
18
 
20
19
  from .checkpoint import (
21
20
  save_checkpoint,
@@ -37,7 +36,7 @@ class DPOTrainer(Trainer):
37
36
  eval_prompts=eval_prompts,
38
37
  eval_image_tags=eval_image_tags
39
38
  )
40
-
39
+ self.packed_sequences = False
41
40
  self.ref_model = self._init_ref_model()
42
41
 
43
42
  def _init_ref_model(self):
@@ -204,7 +203,7 @@ class DPOTrainer(Trainer):
204
203
  if TrainerTools().parallel.parallel_train:
205
204
  self.train_model.require_backward_grad_sync = need_update_grad
206
205
 
207
- with self.ctx:
206
+ with autocastcontext(TrainerTools().parallel.device_type):
208
207
  policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
209
208
  policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
210
209
  aux_loss = policy_outputs.get('aux_loss')
@@ -217,7 +216,7 @@ class DPOTrainer(Trainer):
217
216
  loss = self.criterion(policy_probs, ref_probs)
218
217
 
219
218
  if aux_loss_coef and aux_loss:
220
- loss += aux_loss_coef *aux_loss
219
+ loss += aux_loss_coef * aux_loss
221
220
 
222
221
  if gradient_accumulation_steps > 1:
223
222
  loss = loss / gradient_accumulation_steps
@@ -3,7 +3,10 @@ from contextlib import nullcontext
3
3
  import torch
4
4
  from llm_model import VlmModel, KVCache
5
5
  from .tools import TrainerTools
6
- from .utils import batch_repeat_image_tok
6
+ from .utils import (
7
+ autocastcontext,
8
+ batch_repeat_image_tok
9
+ )
7
10
 
8
11
 
9
12
  def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
@@ -124,13 +127,7 @@ def _generate(
124
127
  如果temperature很大但内容单一,需要增大k、p
125
128
  """
126
129
  use_kv_cache = True
127
-
128
- ctx = torch.autocast(
129
- device_type=device,
130
- dtype=TrainerTools().dtype,
131
- enabled=True,
132
- cache_enabled=None
133
- ) if TrainerTools().use_amp else nullcontext()
130
+ ctx = autocastcontext(device)
134
131
 
135
132
  if isinstance(model, VlmModel):
136
133
  tokens = batch_repeat_image_tok(tokens, tokens_per_image)
@@ -330,13 +327,7 @@ def batch_generate(
330
327
  device: Union[str, torch.device, int]
331
328
  ):
332
329
  use_kv_cache = True
333
-
334
- ctx = torch.autocast(
335
- device_type=device,
336
- dtype=TrainerTools().dtype,
337
- enabled=True,
338
- cache_enabled=None
339
- ) if TrainerTools().use_amp else nullcontext()
330
+ ctx = autocastcontext(device)
340
331
 
341
332
  if isinstance(model, VlmModel):
342
333
  tokens = batch_repeat_image_tok(tokens, tokens_per_image)
@@ -1,4 +1,3 @@
1
- import time
2
1
  from typing import Tuple, List, Union, Callable, Optional
3
2
  import torch
4
3
  from torch.utils.data import Dataset
@@ -14,6 +13,7 @@ from .loss import GRPOLoss
14
13
  from .tools import TrainerTools
15
14
  from .generate_utils import batch_generate
16
15
  from .log import log
16
+ from .utils import autocastcontext
17
17
 
18
18
  from .partition_utils import (
19
19
  sync_model_params,
@@ -41,6 +41,7 @@ class GRPOTrainer(Trainer):
41
41
  eval_image_tags=eval_image_tags
42
42
  )
43
43
 
44
+ self.packed_sequences = False
44
45
  self.reward_func = reward_func
45
46
  self.ref_model = self._init_ref_model()
46
47
 
@@ -66,8 +67,12 @@ class GRPOTrainer(Trainer):
66
67
 
67
68
  def _init_loss(self):
68
69
  criterion = GRPOLoss(
69
- clip_eps=self.train_config.grpo_config.clip_eps,
70
- kl_weight=self.train_config.grpo_config.kl_weight
70
+ beta=self.train_config.grpo_config.loss_beta,
71
+ clip_eps=self.train_config.grpo_config.loss_clip_eps,
72
+ delta=self.train_config.grpo_config.loss_delta,
73
+ importance_sampling_level=self.train_config.grpo_config.loss_importance_sampling_level,
74
+ loss_type=self.train_config.grpo_config.loss_type,
75
+ gen_max_new_tokens=self.train_config.grpo_config.gen_max_new_tokens
71
76
  )
72
77
 
73
78
  return criterion, None
@@ -337,7 +342,7 @@ class GRPOTrainer(Trainer):
337
342
  log(f'start train for batch {batch}/{batch_count_per_file}')
338
343
 
339
344
  for grpo_step in range(self.train_config.grpo_config.grpo_steps):
340
- with self.ctx:
345
+ with autocastcontext(TrainerTools().parallel.device_type):
341
346
  loss, aux_loss = self._maximize_grpo_objective(rollout_data)
342
347
  if aux_loss_coef and aux_loss:
343
348
  loss += aux_loss_coef * aux_loss
llm_trainer/log.py CHANGED
@@ -7,6 +7,7 @@ def get_log_dir() -> str:
7
7
 
8
8
  return f'{log_dir}/' if not log_dir.endswith('/') else log_dir
9
9
 
10
+
10
11
  def log(msg: str, log_file=None):
11
12
  cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
12
13
  if not log_file:
llm_trainer/loss.py CHANGED
@@ -2,6 +2,7 @@ from typing import List, Optional
2
2
  import torch
3
3
  from torch import nn
4
4
  import torch.nn.functional as F
5
+ from .tools import TrainerTools
5
6
 
6
7
 
7
8
  class LMLoss(nn.Module):
@@ -115,6 +116,7 @@ class DPOLoss(nn.Module):
115
116
  )
116
117
 
117
118
  loss = losses.mean()
119
+
118
120
  # chosen_rewards = self.beta * (policy_chosen_probs - ref_chosen_probs).detach()
119
121
  # rejected_rewards = self.beta * (policy_reject_probs - ref_reject_probs).detach()
120
122
 
@@ -124,12 +126,21 @@ class DPOLoss(nn.Module):
124
126
  class GRPOLoss(nn.Module):
125
127
  def __init__(
126
128
  self,
129
+ beta: float,
127
130
  clip_eps: float,
128
- kl_weight: float
131
+ delta: Optional[float] = None,
132
+ importance_sampling_level: str = 'token',
133
+ loss_type: str = 'grpo',
134
+ gen_max_new_tokens: Optional[float] = None
129
135
  ):
130
136
  super().__init__()
137
+
138
+ self.beta = beta
131
139
  self.clip_eps = clip_eps
132
- self.kl_weight = kl_weight
140
+ self.delta = delta
141
+ self.importance_sampling_level = importance_sampling_level
142
+ self.loss_type = loss_type
143
+ self.gen_max_new_tokens = gen_max_new_tokens
133
144
 
134
145
  def forward(
135
146
  self,
@@ -139,33 +150,41 @@ class GRPOLoss(nn.Module):
139
150
  completion_mask: torch.Tensor,
140
151
  advantages: torch.Tensor
141
152
  ) -> torch.Tensor:
142
- # Compute policy ratio
143
- ratio = torch.exp(log_probs - old_log_probs)
144
153
 
145
- # Compute surrogate loss with clipping
146
- surrogate1 = ratio * advantages
147
- surrogate2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
148
- surrogate_loss = torch.min(surrogate1, surrogate2)
154
+ if self.beta != 0.0:
155
+ per_token_kl = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
156
+ else:
157
+ per_token_kl = None
158
+
159
+ log_ratio = log_probs - old_log_probs
160
+ if self.importance_sampling_level == "seq":
161
+ # GSPO
162
+ log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
163
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
164
+ else:
165
+ # GRPO
166
+ log_importance_weights = log_ratio
149
167
 
150
- # Compute KL divergence penalty
151
- kl_div = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
168
+ coef_1 = torch.exp(log_importance_weights)
169
+ coef_2 = torch.clamp(coef_1, 1 - self.clip_eps, 1 + self.clip_eps)
152
170
 
153
- # Combine losses
154
- per_token_loss = surrogate_loss - self.kl_weight * kl_div
155
- loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
171
+ # Two-sided clipping
172
+ if self.delta is not None:
173
+ coef_1 = torch.clamp(coef_1, max=self.delta)
156
174
 
157
- return loss
175
+ per_token_loss1 = coef_1 * advantages
176
+ per_token_loss2 = coef_2 * advantages
177
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
178
+
179
+ if self.beta != 0.0:
180
+ per_token_loss = per_token_loss + self.beta * per_token_kl
158
181
 
159
- # kl = self._approx_kl_divergence(
160
- # log_probs=log_probs,
161
- # ref_log_probs=ref_log_probs,
162
- # mask=mask,
163
- # )
164
- #
165
- # ratio = (log_probs - old_log_probs).exp()
166
- # surr1 = ratio * advantages
167
- # surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
168
- # loss = -torch.min(surr1, surr2) + self.kl_weight * kl
169
- #
170
- # loss = self._masked_mean(loss, mask, dim=-1).mean()
171
- # return loss, kl.mean()
182
+ if self.loss_type == "bnpo":
183
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
184
+ elif self.loss_type == "dr_grpo":
185
+ assert self.gen_max_new_tokens is not None
186
+ loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.gen_max_new_tokens)
187
+ else:
188
+ loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
189
+
190
+ return loss
llm_trainer/parallel.py CHANGED
@@ -140,6 +140,9 @@ class Parallel(ABC):
140
140
  return 1
141
141
 
142
142
  def wait(self, msg=None):
143
+ if self.world_size == 1:
144
+ return
145
+
143
146
  msg = f' for {msg}' if msg else ''
144
147
  log(f'wait at {self.device}{msg}')
145
148
  dist.barrier()
@@ -4,6 +4,7 @@ import itertools
4
4
  from packaging import version
5
5
  from torch import nn
6
6
  from torch.nn.parallel import DistributedDataParallel as DDP
7
+ import torch.distributed as dist
7
8
 
8
9
  from .tools import TrainerTools
9
10
  from .parallel_ds import DsParallel
@@ -45,12 +46,40 @@ def unwrap_model_for_generation(model: nn.Module):
45
46
 
46
47
 
47
48
  def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
49
+ """
50
+ 必须在所有rank上调用,非rank0, _to 可以设置为None.
51
+ 当前函数不适用于_to是一个zero3模型
52
+ """
48
53
  if isinstance(TrainerTools().parallel, DsParallel):
49
- _sync_ds_model_params(_from, _to, mixup_alpha)
50
- elif isinstance(TrainerTools().parallel, DdpParallel):
51
- _sync_ddp_model_params(_from, _to, mixup_alpha)
54
+ state_dict = _get_ds_model_params(_from, only_rank0=_to is None)
55
+ elif isinstance(_from, DDP):
56
+ state_dict = _from.module.state_dict()
57
+ else:
58
+ state_dict = _from.state_dict()
59
+
60
+ if not _to or not state_dict:
61
+ return
62
+
63
+ unwrap_to_model = unwrap_model(_to)
64
+ if mixup_alpha == 1.0:
65
+ # 直接覆盖
66
+ unwrap_to_model.load_state_dict(state_dict)
52
67
  else:
53
- _copy_params(_from, _to, mixup_alpha)
68
+ # 混合参数
69
+ for param_name, target_param in unwrap_to_model.named_parameters():
70
+ if param_name in state_dict:
71
+ from_param_tensor = state_dict[param_name]
72
+ target_param.data.mul_(1.0 - mixup_alpha).add_(
73
+ from_param_tensor.data.to(target_param.device),
74
+ alpha=mixup_alpha
75
+ )
76
+
77
+ # if isinstance(TrainerTools().parallel, DsParallel):
78
+ # _sync_ds_model_params(_from, _to, mixup_alpha)
79
+ # elif isinstance(TrainerTools().parallel, DdpParallel):
80
+ # _sync_ddp_model_params(_from, _to, mixup_alpha)
81
+ # else:
82
+ # _copy_params(_from, _to, mixup_alpha)
54
83
 
55
84
 
56
85
  def unwrap_model(model) -> nn.Module:
@@ -66,6 +95,57 @@ def unwrap_model(model) -> nn.Module:
66
95
  return model
67
96
 
68
97
 
98
+ def _get_ds_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
99
+ """
100
+ 需要在所有rank上调用,然后只有rank0有值
101
+ """
102
+ import deepspeed
103
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
104
+
105
+ if model.zero_optimization_stage() != 3:
106
+ if TrainerTools().parallel.is_main_process:
107
+ return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
108
+ return None
109
+
110
+ # --- ZeRO-3 ---
111
+ # 只调用一次 GatheredParameters,传入所有参数
112
+ with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
113
+ if TrainerTools().parallel.is_main_process:
114
+ # 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
115
+ # 所以我们可以像操作普通模型一样直接调用 state_dict()
116
+ full_state_dict = model.module.state_dict()
117
+
118
+ # 将其克隆到 CPU 并返回
119
+ return {k: v.cpu().clone() for k, v in full_state_dict.items()}
120
+
121
+ # 其他 rank 执行到这里时,上下文结束,直接返回 None
122
+ return None
123
+
124
+
125
+ def _get_ds_model_params(model: nn.Module, only_rank0=False):
126
+ """
127
+ 从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
128
+ 兼容 ZeRO Stages 0, 1, 2, 3。
129
+ 包含了对 ZeRO-3 中分片参数的正确处理。
130
+ """
131
+
132
+ import deepspeed
133
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
134
+ state_dict = _get_ds_full_state_dict_on_rank0(model)
135
+
136
+ # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
137
+ # 我们需要将其广播给所有进程。
138
+ if not only_rank0 and TrainerTools().parallel.world_size > 1:
139
+ # 准备一个列表,rank 0 有数据,其他 rank 是占位符
140
+ object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
141
+ # 执行广播,这个操作是阻塞的,会同步所有进程
142
+ dist.broadcast_object_list(object_list, src=0)
143
+ # 所有进程从列表中获取广播后的 state_dict 副本
144
+ state_dict = object_list[0]
145
+
146
+ return state_dict
147
+
148
+
69
149
  def _copy_params(model, target_model, mixup_alpha):
70
150
  for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
71
151
  target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
@@ -79,6 +159,7 @@ def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alph
79
159
 
80
160
  if _from.zero_optimization_stage() == 3:
81
161
  with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
162
+ # why only rank 0?
82
163
  if TrainerTools().parallel.is_main_process:
83
164
  _copy_params(origin_from, _to, mixup_alpha)
84
165
  else:
@@ -21,6 +21,7 @@ class SFTTrainer(Trainer):
21
21
  eval_prompts=eval_prompts,
22
22
  eval_image_tags=eval_image_tags
23
23
  )
24
+ self.packed_sequences = False
24
25
 
25
26
  def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
26
27
  sft_collate_fn = get_sft_collate_fn(self.train_config.mask_prompt)
llm_trainer/tools.py CHANGED
@@ -31,15 +31,7 @@ class TrainerTools:
31
31
  self.tokenizer = Tokenizer(os.environ.get('TOKENIZERS_TYPE', 'zh_llama'))
32
32
  self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
33
33
 
34
- dtype = os.environ.get('DTYPE', None)
35
- self.dtype = dtypes[dtype] if dtype in dtypes else None
36
-
37
- if not self.dtype:
38
- self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
39
-
40
- log(f'word_size={self.parallel.world_size},'
41
- f' use_amp={self.use_amp},'
42
- f' dtype={self.dtype}')
34
+ log(f'word_size={self.parallel.world_size}, use_amp={self.use_amp}')
43
35
 
44
36
  def _new_parallel(self):
45
37
  parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
@@ -136,10 +136,13 @@ class DPOConfig:
136
136
  @dataclass(kw_only=True)
137
137
  class GRPOConfig:
138
138
  grpo_steps: int = 1
139
- clip_eps: float = 0.1
140
- kl_weight: float = 0.04
141
139
  group_size: int = 12
142
140
  mixup_alpha: float = 1.0
141
+ loss_beta: float = 0.04
142
+ loss_clip_eps: float = 0.1
143
+ loss_delta: Optional[float] = None
144
+ loss_importance_sampling_level: str = 'token' # token or seq
145
+ loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
143
146
  gen_max_new_tokens: Optional[int] = None
144
147
  gen_temperature: Optional[float] = None
145
148
  gen_k: Optional[int] = None
llm_trainer/trainer.py CHANGED
@@ -1,4 +1,3 @@
1
- from contextlib import nullcontext
2
1
  from typing import Optional, Tuple, List, Dict, Any
3
2
  import copy
4
3
 
@@ -37,6 +36,9 @@ from .checkpoint import (
37
36
 
38
37
  from .utils import (
39
38
  set_seed,
39
+ autocastcontext,
40
+ create_doc_boundary_mask,
41
+ generate_position_ids,
40
42
  pretrain_collate_fn,
41
43
  )
42
44
 
@@ -55,6 +57,17 @@ class Trainer:
55
57
  ):
56
58
  set_seed()
57
59
 
60
+ # 是否打包序列,仅pretrain阶段需要打包序列,
61
+ # [[1, 1, eos, 2, 2, eos]]
62
+ # doc_boundary_mask=[[[[0., 0., 0., 0., 0., 0.],
63
+ # [0., 0., 0., 0., 0., 0.],
64
+ # [0., 0., 0., 0., 0., 0.],
65
+ # [-inf, -inf, -inf, 0., 0., 0.],
66
+ # [-inf, -inf, -inf, 0., 0., 0.],
67
+ # [-inf, -inf, -inf, 0., 0., 0.]]]]
68
+ # position_ids=[[0, 1, 2, 0, 1, 2]]
69
+ self.packed_sequences = True
70
+
58
71
  self.train_config: TrainConfig = train_config
59
72
  self.eval_prompts = eval_prompts
60
73
  self.eval_image_tags = eval_image_tags
@@ -81,13 +94,6 @@ class Trainer:
81
94
 
82
95
  self.criterion, self.kd_loss = self._init_loss()
83
96
 
84
- self.ctx = torch.autocast(
85
- device_type=TrainerTools().parallel.device_type,
86
- dtype=TrainerTools().dtype,
87
- enabled=True,
88
- cache_enabled=None
89
- ) if TrainerTools().use_amp else nullcontext()
90
-
91
97
  load_checkpoint(
92
98
  self.train_model,
93
99
  optimizer=self.optimizer,
@@ -433,6 +439,14 @@ class Trainer:
433
439
 
434
440
  raise e
435
441
 
442
+ def _get_model_dtype(self):
443
+ if isinstance(TrainerTools().parallel, DsParallel):
444
+ import deepspeed
445
+ assert isinstance(self.train_model, deepspeed.DeepSpeedEngine)
446
+ return self.train_model.get_data_types()[0]
447
+ else:
448
+ return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
449
+
436
450
  def _eval(self, tag: str):
437
451
  with unwrap_model_for_generation(self.train_model) as generate_model:
438
452
  if TrainerTools().parallel.is_main_process:
@@ -526,8 +540,12 @@ class Trainer:
526
540
  inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
527
541
  attention_mask = inputs != TrainerTools().tokenizer.pad
528
542
 
529
- if TrainerTools().parallel.parallel_train:
530
- self.train_model.require_backward_grad_sync = need_update_grad
543
+ if self.packed_sequences:
544
+ doc_boundary_mask = create_doc_boundary_mask(inputs, self._get_model_dtype())
545
+ position_ids = generate_position_ids(inputs)
546
+ else:
547
+ doc_boundary_mask = None
548
+ position_ids = None
531
549
 
532
550
  if self.pixel_values_provider and 'image_tags' in batch_data:
533
551
  image_tags = batch_data['image_tags']
@@ -535,10 +553,15 @@ class Trainer:
535
553
  else:
536
554
  pixel_values = None
537
555
 
538
- with self.ctx:
556
+ if TrainerTools().parallel.parallel_train:
557
+ self.train_model.require_backward_grad_sync = need_update_grad
558
+
559
+ with autocastcontext(TrainerTools().parallel.device_type):
539
560
  result = self.train_model(
540
561
  inputs,
541
562
  attention_mask=attention_mask,
563
+ doc_boundary_mask=doc_boundary_mask,
564
+ position_ids=position_ids,
542
565
  pixel_values=pixel_values
543
566
  )
544
567
 
llm_trainer/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import random
2
+ from contextlib import nullcontext
2
3
  import torch
3
4
  from torch.nn.utils.rnn import pad_sequence
4
5
  import torch.nn.functional as F
@@ -14,6 +15,115 @@ def set_seed(seed=42):
14
15
  torch.cuda.manual_seed_all(seed)
15
16
 
16
17
 
18
+ def autocastcontext(device_type):
19
+ if TrainerTools().use_amp:
20
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
21
+ return torch.autocast(
22
+ device_type=device_type,
23
+ dtype=dtype,
24
+ enabled=True,
25
+ cache_enabled=None
26
+ )
27
+ else:
28
+ return nullcontext()
29
+
30
+
31
+ def create_doc_boundary_mask(
32
+ input_ids: torch.Tensor,
33
+ dtype: torch.dtype
34
+ ) -> torch.Tensor:
35
+ """
36
+ 根据文档结束符 (eot) 的位置,创建一个 attention mask 来阻止跨文档的注意力。
37
+
38
+ 这个函数生成的 mask 会阻止一个 token 关注 (attend to) 属于前面文档的 tokens。
39
+ 例如,对于输入 `[[1, 2, eot, 3, 4, eot]]`,
40
+ tokens `3` 和 `4` 将无法关注 `1`, `2`, 和第一个 `eot`。
41
+
42
+ Args:
43
+ input_ids (torch.Tensor): 输入的 token ID 张量,形状为 (bsz, seq_len)。
44
+ dtype (torch.dtype): 数据类型。
45
+
46
+ Returns:
47
+ torch.Tensor: 符合 attention 机制要求的 mask 张量,
48
+ 形状为 (bsz, 1, seq_len, seq_len)。
49
+ 值为 -inf 的位置表示被屏蔽,值为 0 的位置表示允许注意力。
50
+ """
51
+ # 获取 batch size 和 sequence length
52
+ bsz, seq_len = input_ids.shape
53
+
54
+ # 1. 确定每个 eot_token 的位置
55
+ # is_eot 是一个布尔张量,形状为 (bsz, seq_len)
56
+ is_eot = (input_ids == TrainerTools().tokenizer.end)
57
+
58
+ # 2. 为每个 token 分配一个文档 ID
59
+ # 我们使用 cumsum (累加和) 来创建递增的文档 ID。一个 token 所属的文档 ID,
60
+ # 取决于它前面有多少个 eot。
61
+ # 示例:
62
+ # input_ids: [[1, 2, 3, eot, 4, 5, eot]]
63
+ # is_eot: [F, F, F, T, F, F, T] -> [0, 0, 0, 1, 0, 0, 1]
64
+ # doc_ids_ending: [0, 0, 0, 1, 1, 1, 2] (cumsum 的结果)
65
+ # doc_ids: [0, 0, 0, 0, 1, 1, 1] (向右移位后的结果)
66
+ # 这个结果正确地将文档 0 分配给了前四个 token,将文档 1 分配给了后三个 token。
67
+ doc_ids_ending = torch.cumsum(is_eot, dim=-1)
68
+ doc_ids = F.pad(doc_ids_ending[:, :-1], (1, 0), value=0)
69
+
70
+ # 3. 通过比较 query 和 key 的文档 ID 来创建 mask
71
+ # 我们的目标是:当 query token 所在的文档 ID 大于 key token 所在的文档 ID 时,进行屏蔽。
72
+ # query_doc_ids 形状: (bsz, seq_len, 1)
73
+ # key_doc_ids 形状: (bsz, 1, seq_len)
74
+ query_doc_ids = doc_ids.unsqueeze(2)
75
+ key_doc_ids = doc_ids.unsqueeze(1)
76
+
77
+ # 利用 PyTorch 的广播机制,`query_doc_ids > key_doc_ids` 会创建一个
78
+ # 形状为 (bsz, seq_len, seq_len) 的布尔张量。
79
+ # 当 query 的文档 ID 大于 key 的文档 ID 时,值为 True,这正是我们需要屏蔽的位置。
80
+ boundary_mask = query_doc_ids > key_doc_ids
81
+
82
+ # 4. 将布尔 mask 转换为 attention 机制所需的浮点数 mask (-inf 和 0)
83
+ final_mask = torch.zeros(
84
+ (bsz, seq_len, seq_len), device=input_ids.device, dtype=dtype
85
+ )
86
+ final_mask.masked_fill_(boundary_mask, torch.finfo(dtype).min)
87
+
88
+ # 5. 增加一个维度以匹配 attention head 的输入要求 (bsz, num_heads, seq_len, seq_len)
89
+ # 这里我们只生成一个 mask,它可以被广播到所有的 head。
90
+ return final_mask.unsqueeze(1)
91
+
92
+
93
+ def generate_position_ids(input_ids: torch.Tensor):
94
+ """
95
+ 为打包序列生成 position_ids 张量。
96
+
97
+ 参数:
98
+ input_ids (torch.Tensor): 输入的 token ID 张量 (batch_size, sequence_length)。
99
+ end_of_text_id (int): 代表文本结束的特殊 token ID。
100
+
101
+ 返回:
102
+ torch.Tensor: 生成的 position_ids 张量。
103
+ """
104
+ # 获取输入张量的形状
105
+ batch_size, seq_length = input_ids.shape
106
+
107
+ # 创建一个与输入形状相同,全为0的张量来存储position_ids
108
+ # 第一个token的位置永远是0,所以这个初始化是正确的
109
+ position_ids = torch.zeros_like(input_ids, dtype=torch.long)
110
+
111
+ # 从第二个时间步 (t=1) 开始遍历整个序列
112
+ for t in range(1, seq_length):
113
+ # 检查前一个时间步 (t-1) 的token是否为 EOT token
114
+ # 这会为批次中的每个序列生成一个布尔值
115
+ is_reset_token = (input_ids[:, t - 1] == TrainerTools().tokenizer.end)
116
+
117
+ # 获取前一个时间步的位置ID
118
+ prev_position_ids = position_ids[:, t - 1]
119
+
120
+ # 如果前一个token是EOT,当前位置重置为0;否则,在前一个位置上加1
121
+ # torch.where 会根据 is_reset_token 的布尔值进行选择
122
+ position_ids[:, t] = torch.where(is_reset_token, 0, prev_position_ids + 1)
123
+
124
+ return position_ids
125
+
126
+
17
127
  def repeat_image_tok(
18
128
  tokens: torch.Tensor,
19
129
  tokens_per_image: int
@@ -43,43 +153,6 @@ def batch_repeat_image_tok(
43
153
  return torch.stack(new_tokens, dim=0)
44
154
 
45
155
 
46
- def _pad_sequence(batch_data):
47
- # [[x,x,x], [y,y,y]]
48
- inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
49
- # crossEntropy默认的ignore_index是-100
50
- labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
51
-
52
- return inputs, labels
53
-
54
-
55
- def _mask_prompt(labels):
56
- tokenizer = TrainerTools().tokenizer
57
- # 支持多轮会话的mask
58
- for batch, label in enumerate(labels):
59
- start_index = -1
60
- for index, token in enumerate(label):
61
- if token == tokenizer.system or token == tokenizer.user:
62
- start_index = index
63
- elif token == tokenizer.end and start_index != -1:
64
- labels[batch, start_index:index + 1] = -100
65
- start_index = -1
66
-
67
- return labels
68
-
69
-
70
- def _zero_pad_sequences(
71
- sequences: list[torch.Tensor], side: str = "left"
72
- ) -> torch.Tensor:
73
- assert side in ("left", "right")
74
- max_len = max(seq.size(0) for seq in sequences)
75
- padded_sequences = []
76
- for seq in sequences:
77
- pad_len = max_len - seq.size(0)
78
- padding = (pad_len, 0) if side == "left" else (0, pad_len)
79
- padded_sequences.append(F.pad(seq, padding))
80
- return torch.stack(padded_sequences, dim=0)
81
-
82
-
83
156
  def pretrain_collate_fn(batch_data):
84
157
  inputs, labels = _pad_sequence(batch_data)
85
158
 
@@ -219,4 +292,41 @@ def join_batch(batch_data: list[dict]) -> dict:
219
292
  data = None
220
293
  result[key] = data
221
294
 
222
- return result
295
+ return result
296
+
297
+
298
+ def _pad_sequence(batch_data):
299
+ # [[x,x,x], [y,y,y]]
300
+ inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
301
+ # crossEntropy默认的ignore_index是-100
302
+ labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
303
+
304
+ return inputs, labels
305
+
306
+
307
+ def _mask_prompt(labels):
308
+ tokenizer = TrainerTools().tokenizer
309
+ # 支持多轮会话的mask
310
+ for batch, label in enumerate(labels):
311
+ start_index = -1
312
+ for index, token in enumerate(label):
313
+ if token == tokenizer.system or token == tokenizer.user:
314
+ start_index = index
315
+ elif token == tokenizer.end and start_index != -1:
316
+ labels[batch, start_index:index + 1] = -100
317
+ start_index = -1
318
+
319
+ return labels
320
+
321
+
322
+ def _zero_pad_sequences(
323
+ sequences: list[torch.Tensor], side: str = "left"
324
+ ) -> torch.Tensor:
325
+ assert side in ("left", "right")
326
+ max_len = max(seq.size(0) for seq in sequences)
327
+ padded_sequences = []
328
+ for seq in sequences:
329
+ pad_len = max_len - seq.size(0)
330
+ padding = (pad_len, 0) if side == "left" else (0, pad_len)
331
+ padded_sequences.append(F.pad(seq, padding))
332
+ return torch.stack(padded_sequences, dim=0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.17
3
+ Version: 0.7.0
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -0,0 +1,33 @@
1
+ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
+ llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
3
+ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
+ llm_trainer/dpo_trainer.py,sha256=_8ZwOKQH69c6Fa5Cey5hNep7XUoI4jPIXQaQcV3soGw,12367
5
+ llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
6
+ llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
+ llm_trainer/generate_utils.py,sha256=zX5218RX4ltahCQCZVVCWQghCWhKslPk2NUnl_CakIE,15050
8
+ llm_trainer/grpo_trainer.py,sha256=0iWvpuMI5CDNIjH08Dd1ihZFqDYenVnHACiMY2GLJtg,16449
9
+ llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
10
+ llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
11
+ llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
12
+ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
+ llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
+ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
+ llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
16
+ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
+ llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
18
+ llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
+ llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
+ llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
21
+ llm_trainer/trainer.py,sha256=2TC2GJeoGd0fDE6CFodk1chsSkk0v0yO0wrFYim5t4g,27938
22
+ llm_trainer/utils.py,sha256=ox2fWtSOS7F2Nh7_FoHxuQgaps1jGW3q59VXz04wRuA,11491
23
+ project_llm_trainer-0.7.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.7.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.7.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.7.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.7.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.7.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.7.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.7.0.dist-info/METADATA,sha256=Q_UU9xBZIIBFOmfQJg1708lFfYn4bu5FA0fuxJCCcxQ,195
31
+ project_llm_trainer-0.7.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.7.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.7.0.dist-info/RECORD,,
@@ -1,33 +0,0 @@
1
- llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
3
- llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=pNJaXvk-g0lGkZoRhbODNH34hTiz8EdP4Z12ws4W0t8,12309
5
- llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
6
- llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
- llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
8
- llm_trainer/grpo_trainer.py,sha256=tuzcSi1uBzUPVKojEheJ3-Tx8-g99mf6LYYxC5nsNiw,16040
9
- llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
- llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
- llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
12
- llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
- llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
- llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
- llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
16
- llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
- llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
- llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
- llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
- llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
21
- llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
22
- llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.17.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.17.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.17.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.17.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.17.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.17.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.17.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.17.dist-info/METADATA,sha256=BVzwe45PQXSE-f5-BCZulqWCK3PIpKzxv9z__moTEJY,196
31
- project_llm_trainer-0.5.17.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.17.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.17.dist-info/RECORD,,