project-llm-trainer 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/dataset.py CHANGED
@@ -1,28 +1,30 @@
1
- import os.path
2
-
3
1
  import torch
4
2
  from torch.utils.data import Dataset
5
3
  import pickle
6
4
  import csv
5
+ import json
7
6
 
8
7
  from .tools import TrainerTools
9
8
  from .utils import repeat_image_tok
10
9
 
11
10
 
12
- def _try_load_pkl(file_path: str):
13
- tokens = None
14
- try:
15
- with open(file_path, 'rb') as f:
16
- tokens = pickle.load(f)
17
- except Exception as e:
18
- raise e
19
- finally:
20
- return tokens
11
+ """
12
+ support jsonl and pkl
13
+ """
14
+ def _get_file_type(file_path: str):
15
+ if file_path.endswith('.jsonl'):
16
+ return 'jsonl'
17
+ elif file_path.endswith('.pkl'):
18
+ return 'pkl'
19
+
20
+ return None
21
21
 
22
22
 
23
23
  class TextDataset(Dataset):
24
24
  """
25
- 适用于pretrain阶段
25
+ 适用于pretrain阶段,数据格式支持jsonl和pkl,如果是jsonl会在init阶段全部encode成token
26
+ jsonl: {'text': 'text1'}\n{'text': 'text2'}
27
+ pkl: [0, 1, 2, 3 ...]
26
28
  """
