project-llm-trainer 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -5,8 +5,6 @@ from torch.utils.data import Dataset
5
5
  import torch.distributed as dist
6
6
  import torch.nn.functional as F
7
7
 
8
- from llm_model import LlmModel
9
-
10
8
  from .parallel_ds import DsParallel
11
9
  from .parallel_fsdp import FsdpParallel
12
10
  from .trainer import Trainer
@@ -41,7 +39,7 @@ class DPOTrainer(Trainer):
41
39
  def _init_reference_model(self):
42
40
  parallel = TrainerTools().new_parallel()
43
41
 
44
- reference_model = LlmModel(self.train_config.model_config)
42
+ reference_model = self._new_model(self.train_config)
45
43
  if self.train_config.init_state_dict:
46
44
  reference_model.load_state_dict(self.train_config.init_state_dict, strict=False)
47
45
  self.train_config.init_state_dict = None
@@ -7,8 +7,6 @@ from torch.nn.utils.rnn import pad_sequence
7
7
  import torch.distributed as dist
8
8
  import torch.nn.functional as F
9
9
 
10
- from llm_model import LlmModel
11
-
12
10
  from .parallel_ds import DsParallel
13
11
  from .trainer import Trainer
14
12
  from .train_configs import TrainConfig
@@ -50,7 +48,7 @@ class GRPOTrainer(Trainer):
50
48
  save_checkpoint(self.train_model, self.optimizer)
51
49
 
52
50
  def _init_reference_model(self):
53
- reference_model = LlmModel(self.train_config.model_config)
51
+ reference_model = self._new_model(self.train_config)
54
52
 
55
53
  device = 'cpu' # TrainerTools().parallel.device
56
54
  reference_model.to(device)
@@ -64,7 +62,7 @@ class GRPOTrainer(Trainer):
64
62
 
65
63
  def _init_generate_model(self):
66
64
  return copy.deepcopy(self.reference_model)
67
- # generate_model = LlmModel(self.train_config.model_config)
65
+ # generate_model = self._new_model(self.train_config)
68
66
  #
69
67
  # device = 'cpu' #TrainerTools().parallel.device
70
68
  # generate_model.to(device)
@@ -32,11 +32,15 @@ class SFTTrainer(Trainer):
32
32
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
33
33
  file_path = self.train_config.file_dataset[file_idx]
34
34
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
35
+
36
+ image_tag_file_path = None
37
+ tokens_per_image = -1
38
+
35
39
  if isinstance(self.train_config.model_config, VLMConfig):
36
- image_tag_file_path = self.train_config.image_tags_file_dataset[file_idx]
37
- tokens_per_image = self.train_config.model_config.tokens_per_image
38
- else:
39
- image_tag_file_path = None
40
- tokens_per_image = -1
40
+ if self.train_config.image_tags_file_dataset:
41
+ image_tag_file_path = self.train_config.image_tags_file_dataset[file_idx]
42
+
43
+ if self.train_config.model_config.tokens_per_image:
44
+ tokens_per_image = self.train_config.model_config.tokens_per_image
41
45
 
42
46
  return LineByLineTextDataset(file_path, max_position_embeddings, image_tag_file_path, tokens_per_image), file_path
llm_trainer/trainer.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import time
2
2
  from contextlib import nullcontext
3
- from typing import Optional, Tuple, List, Dict, Any
3
+ from typing import Optional, Tuple, List, Dict, Any, Union
4
4
 
5
5
  import torch
6
6
  from torch import nn
@@ -110,16 +110,19 @@ class Trainer:
110
110
  self.pixel_values_provider = None
111
111
  self.tokens_per_image = -1
112
112
 
113
+ def _new_model(self, train_config: TrainConfig):
114
+ if isinstance(train_config.model_config, VLMConfig):
115
+ return VlmModel(train_config.model_config)
116
+ else:
117
+ return LlmModel(train_config.model_config)
118
+
113
119
  def _init_train_model_and_optim(
114
120
  self,
115
121
  initial_lr: float,
116
122
  parallel_kwargs: dict,
117
123
  use_ds_optim: bool
118
124
  ):
119
- if isinstance(self.train_config.model_config, VLMConfig):
120
- model = VlmModel(self.train_config.model_config)
121
- else:
122
- model = LlmModel(self.train_config.model_config)
125
+ model = self._new_model(self.train_config)
123
126
 
124
127
  if self.train_config.init_state_dict:
125
128
  model.load_state_dict(self.train_config.init_state_dict, strict=False)
@@ -156,10 +159,7 @@ class Trainer:
156
159
 
157
160
  def _init_eval_model(self) -> Optional[nn.Module]:
158
161
  if TrainerTools().parallel.is_main_process:
159
- if isinstance(self.train_config.model_config, VLMConfig):
160
- return VlmModel(self.train_config.model_config).to('cpu')
161
- else:
162
- return LlmModel(self.train_config.model_config).to('cpu')
162
+ return self._new_model(self.train_config).to('cpu')
163
163
 
