project-llm-trainer 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +1 -3
- llm_trainer/grpo_trainer.py +2 -4
- llm_trainer/sft_trainer.py +9 -5
- llm_trainer/trainer.py +11 -11
- {project_llm_trainer-0.3.2.dist-info → project_llm_trainer-0.3.4.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.3.2.dist-info → project_llm_trainer-0.3.4.dist-info}/RECORD +15 -15
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.3.2.dist-info → project_llm_trainer-0.3.4.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.3.2.dist-info → project_llm_trainer-0.3.4.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -5,8 +5,6 @@ from torch.utils.data import Dataset
|
|
|
5
5
|
import torch.distributed as dist
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
|
|
8
|
-
from llm_model import LlmModel
|
|
9
|
-
|
|
10
8
|
from .parallel_ds import DsParallel
|
|
11
9
|
from .parallel_fsdp import FsdpParallel
|
|
12
10
|
from .trainer import Trainer
|
|
@@ -41,7 +39,7 @@ class DPOTrainer(Trainer):
|
|
|
41
39
|
def _init_reference_model(self):
|
|
42
40
|
parallel = TrainerTools().new_parallel()
|
|
43
41
|
|
|
44
|
-
reference_model =
|
|
42
|
+
reference_model = self._new_model(self.train_config)
|
|
45
43
|
if self.train_config.init_state_dict:
|
|
46
44
|
reference_model.load_state_dict(self.train_config.init_state_dict, strict=False)
|
|
47
45
|
self.train_config.init_state_dict = None
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -7,8 +7,6 @@ from torch.nn.utils.rnn import pad_sequence
|
|
|
7
7
|
import torch.distributed as dist
|
|
8
8
|
import torch.nn.functional as F
|
|
9
9
|
|
|
10
|
-
from llm_model import LlmModel
|
|
11
|
-
|
|
12
10
|
from .parallel_ds import DsParallel
|
|
13
11
|
from .trainer import Trainer
|
|
14
12
|
from .train_configs import TrainConfig
|
|
@@ -50,7 +48,7 @@ class GRPOTrainer(Trainer):
|
|
|
50
48
|
save_checkpoint(self.train_model, self.optimizer)
|
|
51
49
|
|
|
52
50
|
def _init_reference_model(self):
|
|
53
|
-
reference_model =
|
|
51
|
+
reference_model = self._new_model(self.train_config)
|
|
54
52
|
|
|
55
53
|
device = 'cpu' # TrainerTools().parallel.device
|
|
56
54
|
reference_model.to(device)
|
|
@@ -64,7 +62,7 @@ class GRPOTrainer(Trainer):
|
|
|
64
62
|
|
|
65
63
|
def _init_generate_model(self):
|
|
66
64
|
return copy.deepcopy(self.reference_model)
|
|
67
|
-
# generate_model =
|
|
65
|
+
# generate_model = self._new_model(self.train_config)
|
|
68
66
|
#
|
|
69
67
|
# device = 'cpu' #TrainerTools().parallel.device
|
|
70
68
|
# generate_model.to(device)
|
llm_trainer/sft_trainer.py
CHANGED
|
@@ -32,11 +32,15 @@ class SFTTrainer(Trainer):
|
|
|
32
32
|
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
33
33
|
file_path = self.train_config.file_dataset[file_idx]
|
|
34
34
|
max_position_embeddings = self.train_config.model_config.max_position_embeddings
|
|
35
|
+
|
|
36
|
+
image_tag_file_path = None
|
|
37
|
+
tokens_per_image = -1
|
|
38
|
+
|
|
35
39
|
if isinstance(self.train_config.model_config, VLMConfig):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
40
|
+
if self.train_config.image_tags_file_dataset:
|
|
41
|
+
image_tag_file_path = self.train_config.image_tags_file_dataset[file_idx]
|
|
42
|
+
|
|
43
|
+
if self.train_config.model_config.tokens_per_image:
|
|
44
|
+
tokens_per_image = self.train_config.model_config.tokens_per_image
|
|
41
45
|
|
|
42
46
|
return LineByLineTextDataset(file_path, max_position_embeddings, image_tag_file_path, tokens_per_image), file_path
|
llm_trainer/trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from contextlib import nullcontext
|
|
3
|
-
from typing import Optional, Tuple, List, Dict, Any
|
|
3
|
+
from typing import Optional, Tuple, List, Dict, Any, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
@@ -110,16 +110,19 @@ class Trainer:
|
|
|
110
110
|
self.pixel_values_provider = None
|
|
111
111
|
self.tokens_per_image = -1
|
|
112
112
|
|
|
113
|
+
def _new_model(self, train_config: TrainConfig):
|
|
114
|
+
if isinstance(train_config.model_config, VLMConfig):
|
|
115
|
+
return VlmModel(train_config.model_config)
|
|
116
|
+
else:
|
|
117
|
+
return LlmModel(train_config.model_config)
|
|
118
|
+
|
|
113
119
|
def _init_train_model_and_optim(
|
|
114
120
|
self,
|
|
115
121
|
initial_lr: float,
|
|
116
122
|
parallel_kwargs: dict,
|
|
117
123
|
use_ds_optim: bool
|
|
118
124
|
):
|
|
119
|
-
|
|
120
|
-
model = VlmModel(self.train_config.model_config)
|
|
121
|
-
else:
|
|
122
|
-
model = LlmModel(self.train_config.model_config)
|
|
125
|
+
model = self._new_model(self.train_config)
|
|
123
126
|
|
|
124
127
|
if self.train_config.init_state_dict:
|
|
125
128
|
model.load_state_dict(self.train_config.init_state_dict, strict=False)
|
|
@@ -156,10 +159,7 @@ class Trainer:
|
|
|
156
159
|
|
|
157
160
|
def _init_eval_model(self) -> Optional[nn.Module]:
|
|
158
161
|
if TrainerTools().parallel.is_main_process:
|
|
159
|
-
|
|
160
|
-
return VlmModel(self.train_config.model_config).to('cpu')
|
|
161
|
-
else:
|
|
162
|
-
return LlmModel(self.train_config.model_config).to('cpu')
|
|
162
|
+
return self._new_model(self.train_config).to('cpu')
|
|
163
163
|
|
|
164
164
|
return None
|
|
165
165
|
|
|
@@ -400,7 +400,7 @@ class Trainer:
|
|
|
400
400
|
):
|
|
401
401
|
if TrainerTools().parallel.is_main_process:
|
|
402
402
|
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
403
|
-
if isinstance(self.
|
|
403
|
+
if isinstance(self.train_model, VlmModel) and self.pixel_values_provider and eval_image_tag:
|
|
404
404
|
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
405
405
|
else:
|
|
406
406
|
eval_pixel_values = None
|
|
@@ -422,7 +422,7 @@ class Trainer:
|
|
|
422
422
|
):
|
|
423
423
|
if TrainerTools().parallel.is_main_process:
|
|
424
424
|
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
425
|
-
if isinstance(self.
|
|
425
|
+
if isinstance(self.train_model, VlmModel) and self.pixel_values_provider and eval_image_tag:
|
|
426
426
|
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
427
427
|
else:
|
|
428
428
|
eval_pixel_values = None
|
|
@@ -2,11 +2,11 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=Dlkcit0o7Gx6S9QUrIrVp2pTurP9X0zVA7w7ImSuVQU,6049
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
5
|
+
llm_trainer/dpo_trainer.py,sha256=7Bf6snWcu2fT8QRDI1CSzmrc7Cog6JauIeK2KoW_f8I,13135
|
|
6
6
|
llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
|
|
7
7
|
llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
|
|
8
8
|
llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
|
|
9
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
9
|
+
llm_trainer/grpo_trainer.py,sha256=M6vp6QjxhBQVaw3e_3BJ4earuezQNKQ3JeZfQLBaSLQ,16370
|
|
10
10
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
11
11
|
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
12
12
|
llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
|
|
@@ -15,20 +15,20 @@ llm_trainer/parallel_ds.py,sha256=W_PkczyAlgffCRcQadN-Pf7H7HM7TU26v5W63jKELFM,99
|
|
|
15
15
|
llm_trainer/parallel_fsdp.py,sha256=u9XbbVTzcsMcaf-aQFrC_QwWsDRGoEpRmgvu1cKNtgk,3887
|
|
16
16
|
llm_trainer/parallel_none.py,sha256=a6tt3aBmCq5rSP7n2I-sF-hsZ992BbLbpbxutDCFJfs,607
|
|
17
17
|
llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
18
|
-
llm_trainer/sft_trainer.py,sha256=
|
|
18
|
+
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
19
19
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
20
20
|
llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
|
|
21
21
|
llm_trainer/train_configs.py,sha256=cadfo8RgxNUR-L3ZLyjiRXTQvhjUl4A1qENaq-ol8h4,15878
|
|
22
|
-
llm_trainer/trainer.py,sha256=
|
|
22
|
+
llm_trainer/trainer.py,sha256=5DgDzg0TReZrXsIaM6A4DzeJnzePNybGdfoVSDybQ2U,24308
|
|
23
23
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
24
|
-
project_llm_trainer-0.3.
|
|
25
|
-
project_llm_trainer-0.3.
|
|
26
|
-
project_llm_trainer-0.3.
|
|
27
|
-
project_llm_trainer-0.3.
|
|
28
|
-
project_llm_trainer-0.3.
|
|
29
|
-
project_llm_trainer-0.3.
|
|
30
|
-
project_llm_trainer-0.3.
|
|
31
|
-
project_llm_trainer-0.3.
|
|
32
|
-
project_llm_trainer-0.3.
|
|
33
|
-
project_llm_trainer-0.3.
|
|
34
|
-
project_llm_trainer-0.3.
|
|
24
|
+
project_llm_trainer-0.3.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
25
|
+
project_llm_trainer-0.3.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
26
|
+
project_llm_trainer-0.3.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
27
|
+
project_llm_trainer-0.3.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
28
|
+
project_llm_trainer-0.3.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
29
|
+
project_llm_trainer-0.3.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
30
|
+
project_llm_trainer-0.3.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
31
|
+
project_llm_trainer-0.3.4.dist-info/METADATA,sha256=Y8XjOGdQb7VxN5QKHyKICkkOzjGcXJuI6hPziULJNfc,195
|
|
32
|
+
project_llm_trainer-0.3.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
33
|
+
project_llm_trainer-0.3.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
34
|
+
project_llm_trainer-0.3.4.dist-info/RECORD,,
|
{project_llm_trainer-0.3.2.data → project_llm_trainer-0.3.4.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|