project-llm-trainer 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -143,8 +143,7 @@ def _generate(
143
143
 
144
144
  with torch.inference_mode():
145
145
  for _ in range(max_new_tokens):
146
- # 是否需要截取??
147
- t = tokens[:, -max_position_embeddings:]
146
+ t = tokens # tokens[:, -max_position_embeddings:]
148
147
  with ctx:
149
148
  result = model(
150
149
  t,
@@ -202,7 +201,7 @@ def _generate(
202
201
  def _streaming_generate(
203
202
  model: torch.nn.Module,
204
203
  *,
205
- prompt: str,
204
+ prompt: Union[str, torch.Tensor],
206
205
  max_position_embeddings: int,
207
206
  max_new_tokens: int,
208
207
  temperature: Optional[float] = 1.0,
@@ -214,7 +213,11 @@ def _streaming_generate(
214
213
  device: Union[str, torch.device, int] = None
215
214
  ):
216
215
  device = TrainerTools().parallel.device if not device else device
217
- encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
216
+
217
+ if isinstance(prompt, torch.Tensor):
218
+ encoded_tokens = prompt.to(device)
219
+ else:
220
+ encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
218
221
 
219
222
  generate_text_iterator = _generate(
220
223
  model=model,
@@ -237,7 +240,7 @@ def _streaming_generate(
237
240
  def streaming_generate(
238
241
  model: torch.nn.Module,
239
242
  *,
240
- prompt: str,
243
+ prompt: Union[str, torch.Tensor],
241
244
  max_position_embeddings: int,
242
245
  max_new_tokens: int,
243
246
  temperature: Optional[float] = 1.0,
@@ -246,7 +249,8 @@ def streaming_generate(
246
249
  pixel_values: Optional[torch.Tensor] = None,
247
250
  tokens_per_image: int = -1,
248
251
  suppress_tokens: Optional[List[int]] = None,
249
- device: Union[str, torch.device, int] = None
252
+ device: Union[str, torch.device, int] = None,
253
+ return_token: bool = False
250
254
  ):
251
255
  text_iterator = _streaming_generate(
252
256
  model=model,
@@ -264,13 +268,16 @@ def streaming_generate(
264
268
 
265
269
  for (token, is_full_result) in text_iterator:
266
270
  if not is_full_result:
267
- yield TrainerTools().tokenizer.decode(token.squeeze(0))
271
+ if return_token:
272
+ yield token.squeeze(0)
273
+ else:
274
+ yield TrainerTools().tokenizer.decode(token.squeeze(0))
268
275
 
269
276
 
270
277
  def generate(
271
278
  model: torch.nn.Module,
272
279
  *,
273
- prompt: str,
280
+ prompt: Union[str, torch.Tensor],
274
281
  max_position_embeddings: int,
275
282
  max_new_tokens: int,
276
283
  temperature: Optional[float] = 1.0,
@@ -279,7 +286,8 @@ def generate(
279
286
  pixel_values: Optional[torch.Tensor] = None,
280
287
  tokens_per_image: int = -1,
281
288
  suppress_tokens: Optional[List[int]] = None,
282
- device: Union[str, torch.device, int] = None
289
+ device: Union[str, torch.device, int] = None,
290
+ return_token: bool = False
283
291
  ):
284
292
  text_iterator = _streaming_generate(
285
293
  model=model,
@@ -297,7 +305,12 @@ def generate(
297
305
 
298
306
  for (token, is_full_result) in text_iterator:
299
307
  if is_full_result:
300
- return TrainerTools().tokenizer.decode(token.squeeze(0))
308
+ if return_token:
309
+ return token.squeeze(0)
310
+ else:
311
+ return TrainerTools().tokenizer.decode(token.squeeze(0))
312
+
313
+ return None
301
314
 
302
315
 
303
316
  def batch_generate(
@@ -164,7 +164,7 @@ class KDConfig:
164
164
 
165
165
  @dataclass(kw_only=True)
166
166
  class EvalConfig:
167
- max_new_tokens: int = 512
167
+ max_new_tokens: Optional[int] = None
168
168
  temperature: float = 1.0
169
169
  top_p: float = 0.95
170
170
  top_k: Optional[float] = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.8
3
+ Version: 0.5.9
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -4,7 +4,7 @@ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
5
5
  llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
6
6
  llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
- llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
7
+ llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
8
8
  llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
17
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
- llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
20
+ llm_trainer/train_configs.py,sha256=guV8xkG5TSGvYwFvsQV_mA8mDHLLVhL5L0xo_WMsMME,7347
21
21
  llm_trainer/trainer.py,sha256=U26dZc22nByfTZUzKeEiqqYVexBzgw0ep7N0Z2zIcWI,26141
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.8.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.8.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.8.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.8.dist-info/METADATA,sha256=54q4Nl2EWMYwShSqS8cLLqZ0iJVraAvJCGz8QPiVMiE,195
31
- project_llm_trainer-0.5.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.8.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.9.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.9.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.9.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.9.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.9.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.9.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.9.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.9.dist-info/METADATA,sha256=YfFvnbVfUyNMCByDKCJ1rB4Mj0uxGQ2wquSe4QKaiF4,195
31
+ project_llm_trainer-0.5.9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.9.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.9.dist-info/RECORD,,