project-llm-trainer 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/generate_utils.py +23 -10
- llm_trainer/train_configs.py +1 -1
- llm_trainer/trainer.py +1 -0
- {project_llm_trainer-0.5.7.dist-info → project_llm_trainer-0.5.9.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.7.dist-info → project_llm_trainer-0.5.9.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.7.dist-info → project_llm_trainer-0.5.9.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.7.dist-info → project_llm_trainer-0.5.9.dist-info}/top_level.txt +0 -0
llm_trainer/generate_utils.py
CHANGED
|
@@ -143,8 +143,7 @@ def _generate(
|
|
|
143
143
|
|
|
144
144
|
with torch.inference_mode():
|
|
145
145
|
for _ in range(max_new_tokens):
|
|
146
|
-
#
|
|
147
|
-
t = tokens[:, -max_position_embeddings:]
|
|
146
|
+
t = tokens # tokens[:, -max_position_embeddings:]
|
|
148
147
|
with ctx:
|
|
149
148
|
result = model(
|
|
150
149
|
t,
|
|
@@ -202,7 +201,7 @@ def _generate(
|
|
|
202
201
|
def _streaming_generate(
|
|
203
202
|
model: torch.nn.Module,
|
|
204
203
|
*,
|
|
205
|
-
prompt: str,
|
|
204
|
+
prompt: Union[str, torch.Tensor],
|
|
206
205
|
max_position_embeddings: int,
|
|
207
206
|
max_new_tokens: int,
|
|
208
207
|
temperature: Optional[float] = 1.0,
|
|
@@ -214,7 +213,11 @@ def _streaming_generate(
|
|
|
214
213
|
device: Union[str, torch.device, int] = None
|
|
215
214
|
):
|
|
216
215
|
device = TrainerTools().parallel.device if not device else device
|
|
217
|
-
|
|
216
|
+
|
|
217
|
+
if isinstance(prompt, torch.Tensor):
|
|
218
|
+
encoded_tokens = prompt.to(device)
|
|
219
|
+
else:
|
|
220
|
+
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
218
221
|
|
|
219
222
|
generate_text_iterator = _generate(
|
|
220
223
|
model=model,
|
|
@@ -237,7 +240,7 @@ def _streaming_generate(
|
|
|
237
240
|
def streaming_generate(
|
|
238
241
|
model: torch.nn.Module,
|
|
239
242
|
*,
|
|
240
|
-
prompt: str,
|
|
243
|
+
prompt: Union[str, torch.Tensor],
|
|
241
244
|
max_position_embeddings: int,
|
|
242
245
|
max_new_tokens: int,
|
|
243
246
|
temperature: Optional[float] = 1.0,
|
|
@@ -246,7 +249,8 @@ def streaming_generate(
|
|
|
246
249
|
pixel_values: Optional[torch.Tensor] = None,
|
|
247
250
|
tokens_per_image: int = -1,
|
|
248
251
|
suppress_tokens: Optional[List[int]] = None,
|
|
249
|
-
device: Union[str, torch.device, int] = None
|
|
252
|
+
device: Union[str, torch.device, int] = None,
|
|
253
|
+
return_token: bool = False
|
|
250
254
|
):
|
|
251
255
|
text_iterator = _streaming_generate(
|
|
252
256
|
model=model,
|
|
@@ -264,13 +268,16 @@ def streaming_generate(
|
|
|
264
268
|
|
|
265
269
|
for (token, is_full_result) in text_iterator:
|
|
266
270
|
if not is_full_result:
|
|
267
|
-
|
|
271
|
+
if return_token:
|
|
272
|
+
yield token.squeeze(0)
|
|
273
|
+
else:
|
|
274
|
+
yield TrainerTools().tokenizer.decode(token.squeeze(0))
|
|
268
275
|
|
|
269
276
|
|
|
270
277
|
def generate(
|
|
271
278
|
model: torch.nn.Module,
|
|
272
279
|
*,
|
|
273
|
-
prompt: str,
|
|
280
|
+
prompt: Union[str, torch.Tensor],
|
|
274
281
|
max_position_embeddings: int,
|
|
275
282
|
max_new_tokens: int,
|
|
276
283
|
temperature: Optional[float] = 1.0,
|
|
@@ -279,7 +286,8 @@ def generate(
|
|
|
279
286
|
pixel_values: Optional[torch.Tensor] = None,
|
|
280
287
|
tokens_per_image: int = -1,
|
|
281
288
|
suppress_tokens: Optional[List[int]] = None,
|
|
282
|
-
device: Union[str, torch.device, int] = None
|
|
289
|
+
device: Union[str, torch.device, int] = None,
|
|
290
|
+
return_token: bool = False
|
|
283
291
|
):
|
|
284
292
|
text_iterator = _streaming_generate(
|
|
285
293
|
model=model,
|
|
@@ -297,7 +305,12 @@ def generate(
|
|
|
297
305
|
|
|
298
306
|
for (token, is_full_result) in text_iterator:
|
|
299
307
|
if is_full_result:
|
|
300
|
-
|
|
308
|
+
if return_token:
|
|
309
|
+
return token.squeeze(0)
|
|
310
|
+
else:
|
|
311
|
+
return TrainerTools().tokenizer.decode(token.squeeze(0))
|
|
312
|
+
|
|
313
|
+
return None
|
|
301
314
|
|
|
302
315
|
|
|
303
316
|
def batch_generate(
|
llm_trainer/train_configs.py
CHANGED
llm_trainer/trainer.py
CHANGED
|
@@ -4,7 +4,7 @@ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
|
4
4
|
llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
|
|
6
6
|
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
|
-
llm_trainer/generate_utils.py,sha256=
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
8
|
llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
-
llm_trainer/train_configs.py,sha256=
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
20
|
+
llm_trainer/train_configs.py,sha256=guV8xkG5TSGvYwFvsQV_mA8mDHLLVhL5L0xo_WMsMME,7347
|
|
21
|
+
llm_trainer/trainer.py,sha256=U26dZc22nByfTZUzKeEiqqYVexBzgw0ep7N0Z2zIcWI,26141
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.9.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.9.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.9.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.9.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.9.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.9.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.9.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.9.dist-info/METADATA,sha256=YfFvnbVfUyNMCByDKCJ1rB4Mj0uxGQ2wquSe4QKaiF4,195
|
|
31
|
+
project_llm_trainer-0.5.9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.9.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.9.dist-info/RECORD,,
|
{project_llm_trainer-0.5.7.data → project_llm_trainer-0.5.9.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|