project-llm-trainer 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/dataset.py CHANGED
@@ -3,9 +3,10 @@ import os.path
3
3
  import torch
4
4
  from torch.utils.data import Dataset
5
5
  import pickle
6
+ import csv
6
7
 
7
8
  from .tools import TrainerTools
8
- from .utils import extra_image_tag_and_repeat_image_tok
9
+ from .utils import repeat_image_tok
9
10
 
10
11
 
11
12
  def _try_load_pkl(file_path: str):
@@ -23,7 +24,12 @@ class TextDataset(Dataset):
23
24
  """
24
25
  适用于pretrain阶段
25
26
  """
26
- def __init__(self, file_path, block_size, stride):
27
+ def __init__(
28
+ self,
29
+ file_path,
30
+ block_size,
31
+ stride
32
+ ):
27
33
  super().__init__()
28
34
 
29
35
  self.input_ids = []
@@ -56,12 +62,19 @@ class LineByLineTextDataset(Dataset):
56
62
  """
57
63
  适用于sft阶段
58
64
  """
59
- def __init__(self, file_path, max_len, tokens_per_image=-1):
65
+ def __init__(
66
+ self,
67
+ file_path,
68
+ max_len,
69
+ image_tags_file_path=None,
70
+ tokens_per_image=-1
71
+ ):
60
72
  super().__init__()
61
73
 
62
74
  self.max_len = max_len
63
75
  self.tokens_per_image = tokens_per_image
64
76
  self.input_ids = []
77
+ self.image_tags = []
65
78
 
66
79
  tokens = _try_load_pkl(file_path)
67
80
  if not tokens:
@@ -79,19 +92,26 @@ class LineByLineTextDataset(Dataset):
79
92
 
80
93
  self.input_ids = tokens
81
94
 
95
+ if image_tags_file_path:
96
+ with open(image_tags_file_path, 'r') as f:
97
+ csv_reader = csv.reader(f)
98
+ for line in csv_reader:
99
+ self.image_tags.append(line[0])
100
+
82
101
  def __len__(self):
83
102
  return len(self.input_ids)
84
103
 
85
104
  def __getitem__(self, item):
86
- inputs = self.input_ids[item]
105
+ inputs = torch.tensor(self.input_ids[item]).long()
106
+ image_tag = self.image_tags[item] if self.image_tags else None
87
107
  if self.tokens_per_image != -1:
88
- inputs, image_tag = extra_image_tag_and_repeat_image_tok(inputs, self.tokens_per_image)
108
+ inputs = repeat_image_tok(inputs, self.tokens_per_image)
89
109
  else:
90
110
  image_tag = None
91
111
 
92
112
  inputs = inputs[:self.max_len]
93
113
 
94
- return {'inputs': torch.tensor(inputs).long(), 'image_tag': image_tag}
114
+ return {'inputs': inputs, 'image_tag': image_tag}
95
115
 
96
116
 
97
117
  class DPODataset(Dataset):
@@ -28,7 +28,7 @@ class DPOTrainer(Trainer):
28
28
  *,
29
29
  train_config: TrainConfig,
30
30
  eval_prompts: List[str],
31
- eval_image_tags: Optional[List[int]] = None
31
+ eval_image_tags: Optional[List[str]] = None
32
32
  ):
33
33
  super().__init__(
34
34
  train_config=train_config,
@@ -112,9 +112,10 @@ class DPOTrainer(Trainer):
112
112
 
113
113
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
114
114
 
115
- def _create_dataset(self, file_path) -> Dataset:
115
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
116
+ file_path = self.train_config.file_dataset[file_idx]
116
117
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
117
- return DPODataset(file_path, max_position_embeddings)
118
+ return DPODataset(file_path, max_position_embeddings), file_path
118
119
 
119
120
  def _calc_loss(self, inputs, attention_mask, logits, labels): ...
120
121
 
@@ -184,9 +185,7 @@ class DPOTrainer(Trainer):
184
185
  file_count = len(self.train_config.file_dataset)
185
186
 
186
187
  for file_idx in range(file_count):
187
- file_path = self.train_config.file_dataset[file_idx]
188
-
189
- dataset = self._create_dataset(file_path)
188
+ dataset, file_path = self._create_dataset(file_idx)
190
189
  train_data_loader = TrainerTools().parallel.process_dataloader(
191
190
  dataset=dataset,
192
191
  data_loader_kwargs=self.data_loader_kwargs,
@@ -30,7 +30,7 @@ class GRPOTrainer(Trainer):
30
30
  train_config: TrainConfig,
31
31
  reward_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], List[float]],