27
29
  def __init__(
28
30
  self,
@@ -34,19 +36,17 @@ class TextDataset(Dataset):
34
36
 
35
37
  self.input_ids = []
36
38
 
37
- tokens = _try_load_pkl(file_path)
38
- if not tokens:
39
- cache_file = f'{file_path}.cache'
40
- if os.path.exists(cache_file):
41
- tokens = _try_load_pkl(cache_file)
42
- else:
43
- tokens = []
44
- with open(file_path, 'r') as f:
45
- for line in f:
46
- tokens.extend(TrainerTools().tokenizer.encode(line))
47
-
48
- with open(cache_file, 'wb') as f:
49
- pickle.dump(tokens, f)
39
+ file_type = _get_file_type(file_path)
40
+ if file_type == 'jsonl':
41
+ tokens = []
42
+ with open(file_path, 'r') as f:
43
+ for line in f:
44
+ tokens.extend(TrainerTools().tokenizer.encode(json.loads(line.strip())['text']))
45
+ elif file_type == 'pkl':
46
+ with open(file_path, 'rb') as f:
47
+ tokens = pickle.load(f)
48
+ else:
49
+ raise Exception(f'unsupported file type for {file_path}')
50
50
 
51
51
  for i in range(0, len(tokens) - block_size + 1, stride):
52
52
  self.input_ids.append(tokens[i:i+block_size])
@@ -60,7 +60,21 @@ class TextDataset(Dataset):
60
60
 
61
61
  class LineByLineTextDataset(Dataset):
62
62
  """
63
- 适用于sft阶段
63
+ 适用于sft阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
64
+ jsonl: [
65
+ {'role': 'system', 'content': 'system_content'},
66
+ {'role': 'user', 'content': 'user_content'},
67
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
68
+ ]\n
69
+ [
70
+ {'role': 'system', 'content': 'system_content'},
71
+ {'role': 'user', 'content': 'user_content'},
72
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
73
+ ]
74
+ pkl: [
75
+ [0, 1, 2, 3],
76
+ [4, 5, 6, 7]
77
+ ]
64
78
  """
65
79
  def __init__(
66
80
  self,
@@ -75,22 +89,20 @@ class LineByLineTextDataset(Dataset):
75
89
  self.tokens_per_image = tokens_per_image
76
90
  self.input_ids = []
77
91
  self.image_tags = []
78
-
79
- tokens = _try_load_pkl(file_path)
80
- if not tokens:
81
- cache_file = f'{file_path}.cache'
82
- if os.path.exists(cache_file):
83
- tokens = _try_load_pkl(cache_file)
84
- else:
85
- tokens = []
86
- with open(file_path, 'r') as f:
87
- for line in f:
88
- tokens.append(TrainerTools().tokenizer.encode(line))
89
-
90
- with open(cache_file, 'wb') as f:
91
- pickle.dump(tokens, f)
92
-
93
- self.input_ids = tokens
92
+ self.plain_text = False
93
+
94
+ file_type = _get_file_type(file_path)
95
+ if file_type == 'jsonl':
96
+ self.plain_text = True
97
+
98
+ with open(file_path, 'r') as f:
99
+ for line in f:
100
+ self.input_ids.append(json.loads(line.strip()))
101
+ elif file_type == 'pkl':
102
+ with open(file_path, 'rb') as f:
103
+ self.input_ids = pickle.load(f)
104
+ else:
105
+ raise Exception(f'unsupported file type for {file_path}')
94
106
 
95
107
  if image_tags_file_path:
96
108
  with open(image_tags_file_path, 'r') as f:
@@ -102,8 +114,14 @@ class LineByLineTextDataset(Dataset):
102
114
  return len(self.input_ids)
103
115
 
104
116
  def __getitem__(self, item):
105
- inputs = torch.tensor(self.input_ids[item]).long()
117
+ if self.plain_text:
118
+ inputs = TrainerTools().tokenizer.apply_chat_template(self.input_ids[item])
119
+ else:
120
+ inputs = self.input_ids[item]
121
+
122
+ inputs = torch.tensor(inputs).long()
106
123
  image_tag = self.image_tags[item] if self.image_tags else None
124
+
107
125
  if self.tokens_per_image != -1:
108
126
  inputs = repeat_image_tok(inputs, self.tokens_per_image)
109
127
  else:
@@ -111,48 +129,134 @@ class LineByLineTextDataset(Dataset):
111
129
 
112
130
  inputs = inputs[:self.max_len]
113
131
 
114
- return {'inputs': inputs, 'image_tag': image_tag}
132
+ return {
133
+ 'inputs': inputs,
134
+ 'image_tag': image_tag
135
+ }
115
136
 
116
137
 
117
138
  class DPODataset(Dataset):
139
+ """
140
+ 适用于dpo阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
141
+ jsonl: {'chosen':
142
+ [{'role': 'system', 'content': 'system_content'},
143
+ {'role': 'user', 'content': 'user_content'},
144
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
145
+ 'rejected':
146
+ [{'role': 'system', 'content': 'system_content'},
147
+ {'role': 'user', 'content': 'user_content'},
148
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
149
+ }\n
150
+ {'chosen':
151
+ [{'role': 'system', 'content': 'system_content'},
152
+ {'role': 'user', 'content': 'user_content'},
153
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
154
+ 'rejected':
155
+ [{'role': 'system', 'content': 'system_content'},
156
+ {'role': 'user', 'content': 'user_content'},
157
+ 'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
158
+ }
159
+ pkl: [
160
+ {'chosen': xxx, 'rejected': xxx},
161
+ {'chosen': xxx, 'rejected': xxx},
162
+ ]
163
+ """
118
164
  def __init__(self, file_path, max_len):
119
165
  self.max_len = max_len
120
166
  self.chosen_ids = []
121
167
  self.rejected_ids = []
122
-
123
- # [{'chosen': xxx, 'rejected': xxx} ...]
124
- tokens = _try_load_pkl(file_path)
125
- for token in tokens:
126
- self.chosen_ids.append(token['chosen'])
127
- self.rejected_ids.append(token['rejected'])
168
+ self.plain_text = False
169
+
170
+ file_type = _get_file_type(file_path)
171
+ if file_type == 'jsonl':
172
+ self.plain_text = True
173
+
174
+ with open(file_path, 'r') as f:
175
+ for line in f:
176
+ json_ = json.loads(line.strip())
177
+ self.chosen_ids.append(json_['chosen'])
178
+ self.rejected_ids.append(json_['rejected'])
179
+ elif file_type == 'pkl':
180
+ with open(file_path, 'rb') as f:
181
+ tokens = pickle.load(f)
182
+
183
+ for token in tokens:
184
+ self.chosen_ids.append(token['chosen'])
185
+ self.rejected_ids.append(token['rejected'])
186
+ else:
187
+ raise Exception(f'unsupported file type for {file_path}')
128
188
 
129
189
  def __len__(self):
130
190
  return len(self.chosen_ids)
131
191
 
132
192
  def __getitem__(self, item):
133
- chosen_id = self.chosen_ids[item]
134
- rejected_id = self.rejected_ids[item]
193
+ if self.plain_text:
194
+ chosen_id = TrainerTools().tokenizer.apply_chat_template(self.chosen_ids[item])
195
+ rejected_id = TrainerTools().tokenizer.apply_chat_template(self.rejected_ids[item])
196
+ else:
197
+ chosen_id = self.chosen_ids[item]
198
+ rejected_id = self.rejected_ids[item]
135
199
 
136
- return {'chosen': chosen_id[:self.max_len], 'rejected': rejected_id[:self.max_len]}
200
+ return {
201
+ 'chosen': chosen_id[:self.max_len],
202
+ 'rejected': rejected_id[:self.max_len]
203
+ }
137
204
 
138
205
 
139
206
  class GRPORolloutDataset(Dataset):
207
+ """
208
+ 适用于grpo(gspo)阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
209
+ jsonl: {'prompt':
210
+ [{'role': 'system', 'content': 'system_content'},
211
+ {'role': 'user', 'content': 'user_content'},
212
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
213
+ 'answer': '10'
214
+ }\n
215
+ {'prompt':
216
+ [{'role': 'system', 'content': 'system_content'},
217
+ {'role': 'user', 'content': 'user_content'},
218
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
219
+ 'answer': '10'
220
+ }
221
+ pkl: [
222
+ {'prompt': xxx, 'answer': xxx},
223
+ {'prompt': xxx, 'answer': xxx},
224
+ ]
225
+ """
140
226
  def __init__(self, file_path):
141
227
  self.questions = []
142
228
  self.answers = []
143
-
144
- # [{'question': xxx, 'answer': ''}]
145
- tokens = _try_load_pkl(file_path)
146
- for token in tokens:
147
- self.questions.append(token['prompt'])
148
- self.answers.append(token['answer'])
229
+ self.plain_text = False
230
+
231
+ file_type = _get_file_type(file_path)
232
+ if file_type == 'jsonl':
233
+ self.plain_text = True
234
+
235
+ with open(file_path, 'r') as f:
236
+ for line in f:
237
+ json_ = json.loads(line.strip())
238
+ self.questions.append(json_['prompt'])
239
+ self.answers.append(json_['answer'])
240
+ elif file_type == 'pkl':
241
+ with open(file_path, 'rb') as f:
242
+ tokens = pickle.load(f)
243
+
244
+ for token in tokens:
245
+ self.questions.append(token['prompt'])
246
+ self.answers.append(token['answer'])
247
+ else:
248
+ raise Exception(f'unsupported file type for {file_path}')
149
249
 
150
250
  def __len__(self):
151
251
  return len(self.questions)
152
252
 
153
253
  def __getitem__(self, item):
154
- question = self.questions[item]
155
- answer = self.answers[item]
254
+ if self.plain_text:
255
+ question = TrainerTools().tokenizer.apply_chat_template(self.questions[item])
256
+ answer = TrainerTools().tokenizer.encode(self.answers[item])
257
+ else:
258
+ question = self.questions[item]
259
+ answer = self.answers[item]
156
260
 
157
261
  return {
158
262
  'prompt': torch.tensor(question).long(),
@@ -0,0 +1,24 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ # file_name
7
+ run_file_name = arguments[0]
8
+
9
+ # cuda_visible_devive
10
+ if len(arguments) > 1:
11
+ # 0,1,2,3
12
+ cuda_visible_devive = arguments[1]
13
+ else:
14
+ cuda_visible_devive = None
15
+
16
+ os.environ['PARALLEL_TYPE'] = 'ddp'
17
+
18
+ if cuda_visible_devive:
19
+ os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devive
20
+
21
+ command = f'torchrun --standalone --nproc_per_node=gpu {run_file_name}'
22
+
23
+ print(f'run command {command}')
24
+ os.system(command)
@@ -0,0 +1,29 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ # file_name
7
+ run_file_name = arguments[0]
8
+
9
+ # cuda_visible_devive
10
+ if len(arguments) > 1:
11
+ # 0,1,2,3
12
+ cuda_visible_devive = arguments[1]
13
+ else:
14
+ cuda_visible_devive = None
15
+
16
+ # cuda location
17
+ if len(arguments) > 2:
18
+ cuda_loc = arguments[2]
19
+ else:
20
+ cuda_loc = 'localhost'
21
+
22
+ os.environ['PARALLEL_TYPE'] = 'ds'
23
+
24
+ cuda_ctrl = f' --include {cuda_loc}:{cuda_visible_devive}' if cuda_visible_devive else ''
25
+
26
+ command = f'deepspeed{cuda_ctrl} {run_file_name}'
27
+
28
+ print(f'run command {command}')
29
+ os.system(command)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.7.6
3
+ Version: 0.7.8
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,6 +1,6 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
3
- llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
3
+ llm_trainer/dataset.py,sha256=UL3fGeM4XSlyNQRZH-139u3LujqAQx3YyaxNRewk6LE,8935
4
4
  llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
5
5
  llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
@@ -20,14 +20,14 @@ llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
20
  llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
21
21
  llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
22
22
  llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
23
- project_llm_trainer-0.7.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.7.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.7.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.7.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.7.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.7.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.7.6.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
30
- project_llm_trainer-0.7.6.dist-info/METADATA,sha256=t52f6ahI8WvTnTguykneF91x-ChSZ84sE9PaBjvqb1g,195
31
- project_llm_trainer-0.7.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.7.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.7.6.dist-info/RECORD,,
23
+ project_llm_trainer-0.7.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.7.8.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
25
+ project_llm_trainer-0.7.8.data/scripts/ds_train,sha256=3nXNNKmYI7miqyBdf-Ijl_rW1cGIKrAMZ1CSswN_gGo,665
26
+ project_llm_trainer-0.7.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.7.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.7.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.7.8.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
30
+ project_llm_trainer-0.7.8.dist-info/METADATA,sha256=rSYUrEkdjPCyYUqT2SOw3-hzT40wU3AwEw-ouHh1rBY,195
31
+ project_llm_trainer-0.7.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.7.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.7.8.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- #!python
2
-
3
- if __name__ == '__main__':
4
- import os, sys
5
- arguments = sys.argv[1:]
6
- run_file_name = arguments[0]
7
-
8
- os.environ['PARALLEL_TYPE'] = 'ddp'
9
- command = f'torchrun --standalone --nproc_per_node=gpu {run_file_name}'
10
-
11
- print(f'real command is {command}')
12
- os.system(command)
@@ -1,12 +0,0 @@
1
- #!python
2
-
3
- if __name__ == '__main__':
4
- import os, sys
5
- arguments = sys.argv[1:]
6
- run_file_name = arguments[0]
7
-
8
- os.environ['PARALLEL_TYPE'] = 'ds'
9
- command = f'deepspeed {run_file_name}'
10
-
11
- print(f'real command is {command}')
12
- os.system(command)