164
164
  return None
165
165
 
@@ -400,7 +400,7 @@ class Trainer:
400
400
  ):
401
401
  if TrainerTools().parallel.is_main_process:
402
402
  eval_prompt, eval_image_tag = self._get_eval_data()
403
- if isinstance(self.train_config.model_config, VLMConfig) and eval_image_tag:
403
+ if isinstance(self.train_model, VlmModel) and self.pixel_values_provider and eval_image_tag:
404
404
  eval_pixel_values = self.pixel_values_provider([eval_image_tag])
405
405
  else:
406
406
  eval_pixel_values = None
@@ -422,7 +422,7 @@ class Trainer:
422
422
  ):
423
423
  if TrainerTools().parallel.is_main_process:
424
424
  eval_prompt, eval_image_tag = self._get_eval_data()
425
- if isinstance(self.train_config.model_config, VLMConfig) and eval_image_tag:
425
+ if isinstance(self.train_model, VlmModel) and eval_image_tag:
426
426
  eval_pixel_values = self.pixel_values_provider([eval_image_tag])
427
427
  else:
428
428
  eval_pixel_values = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -2,11 +2,11 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=Dlkcit0o7Gx6S9QUrIrVp2pTurP9X0zVA7w7ImSuVQU,6049
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=q3JZ1iKzmiuwUV-DTrSXUea2d39g6f5x1oUuF1QzBGA,13173
5
+ llm_trainer/dpo_trainer.py,sha256=7Bf6snWcu2fT8QRDI1CSzmrc7Cog6JauIeK2KoW_f8I,13135
6
6
  llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
7
7
  llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
8
8
  llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
9
- llm_trainer/grpo_trainer.py,sha256=_k9pik-kpbE8g9taQyG9w3dTLAHilgVBTUa4Y90Wae4,16414
9
+ llm_trainer/grpo_trainer.py,sha256=M6vp6QjxhBQVaw3e_3BJ4earuezQNKQ3JeZfQLBaSLQ,16370
10
10
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
11
11
  llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
12
12
  llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
@@ -15,20 +15,20 @@ llm_trainer/parallel_ds.py,sha256=W_PkczyAlgffCRcQadN-Pf7H7HM7TU26v5W63jKELFM,99
15
15
  llm_trainer/parallel_fsdp.py,sha256=u9XbbVTzcsMcaf-aQFrC_QwWsDRGoEpRmgvu1cKNtgk,3887
16
16
  llm_trainer/parallel_none.py,sha256=a6tt3aBmCq5rSP7n2I-sF-hsZ992BbLbpbxutDCFJfs,607
17
17
  llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
18
- llm_trainer/sft_trainer.py,sha256=WWmg8YOwr-w90otmeMjXvK9sa_DSPKlfgAPg3kHyRF4,1672
18
+ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
19
19
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
20
20
  llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
21
21
  llm_trainer/train_configs.py,sha256=cadfo8RgxNUR-L3ZLyjiRXTQvhjUl4A1qENaq-ol8h4,15878
22
- llm_trainer/trainer.py,sha256=153F8FzsKh6k9XLm9i6JzmwN4Vwva5mWr9rVoge_3bY,24353
22
+ llm_trainer/trainer.py,sha256=tUi0Xcwgci_Y4T_I3CHR6pcsqURsXIMKbpAXB4DgRWo,24277
23
23
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
24
- project_llm_trainer-0.3.2.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
- project_llm_trainer-0.3.2.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
- project_llm_trainer-0.3.2.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
- project_llm_trainer-0.3.2.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
- project_llm_trainer-0.3.2.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
- project_llm_trainer-0.3.2.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
- project_llm_trainer-0.3.2.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
- project_llm_trainer-0.3.2.dist-info/METADATA,sha256=NQpGh0Xy09euhzVTSBcC6m5P23ATvRKQ-zmkE0o__6g,195
32
- project_llm_trainer-0.3.2.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
- project_llm_trainer-0.3.2.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
- project_llm_trainer-0.3.2.dist-info/RECORD,,
24
+ project_llm_trainer-0.3.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
+ project_llm_trainer-0.3.3.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
+ project_llm_trainer-0.3.3.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
+ project_llm_trainer-0.3.3.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
+ project_llm_trainer-0.3.3.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
+ project_llm_trainer-0.3.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
+ project_llm_trainer-0.3.3.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
+ project_llm_trainer-0.3.3.dist-info/METADATA,sha256=plDfgI4_qj6tHaJvEhmhmRFmKt0JJpU37JNDjkDjnRY,195
32
+ project_llm_trainer-0.3.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
+ project_llm_trainer-0.3.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
+ project_llm_trainer-0.3.3.dist-info/RECORD,,