32
32
  eval_prompts: List[str],
33
- eval_image_tags: Optional[List[int]] = None
33
+ eval_image_tags: Optional[List[str]] = None
34
34
  ):
35
35
  super().__init__(
36
36
  train_config=train_config,
@@ -90,8 +90,9 @@ class GRPOTrainer(Trainer):
90
90
 
91
91
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
92
92
 
93
- def _create_dataset(self, file_path) -> Dataset:
94
- return GRPORolloutDataset(file_path)
93
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
94
+ file_path = self.train_config.file_dataset[file_idx]
95
+ return GRPORolloutDataset(file_path), file_path
95
96
 
96
97
  def _calc_loss(self, inputs, attention_mask, logits, labels): ...
97
98
 
@@ -302,8 +303,7 @@ class GRPOTrainer(Trainer):
302
303
  file_count = len(self.train_config.file_dataset)
303
304
 
304
305
  for file_idx in range(file_count):
305
- file_path = self.train_config.file_dataset[file_idx]
306
- dataset = self._create_dataset(file_path)
306
+ dataset, file_path = self._create_dataset(file_idx)
307
307
 
308
308
  train_data_loader = TrainerTools().parallel.process_dataloader(
309
309
  dataset=dataset,
@@ -14,7 +14,7 @@ class SFTTrainer(Trainer):
14
14
  *,
15
15
  train_config: TrainConfig,
16
16
  eval_prompts: List[str],
17
- eval_image_tags: Optional[List[int]] = None
17
+ eval_image_tags: Optional[List[str]] = None
18
18
  ):
19
19
  super().__init__(
20
20
  train_config=train_config,
@@ -29,11 +29,14 @@ class SFTTrainer(Trainer):
29
29
 
30
30
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
31
31
 
32
- def _create_dataset(self, file_path) -> Dataset:
32
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
33
+ file_path = self.train_config.file_dataset[file_idx]
33
34
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
34
35
  if isinstance(self.train_config.model_config, VLMConfig):
36
+ image_tag_file_path = self.train_config.image_tags_file_dataset[file_idx]
35
37
  tokens_per_image = self.train_config.model_config.tokens_per_image
36
38
  else:
39
+ image_tag_file_path = None
37
40
  tokens_per_image = -1
38
41
 
39
- return LineByLineTextDataset(file_path, max_position_embeddings, tokens_per_image)
42
+ return LineByLineTextDataset(file_path, max_position_embeddings, image_tag_file_path, tokens_per_image), file_path
@@ -408,6 +408,7 @@ class TrainConfig:
408
408
  *,
409
409
  model_config: Union[ModelConfig, VLMConfig],
410
410
  file_dataset: FileDataset,
411
+ image_tags_file_dataset: Optional[FileDataset] = None,
411
412
  mask_prompt: bool = True,
412
413
  gradient_accumulation_steps: int = 0,
413
414
  eval_batch_interval: int = 100,
@@ -419,7 +420,7 @@ class TrainConfig:
419
420
  fsdp_config: FsdpConfig = FsdpConfig(),
420
421
  data_loader_config: DataLoaderConfig = DataLoaderConfig(),
421
422
  kd_config: Optional[KDConfig] = None,
422
- pixel_values_provider: Optional[Callable[[list[int]], torch.Tensor]] = None,
423
+ pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None,
423
424
  init_state_dict: Optional[Mapping[str, Any]] = None,
424
425
  eval_config: EvalConfig = EvalConfig()
425
426
  ):
@@ -427,6 +428,7 @@ class TrainConfig:
427
428
  self.batch_size = batch_size
428
429
  self.model_config = model_config
429
430
  self.file_dataset = file_dataset
431
+ self.image_tags_file_dataset = image_tags_file_dataset
430
432
  self.mask_prompt = mask_prompt
431
433
  self.gradient_accumulation_steps = gradient_accumulation_steps
432
434
  self.eval_batch_interval = eval_batch_interval
llm_trainer/trainer.py CHANGED
@@ -52,7 +52,7 @@ class Trainer:
52
52
  *,
53
53
  train_config: TrainConfig,
54
54
  eval_prompts: List[str],
55
- eval_image_tags: Optional[List[int]] = None
55
+ eval_image_tags: Optional[List[str]] = None
56
56
  ):
