project-llm-trainer 0.7.7__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dataset.py +165 -61
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.8.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.8.dist-info}/RECORD +12 -12
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.8.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.8.dist-info}/top_level.txt +0 -0
llm_trainer/dataset.py
CHANGED
|
@@ -1,28 +1,30 @@
|
|
|
1
|
-
import os.path
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from torch.utils.data import Dataset
|
|
5
3
|
import pickle
|
|
6
4
|
import csv
|
|
5
|
+
import json
|
|
7
6
|
|
|
8
7
|
from .tools import TrainerTools
|
|
9
8
|
from .utils import repeat_image_tok
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
11
|
+
"""
|
|
12
|
+
support jsonl and pkl
|
|
13
|
+
"""
|
|
14
|
+
def _get_file_type(file_path: str):
|
|
15
|
+
if file_path.endswith('.jsonl'):
|
|
16
|
+
return 'jsonl'
|
|
17
|
+
elif file_path.endswith('.pkl'):
|
|
18
|
+
return 'pkl'
|
|
19
|
+
|
|
20
|
+
return None
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class TextDataset(Dataset):
|
|
24
24
|
"""
|
|
25
|
-
适用于pretrain
|
|
25
|
+
适用于pretrain阶段,数据格式支持jsonl和pkl,如果是jsonl会在init阶段全部encode成token
|
|
26
|
+
jsonl: {'text': 'text1'}\n{'text': 'text2'}
|
|
27
|
+
pkl: [0, 1, 2, 3 ...]
|
|
26
28
|
"""
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
@@ -34,19 +36,17 @@ class TextDataset(Dataset):
|
|
|
34
36
|
|
|
35
37
|
self.input_ids = []
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
if
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
with open(cache_file, 'wb') as f:
|
|
49
|
-
pickle.dump(tokens, f)
|
|
39
|
+
file_type = _get_file_type(file_path)
|
|
40
|
+
if file_type == 'jsonl':
|
|
41
|
+
tokens = []
|
|
42
|
+
with open(file_path, 'r') as f:
|
|
43
|
+
for line in f:
|
|
44
|
+
tokens.extend(TrainerTools().tokenizer.encode(json.loads(line.strip())['text']))
|
|
45
|
+
elif file_type == 'pkl':
|
|
46
|
+
with open(file_path, 'rb') as f:
|
|
47
|
+
tokens = pickle.load(f)
|
|
48
|
+
else:
|
|
49
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
50
50
|
|
|
51
51
|
for i in range(0, len(tokens) - block_size + 1, stride):
|
|
52
52
|
self.input_ids.append(tokens[i:i+block_size])
|
|
@@ -60,7 +60,21 @@ class TextDataset(Dataset):
|
|
|
60
60
|
|
|
61
61
|
class LineByLineTextDataset(Dataset):
|
|
62
62
|
"""
|
|
63
|
-
适用于sft阶段
|
|
63
|
+
适用于sft阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
64
|
+
jsonl: [
|
|
65
|
+
{'role': 'system', 'content': 'system_content'},
|
|
66
|
+
{'role': 'user', 'content': 'user_content'},
|
|
67
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
|
|
68
|
+
]\n
|
|
69
|
+
[
|
|
70
|
+
{'role': 'system', 'content': 'system_content'},
|
|
71
|
+
{'role': 'user', 'content': 'user_content'},
|
|
72
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
|
|
73
|
+
]
|
|
74
|
+
pkl: [
|
|
75
|
+
[0, 1, 2, 3],
|
|
76
|
+
[4, 5, 6, 7]
|
|
77
|
+
]
|
|
64
78
|
"""
|
|
65
79
|
def __init__(
|
|
66
80
|
self,
|
|
@@ -75,22 +89,20 @@ class LineByLineTextDataset(Dataset):
|
|
|
75
89
|
self.tokens_per_image = tokens_per_image
|
|
76
90
|
self.input_ids = []
|
|
77
91
|
self.image_tags = []
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
self.input_ids = tokens
|
|
92
|
+
self.plain_text = False
|
|
93
|
+
|
|
94
|
+
file_type = _get_file_type(file_path)
|
|
95
|
+
if file_type == 'jsonl':
|
|
96
|
+
self.plain_text = True
|
|
97
|
+
|
|
98
|
+
with open(file_path, 'r') as f:
|
|
99
|
+
for line in f:
|
|
100
|
+
self.input_ids.append(json.loads(line.strip()))
|
|
101
|
+
elif file_type == 'pkl':
|
|
102
|
+
with open(file_path, 'rb') as f:
|
|
103
|
+
self.input_ids = pickle.load(f)
|
|
104
|
+
else:
|
|
105
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
94
106
|
|
|
95
107
|
if image_tags_file_path:
|
|
96
108
|
with open(image_tags_file_path, 'r') as f:
|
|
@@ -102,8 +114,14 @@ class LineByLineTextDataset(Dataset):
|
|
|
102
114
|
return len(self.input_ids)
|
|
103
115
|
|
|
104
116
|
def __getitem__(self, item):
|
|
105
|
-
|
|
117
|
+
if self.plain_text:
|
|
118
|
+
inputs = TrainerTools().tokenizer.apply_chat_template(self.input_ids[item])
|
|
119
|
+
else:
|
|
120
|
+
inputs = self.input_ids[item]
|
|
121
|
+
|
|
122
|
+
inputs = torch.tensor(inputs).long()
|
|
106
123
|
image_tag = self.image_tags[item] if self.image_tags else None
|
|
124
|
+
|
|
107
125
|
if self.tokens_per_image != -1:
|
|
108
126
|
inputs = repeat_image_tok(inputs, self.tokens_per_image)
|
|
109
127
|
else:
|
|
@@ -111,48 +129,134 @@ class LineByLineTextDataset(Dataset):
|
|
|
111
129
|
|
|
112
130
|
inputs = inputs[:self.max_len]
|
|
113
131
|
|
|
114
|
-
return {
|
|
132
|
+
return {
|
|
133
|
+
'inputs': inputs,
|
|
134
|
+
'image_tag': image_tag
|
|
135
|
+
}
|
|
115
136
|
|
|
116
137
|
|
|
117
138
|
class DPODataset(Dataset):
|
|
139
|
+
"""
|
|
140
|
+
适用于dpo阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
141
|
+
jsonl: {'chosen':
|
|
142
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
143
|
+
{'role': 'user', 'content': 'user_content'},
|
|
144
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
145
|
+
'rejected':
|
|
146
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
147
|
+
{'role': 'user', 'content': 'user_content'},
|
|
148
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
149
|
+
}\n
|
|
150
|
+
{'chosen':
|
|
151
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
152
|
+
{'role': 'user', 'content': 'user_content'},
|
|
153
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
154
|
+
'rejected':
|
|
155
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
156
|
+
{'role': 'user', 'content': 'user_content'},
|
|
157
|
+
'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
158
|
+
}
|
|
159
|
+
pkl: [
|
|
160
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
161
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
162
|
+
]
|
|
163
|
+
"""
|
|
118
164
|
def __init__(self, file_path, max_len):
|
|
119
165
|
self.max_len = max_len
|
|
120
166
|
self.chosen_ids = []
|
|
121
167
|
self.rejected_ids = []
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
self.
|
|
127
|
-
|
|
168
|
+
self.plain_text = False
|
|
169
|
+
|
|
170
|
+
file_type = _get_file_type(file_path)
|
|
171
|
+
if file_type == 'jsonl':
|
|
172
|
+
self.plain_text = True
|
|
173
|
+
|
|
174
|
+
with open(file_path, 'r') as f:
|
|
175
|
+
for line in f:
|
|
176
|
+
json_ = json.loads(line.strip())
|
|
177
|
+
self.chosen_ids.append(json_['chosen'])
|
|
178
|
+
self.rejected_ids.append(json_['rejected'])
|
|
179
|
+
elif file_type == 'pkl':
|
|
180
|
+
with open(file_path, 'rb') as f:
|
|
181
|
+
tokens = pickle.load(f)
|
|
182
|
+
|
|
183
|
+
for token in tokens:
|
|
184
|
+
self.chosen_ids.append(token['chosen'])
|
|
185
|
+
self.rejected_ids.append(token['rejected'])
|
|
186
|
+
else:
|
|
187
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
128
188
|
|
|
129
189
|
def __len__(self):
|
|
130
190
|
return len(self.chosen_ids)
|
|
131
191
|
|
|
132
192
|
def __getitem__(self, item):
|
|
133
|
-
|
|
134
|
-
|
|
193
|
+
if self.plain_text:
|
|
194
|
+
chosen_id = TrainerTools().tokenizer.apply_chat_template(self.chosen_ids[item])
|
|
195
|
+
rejected_id = TrainerTools().tokenizer.apply_chat_template(self.rejected_ids[item])
|
|
196
|
+
else:
|
|
197
|
+
chosen_id = self.chosen_ids[item]
|
|
198
|
+
rejected_id = self.rejected_ids[item]
|
|
135
199
|
|
|
136
|
-
return {
|
|
200
|
+
return {
|
|
201
|
+
'chosen': chosen_id[:self.max_len],
|
|
202
|
+
'rejected': rejected_id[:self.max_len]
|
|
203
|
+
}
|
|
137
204
|
|
|
138
205
|
|
|
139
206
|
class GRPORolloutDataset(Dataset):
|
|
207
|
+
"""
|
|
208
|
+
适用于grpo(gspo)阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
209
|
+
jsonl: {'prompt':
|
|
210
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
211
|
+
{'role': 'user', 'content': 'user_content'},
|
|
212
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
213
|
+
'answer': '10'
|
|
214
|
+
}\n
|
|
215
|
+
{'prompt':
|
|
216
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
217
|
+
{'role': 'user', 'content': 'user_content'},
|
|
218
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
219
|
+
'answer': '10'
|
|
220
|
+
}
|
|
221
|
+
pkl: [
|
|
222
|
+
{'prompt': xxx, 'answer': xxx},
|
|
223
|
+
{'prompt': xxx, 'answer': xxx},
|
|
224
|
+
]
|
|
225
|
+
"""
|
|
140
226
|
def __init__(self, file_path):
|
|
141
227
|
self.questions = []
|
|
142
228
|
self.answers = []
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
self.
|
|
148
|
-
|
|
229
|
+
self.plain_text = False
|
|
230
|
+
|
|
231
|
+
file_type = _get_file_type(file_path)
|
|
232
|
+
if file_type == 'jsonl':
|
|
233
|
+
self.plain_text = True
|
|
234
|
+
|
|
235
|
+
with open(file_path, 'r') as f:
|
|
236
|
+
for line in f:
|
|
237
|
+
json_ = json.loads(line.strip())
|
|
238
|
+
self.questions.append(json_['prompt'])
|
|
239
|
+
self.answers.append(json_['answer'])
|
|
240
|
+
elif file_type == 'pkl':
|
|
241
|
+
with open(file_path, 'rb') as f:
|
|
242
|
+
tokens = pickle.load(f)
|
|
243
|
+
|
|
244
|
+
for token in tokens:
|
|
245
|
+
self.questions.append(token['prompt'])
|
|
246
|
+
self.answers.append(token['answer'])
|
|
247
|
+
else:
|
|
248
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
149
249
|
|
|
150
250
|
def __len__(self):
|
|
151
251
|
return len(self.questions)
|
|
152
252
|
|
|
153
253
|
def __getitem__(self, item):
|
|
154
|
-
|
|
155
|
-
|
|
254
|
+
if self.plain_text:
|
|
255
|
+
question = TrainerTools().tokenizer.apply_chat_template(self.questions[item])
|
|
256
|
+
answer = TrainerTools().tokenizer.encode(self.answers[item])
|
|
257
|
+
else:
|
|
258
|
+
question = self.questions[item]
|
|
259
|
+
answer = self.answers[item]
|
|
156
260
|
|
|
157
261
|
return {
|
|
158
262
|
'prompt': torch.tensor(question).long(),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
|
|
3
|
-
llm_trainer/dataset.py,sha256=
|
|
3
|
+
llm_trainer/dataset.py,sha256=UL3fGeM4XSlyNQRZH-139u3LujqAQx3YyaxNRewk6LE,8935
|
|
4
4
|
llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12353
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
@@ -20,14 +20,14 @@ llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
|
20
20
|
llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
|
|
21
21
|
llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
|
|
22
22
|
llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
|
|
23
|
-
project_llm_trainer-0.7.
|
|
24
|
-
project_llm_trainer-0.7.
|
|
25
|
-
project_llm_trainer-0.7.
|
|
26
|
-
project_llm_trainer-0.7.
|
|
27
|
-
project_llm_trainer-0.7.
|
|
28
|
-
project_llm_trainer-0.7.
|
|
29
|
-
project_llm_trainer-0.7.
|
|
30
|
-
project_llm_trainer-0.7.
|
|
31
|
-
project_llm_trainer-0.7.
|
|
32
|
-
project_llm_trainer-0.7.
|
|
33
|
-
project_llm_trainer-0.7.
|
|
23
|
+
project_llm_trainer-0.7.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.8.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
|
|
25
|
+
project_llm_trainer-0.7.8.data/scripts/ds_train,sha256=3nXNNKmYI7miqyBdf-Ijl_rW1cGIKrAMZ1CSswN_gGo,665
|
|
26
|
+
project_llm_trainer-0.7.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.8.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
|
|
30
|
+
project_llm_trainer-0.7.8.dist-info/METADATA,sha256=rSYUrEkdjPCyYUqT2SOw3-hzT40wU3AwEw-ouHh1rBY,195
|
|
31
|
+
project_llm_trainer-0.7.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.8.dist-info/RECORD,,
|
{project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.8.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|