57
57
  set_seed()
58
58
 
@@ -318,9 +318,10 @@ class Trainer:
318
318
 
319
319
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
320
320
 
321
- def _create_dataset(self, file_path) -> Dataset:
321
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
322
+ file_path = self.train_config.file_dataset[file_idx]
322
323
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
323
- return TextDataset(file_path, max_position_embeddings, max_position_embeddings)
324
+ return TextDataset(file_path, max_position_embeddings, max_position_embeddings), file_path
324
325
 
325
326
  def _calc_loss(self, inputs, attention_mask, logits, labels):
326
327
  # calc loss
@@ -353,7 +354,7 @@ class Trainer:
353
354
 
354
355
  TrainerTools().parallel.synchronize()
355
356
 
356
- def _get_eval_data(self) -> Tuple[str, Optional[int]]:
357
+ def _get_eval_data(self) -> Tuple[str, Optional[str]]:
357
358
  if len(self.eval_prompts) == 0:
358
359
  return '', None
359
360
 
@@ -458,9 +459,7 @@ class Trainer:
458
459
  file_count = len(self.train_config.file_dataset)
459
460
 
460
461
  for file_idx in range(file_count):
461
- file_path = self.train_config.file_dataset[file_idx]
462
-
463
- dataset = self._create_dataset(file_path)
462
+ dataset, file_path = self._create_dataset(file_idx)
464
463
  train_data_loader = TrainerTools().parallel.process_dataloader(
465
464
  dataset=dataset,
466
465
  data_loader_kwargs=self.data_loader_kwargs,
llm_trainer/utils.py CHANGED
@@ -15,45 +15,6 @@ def set_seed(seed=42):
15
15
  torch.cuda.manual_seed_all(seed)
16
16
 
17
17
 
18
- def extra_image_tag_and_repeat_image_tok(
19
- inputs: list[int],
20
- tokens_per_image: int
21
- ) -> Tuple[list[int], Optional[int]]:
22
- # tokens_per_image=3 -> <image>{image_tag}...xxxx -> <image><image><image>...xxx
23
- image_tok = TrainerTools().tokenizer.image
24
- if image_tok not in inputs:
25
- return inputs, None
26
-
27
- image_tok_idx = inputs.index(image_tok)
28
- image_tag_idx = image_tok_idx + 1
29
-
30
- if image_tag_idx < len(inputs):
31
- # remove it
32
- image_tag = inputs.pop(image_tag_idx)
33
- else:
34
- image_tag = None
35
-
36
- # repeat image_tok
37
- new_inputs = inputs[:image_tok_idx] + [image_tok] * tokens_per_image + inputs[image_tok_idx + 1:]
38
- return new_inputs, image_tag
39
-
40
-
41
- def batch_extra_image_tag_and_repeat_image_tok(
42
- tokens: torch.Tensor,
43
- tokens_per_image: int
44
- ) -> Tuple[torch.Tensor, list[int]]:
45
- new_tokens = []
46
- image_tags = []
47
-
48
- tokens_list = tokens.cpu().detach().tolist()
49
- for token in tokens_list:
50
- new_token, image_tag = extra_image_tag_and_repeat_image_tok(token, tokens_per_image)
51
- new_tokens.append(new_token)
52
- image_tags.append(image_tag)
53
-
54
- return torch.tensor(new_tokens, dtype=tokens.dtype, device=tokens.device), image_tags
55
-
56
-
57
18
  def repeat_image_tok(
58
19
  tokens: torch.Tensor,
59
20
  tokens_per_image: int
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,12 +1,12 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=Dlkcit0o7Gx6S9QUrIrVp2pTurP9X0zVA7w7ImSuVQU,6049
3
- llm_trainer/dataset.py,sha256=uz1TTd87ikf7CZPdGxmR95TSQTFWPPTilgWLBWO46_I,3916
3
+ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=6rm8Jq0rI0xazcl_bCOun8rnd34Tb_PKgezowhwoiCM,13150
5
+ llm_trainer/dpo_trainer.py,sha256=q3JZ1iKzmiuwUV-DTrSXUea2d39g6f5x1oUuF1QzBGA,13173
6
6
  llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
7
7
  llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
8
8
  llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
9
- llm_trainer/grpo_trainer.py,sha256=gWDX8vRZ7hLKl_483X5ua92nst1m617BrqnzLhwr87g,16390
9
+ llm_trainer/grpo_trainer.py,sha256=_k9pik-kpbE8g9taQyG9w3dTLAHilgVBTUa4Y90Wae4,16414
10
10
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
11
11
  llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
12
12
  llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
@@ -15,20 +15,20 @@ llm_trainer/parallel_ds.py,sha256=W_PkczyAlgffCRcQadN-Pf7H7HM7TU26v5W63jKELFM,99
15
15
  llm_trainer/parallel_fsdp.py,sha256=u9XbbVTzcsMcaf-aQFrC_QwWsDRGoEpRmgvu1cKNtgk,3887
16
16
  llm_trainer/parallel_none.py,sha256=a6tt3aBmCq5rSP7n2I-sF-hsZ992BbLbpbxutDCFJfs,607
17
17
  llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
18
- llm_trainer/sft_trainer.py,sha256=T9CujoEp8D5I65fLF2wgV6SPjzhGFbAI4We5NwL4O-M,1443
18
+ llm_trainer/sft_trainer.py,sha256=WWmg8YOwr-w90otmeMjXvK9sa_DSPKlfgAPg3kHyRF4,1672
19
19
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
20
20
  llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
21
- llm_trainer/train_configs.py,sha256=FAlylSYVeh_oJGTy2fcMNUV8JLD6B70hMuk-iKx14iI,15748
22
- llm_trainer/trainer.py,sha256=mq51d-2ADUpcWCArszhYnOSTveatt3_x43hcC7IZgYk,24330
23
- llm_trainer/utils.py,sha256=04XiMENVotNgbNRBn9wadHu-cJHPxj0Xq-zzLJmNgZQ,8062
24
- project_llm_trainer-0.3.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
- project_llm_trainer-0.3.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
- project_llm_trainer-0.3.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
- project_llm_trainer-0.3.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
- project_llm_trainer-0.3.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
- project_llm_trainer-0.3.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
- project_llm_trainer-0.3.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
- project_llm_trainer-0.3.1.dist-info/METADATA,sha256=LJl2lNqTIIQZpTt7iVqzQJ2NhAvTUOwS9w44_XxIn0Y,195
32
- project_llm_trainer-0.3.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
- project_llm_trainer-0.3.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
- project_llm_trainer-0.3.1.dist-info/RECORD,,
21
+ llm_trainer/train_configs.py,sha256=cadfo8RgxNUR-L3ZLyjiRXTQvhjUl4A1qENaq-ol8h4,15878
22
+ llm_trainer/trainer.py,sha256=153F8FzsKh6k9XLm9i6JzmwN4Vwva5mWr9rVoge_3bY,24353
23
+ llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
24
+ project_llm_trainer-0.3.2.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
+ project_llm_trainer-0.3.2.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
+ project_llm_trainer-0.3.2.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
+ project_llm_trainer-0.3.2.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
+ project_llm_trainer-0.3.2.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
+ project_llm_trainer-0.3.2.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
+ project_llm_trainer-0.3.2.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
+ project_llm_trainer-0.3.2.dist-info/METADATA,sha256=NQpGh0Xy09euhzVTSBcC6m5P23ATvRKQ-zmkE0o__6g,195
32
+ project_llm_trainer-0.3.2.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
+ project_llm_trainer-0.3.2.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
+ project_llm_trainer-0.3.2.dist-info/RECORD,,