project-llm-trainer 0.7.7__py3-none-any.whl → 0.7.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dataset.py +165 -61
- llm_trainer/dpo_trainer.py +38 -20
- llm_trainer/loss.py +6 -10
- llm_trainer/tokenizer.py +16 -1
- llm_trainer/utils.py +36 -11
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/ds_train +7 -6
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.9.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.9.dist-info}/RECORD +16 -16
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.9.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.7.dist-info → project_llm_trainer-0.7.9.dist-info}/top_level.txt +0 -0
llm_trainer/dataset.py
CHANGED
|
@@ -1,28 +1,30 @@
|
|
|
1
|
-
import os.path
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from torch.utils.data import Dataset
|
|
5
3
|
import pickle
|
|
6
4
|
import csv
|
|
5
|
+
import json
|
|
7
6
|
|
|
8
7
|
from .tools import TrainerTools
|
|
9
8
|
from .utils import repeat_image_tok
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
11
|
+
"""
|
|
12
|
+
support jsonl and pkl
|
|
13
|
+
"""
|
|
14
|
+
def _get_file_type(file_path: str):
|
|
15
|
+
if file_path.endswith('.jsonl'):
|
|
16
|
+
return 'jsonl'
|
|
17
|
+
elif file_path.endswith('.pkl'):
|
|
18
|
+
return 'pkl'
|
|
19
|
+
|
|
20
|
+
return None
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class TextDataset(Dataset):
|
|
24
24
|
"""
|
|
25
|
-
适用于pretrain
|
|
25
|
+
适用于pretrain阶段,数据格式支持jsonl和pkl,如果是jsonl会在init阶段全部encode成token
|
|
26
|
+
jsonl: {'text': 'text1'}\n{'text': 'text2'}
|
|
27
|
+
pkl: [0, 1, 2, 3 ...]
|
|
26
28
|
"""
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
@@ -34,19 +36,17 @@ class TextDataset(Dataset):
|
|
|
34
36
|
|
|
35
37
|
self.input_ids = []
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
if
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
with open(cache_file, 'wb') as f:
|
|
49
|
-
pickle.dump(tokens, f)
|
|
39
|
+
file_type = _get_file_type(file_path)
|
|
40
|
+
if file_type == 'jsonl':
|
|
41
|
+
tokens = []
|
|
42
|
+
with open(file_path, 'r') as f:
|
|
43
|
+
for line in f:
|
|
44
|
+
tokens.extend(TrainerTools().tokenizer.encode(json.loads(line.strip())['text']))
|
|
45
|
+
elif file_type == 'pkl':
|
|
46
|
+
with open(file_path, 'rb') as f:
|
|
47
|
+
tokens = pickle.load(f)
|
|
48
|
+
else:
|
|
49
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
50
50
|
|
|
51
51
|
for i in range(0, len(tokens) - block_size + 1, stride):
|
|
52
52
|
self.input_ids.append(tokens[i:i+block_size])
|
|
@@ -60,7 +60,21 @@ class TextDataset(Dataset):
|
|
|
60
60
|
|
|
61
61
|
class LineByLineTextDataset(Dataset):
|
|
62
62
|
"""
|
|
63
|
-
适用于sft阶段
|
|
63
|
+
适用于sft阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
64
|
+
jsonl: [
|
|
65
|
+
{'role': 'system', 'content': 'system_content'},
|
|
66
|
+
{'role': 'user', 'content': 'user_content'},
|
|
67
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
|
|
68
|
+
]\n
|
|
69
|
+
[
|
|
70
|
+
{'role': 'system', 'content': 'system_content'},
|
|
71
|
+
{'role': 'user', 'content': 'user_content'},
|
|
72
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
|
|
73
|
+
]
|
|
74
|
+
pkl: [
|
|
75
|
+
[0, 1, 2, 3],
|
|
76
|
+
[4, 5, 6, 7]
|
|
77
|
+
]
|
|
64
78
|
"""
|
|
65
79
|
def __init__(
|
|
66
80
|
self,
|
|
@@ -75,22 +89,20 @@ class LineByLineTextDataset(Dataset):
|
|
|
75
89
|
self.tokens_per_image = tokens_per_image
|
|
76
90
|
self.input_ids = []
|
|
77
91
|
self.image_tags = []
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
self.input_ids = tokens
|
|
92
|
+
self.plain_text = False
|
|
93
|
+
|
|
94
|
+
file_type = _get_file_type(file_path)
|
|
95
|
+
if file_type == 'jsonl':
|
|
96
|
+
self.plain_text = True
|
|
97
|
+
|
|
98
|
+
with open(file_path, 'r') as f:
|
|
99
|
+
for line in f:
|
|
100
|
+
self.input_ids.append(json.loads(line.strip()))
|
|
101
|
+
elif file_type == 'pkl':
|
|
102
|
+
with open(file_path, 'rb') as f:
|
|
103
|
+
self.input_ids = pickle.load(f)
|
|
104
|
+
else:
|
|
105
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
94
106
|
|
|
95
107
|
if image_tags_file_path:
|
|
96
108
|
with open(image_tags_file_path, 'r') as f:
|
|
@@ -102,8 +114,14 @@ class LineByLineTextDataset(Dataset):
|
|
|
102
114
|
return len(self.input_ids)
|
|
103
115
|
|
|
104
116
|
def __getitem__(self, item):
|
|
105
|
-
|
|
117
|
+
if self.plain_text:
|
|
118
|
+
inputs = TrainerTools().tokenizer.apply_chat_template(self.input_ids[item])
|
|
119
|
+
else:
|
|
120
|
+
inputs = self.input_ids[item]
|
|
121
|
+
|
|
122
|
+
inputs = torch.tensor(inputs).long()
|
|
106
123
|
image_tag = self.image_tags[item] if self.image_tags else None
|
|
124
|
+
|
|
107
125
|
if self.tokens_per_image != -1:
|
|
108
126
|
inputs = repeat_image_tok(inputs, self.tokens_per_image)
|
|
109
127
|
else:
|
|
@@ -111,48 +129,134 @@ class LineByLineTextDataset(Dataset):
|
|
|
111
129
|
|
|
112
130
|
inputs = inputs[:self.max_len]
|
|
113
131
|
|
|
114
|
-
return {
|
|
132
|
+
return {
|
|
133
|
+
'inputs': inputs,
|
|
134
|
+
'image_tag': image_tag
|
|
135
|
+
}
|
|
115
136
|
|
|
116
137
|
|
|
117
138
|
class DPODataset(Dataset):
|
|
139
|
+
"""
|
|
140
|
+
适用于dpo阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
141
|
+
jsonl: {'chosen':
|
|
142
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
143
|
+
{'role': 'user', 'content': 'user_content'},
|
|
144
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
145
|
+
'rejected':
|
|
146
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
147
|
+
{'role': 'user', 'content': 'user_content'},
|
|
148
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
149
|
+
}\n
|
|
150
|
+
{'chosen':
|
|
151
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
152
|
+
{'role': 'user', 'content': 'user_content'},
|
|
153
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
154
|
+
'rejected':
|
|
155
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
156
|
+
{'role': 'user', 'content': 'user_content'},
|
|
157
|
+
'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
158
|
+
}
|
|
159
|
+
pkl: [
|
|
160
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
161
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
162
|
+
]
|
|
163
|
+
"""
|
|
118
164
|
def __init__(self, file_path, max_len):
|
|
119
165
|
self.max_len = max_len
|
|
120
166
|
self.chosen_ids = []
|
|
121
167
|
self.rejected_ids = []
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
self.
|
|
127
|
-
|
|
168
|
+
self.plain_text = False
|
|
169
|
+
|
|
170
|
+
file_type = _get_file_type(file_path)
|
|
171
|
+
if file_type == 'jsonl':
|
|
172
|
+
self.plain_text = True
|
|
173
|
+
|
|
174
|
+
with open(file_path, 'r') as f:
|
|
175
|
+
for line in f:
|
|
176
|
+
json_ = json.loads(line.strip())
|
|
177
|
+
self.chosen_ids.append(json_['chosen'])
|
|
178
|
+
self.rejected_ids.append(json_['rejected'])
|
|
179
|
+
elif file_type == 'pkl':
|
|
180
|
+
with open(file_path, 'rb') as f:
|
|
181
|
+
tokens = pickle.load(f)
|
|
182
|
+
|
|
183
|
+
for token in tokens:
|
|
184
|
+
self.chosen_ids.append(token['chosen'])
|
|
185
|
+
self.rejected_ids.append(token['rejected'])
|
|
186
|
+
else:
|
|
187
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
128
188
|
|
|
129
189
|
def __len__(self):
|
|
130
190
|
return len(self.chosen_ids)
|
|
131
191
|
|
|
132
192
|
def __getitem__(self, item):
|
|
133
|
-
|
|
134
|
-
|
|
193
|
+
if self.plain_text:
|
|
194
|
+
chosen_id = TrainerTools().tokenizer.apply_chat_template(self.chosen_ids[item])
|
|
195
|
+
rejected_id = TrainerTools().tokenizer.apply_chat_template(self.rejected_ids[item])
|
|
196
|
+
else:
|
|
197
|
+
chosen_id = self.chosen_ids[item]
|
|
198
|
+
rejected_id = self.rejected_ids[item]
|
|
135
199
|
|
|
136
|
-
return {
|
|
200
|
+
return {
|
|
201
|
+
'chosen': chosen_id[:self.max_len],
|
|
202
|
+
'rejected': rejected_id[:self.max_len]
|
|
203
|
+
}
|
|
137
204
|
|
|
138
205
|
|
|
139
206
|
class GRPORolloutDataset(Dataset):
|
|
207
|
+
"""
|
|
208
|
+
适用于grpo(gspo)阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
209
|
+
jsonl: {'prompt':
|
|
210
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
211
|
+
{'role': 'user', 'content': 'user_content'},
|
|
212
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
213
|
+
'answer': '10'
|
|
214
|
+
}\n
|
|
215
|
+
{'prompt':
|
|
216
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
217
|
+
{'role': 'user', 'content': 'user_content'},
|
|
218
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
219
|
+
'answer': '10'
|
|
220
|
+
}
|
|
221
|
+
pkl: [
|
|
222
|
+
{'prompt': xxx, 'answer': xxx},
|
|
223
|
+
{'prompt': xxx, 'answer': xxx},
|
|
224
|
+
]
|
|
225
|
+
"""
|
|
140
226
|
def __init__(self, file_path):
|
|
141
227
|
self.questions = []
|
|
142
228
|
self.answers = []
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
self.
|
|
148
|
-
|
|
229
|
+
self.plain_text = False
|
|
230
|
+
|
|
231
|
+
file_type = _get_file_type(file_path)
|
|
232
|
+
if file_type == 'jsonl':
|
|
233
|
+
self.plain_text = True
|
|
234
|
+
|
|
235
|
+
with open(file_path, 'r') as f:
|
|
236
|
+
for line in f:
|
|
237
|
+
json_ = json.loads(line.strip())
|
|
238
|
+
self.questions.append(json_['prompt'])
|
|
239
|
+
self.answers.append(json_['answer'])
|
|
240
|
+
elif file_type == 'pkl':
|
|
241
|
+
with open(file_path, 'rb') as f:
|
|
242
|
+
tokens = pickle.load(f)
|
|
243
|
+
|
|
244
|
+
for token in tokens:
|
|
245
|
+
self.questions.append(token['prompt'])
|
|
246
|
+
self.answers.append(token['answer'])
|
|
247
|
+
else:
|
|
248
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
149
249
|
|
|
150
250
|
def __len__(self):
|
|
151
251
|
return len(self.questions)
|
|
152
252
|
|
|
153
253
|
def __getitem__(self, item):
|
|
154
|
-
|
|
155
|
-
|
|
254
|
+
if self.plain_text:
|
|
255
|
+
question = TrainerTools().tokenizer.apply_chat_template(self.questions[item])
|
|
256
|
+
answer = TrainerTools().tokenizer.encode(self.answers[item])
|
|
257
|
+
else:
|
|
258
|
+
question = self.questions[item]
|
|
259
|
+
answer = self.answers[item]
|
|
156
260
|
|
|
157
261
|
return {
|
|
158
262
|
'prompt': torch.tensor(question).long(),
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -12,7 +12,8 @@ from .loss import DPOLoss
|
|
|
12
12
|
from .tools import TrainerTools
|
|
13
13
|
from .utils import (
|
|
14
14
|
autocast,
|
|
15
|
-
get_dpo_collate_fn
|
|
15
|
+
get_dpo_collate_fn,
|
|
16
|
+
fill_loss_mask
|
|
16
17
|
)
|
|
17
18
|
from .partition_utils import sync_model_params
|
|
18
19
|
|
|
@@ -84,7 +85,6 @@ class DPOTrainer(Trainer):
|
|
|
84
85
|
def _calc_loss(self, inputs, attention_mask, logits, labels): ...
|
|
85
86
|
|
|
86
87
|
def _log_probs_from_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
87
|
-
# https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
|
|
88
88
|
if logits.dtype in [torch.float32, torch.float64]:
|
|
89
89
|
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
|
90
90
|
logsumexp_values = torch.stack(
|
|
@@ -102,25 +102,26 @@ class DPOTrainer(Trainer):
|
|
|
102
102
|
return log_probs_labels
|
|
103
103
|
|
|
104
104
|
|
|
105
|
-
def _logprobs(self, logits, labels,
|
|
105
|
+
def _logprobs(self, logits, labels, attention_mask):
|
|
106
106
|
"""
|
|
107
107
|
Calculate the average log probabilities for a batch of sequences.
|
|
108
108
|
|
|
109
109
|
Args:
|
|
110
110
|
logits (torch.Tensor): Logits from the model with shape (B, T, V)
|
|
111
111
|
labels (torch.Tensor): Ground truth labels with shape (B, T).
|
|
112
|
-
|
|
112
|
+
attention_mask (torch.Tensor): Mask tensor with shape (B, T) indicating
|
|
113
113
|
which tokens are not padding (1 for valid tokens, 0 for padding).
|
|
114
114
|
|
|
115
115
|
Returns:
|
|
116
116
|
torch.Tensor: Average log probabilities for each sequence in the batch.
|
|
117
117
|
Shape is (B,) representing the mean log probability for each sequence.
|
|
118
118
|
"""
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
loss_masks = attention_mask.clone().bool()
|
|
120
|
+
loss_masks = fill_loss_mask(loss_masks, labels)
|
|
121
121
|
|
|
122
|
-
|
|
123
|
-
|
|
122
|
+
logits = logits[:, :-1, :]
|
|
123
|
+
labels = labels[:, 1:].clone()
|
|
124
|
+
loss_masks = loss_masks[:, 1:]
|
|
124
125
|
|
|
125
126
|
# dummy token; we'll ignore the losses on these tokens later
|
|
126
127
|
labels[labels == -100] = 0
|
|
@@ -129,11 +130,10 @@ class DPOTrainer(Trainer):
|
|
|
129
130
|
per_token_logps = self._log_probs_from_logits(logits, labels)
|
|
130
131
|
|
|
131
132
|
# Apply the mask to set log-probs of padding tokens to 0
|
|
132
|
-
logprobs_sums = (per_token_logps *
|
|
133
|
-
|
|
134
|
-
# logprobs_means = (per_token_logps * mask).sum(-1) / mask.sum(-1)
|
|
133
|
+
logprobs_sums = (per_token_logps * loss_masks).sum(-1)
|
|
134
|
+
logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
|
|
135
135
|
|
|
136
|
-
return logprobs_sums
|
|
136
|
+
return logprobs_sums, logprobs_means
|
|
137
137
|
|
|
138
138
|
def train(self):
|
|
139
139
|
# 梯度累积步数
|
|
@@ -147,6 +147,7 @@ class DPOTrainer(Trainer):
|
|
|
147
147
|
last_best_checkpoint_loss: Optional[float] = None
|
|
148
148
|
|
|
149
149
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
150
|
+
nll_loss_coef = self.train_config.dpo_config.nll_loss_coef
|
|
150
151
|
|
|
151
152
|
for epoch in range(self.train_config.n_epochs):
|
|
152
153
|
self.train_model.train()
|
|
@@ -188,36 +189,53 @@ class DPOTrainer(Trainer):
|
|
|
188
189
|
try:
|
|
189
190
|
chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
|
|
190
191
|
chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
|
|
192
|
+
|
|
191
193
|
rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
|
|
192
194
|
rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
|
|
193
195
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
+
chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
|
|
197
|
+
rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
|
|
196
198
|
|
|
197
199
|
# 在batch维度concat
|
|
198
200
|
# [chosen, chosen, reject, reject]
|
|
199
201
|
concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
|
|
200
202
|
concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
|
|
201
|
-
|
|
203
|
+
concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
|
|
202
204
|
|
|
203
205
|
if TrainerTools().parallel.parallel_train:
|
|
204
206
|
self.train_model.require_backward_grad_sync = need_update_grad
|
|
205
207
|
|
|
206
208
|
with autocast(TrainerTools().parallel.device_type):
|
|
207
|
-
policy_outputs = self.train_model(concat_inputs, attention_mask=
|
|
208
|
-
|
|
209
|
+
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
|
|
210
|
+
policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels, concat_attention_masks)
|
|
209
211
|
aux_loss = policy_outputs.get('aux_loss')
|
|
210
212
|
|
|
211
213
|
with torch.no_grad():
|
|
212
|
-
ref_outputs = self.ref_model(concat_inputs, attention_mask=
|
|
213
|
-
|
|
214
|
+
ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
|
|
215
|
+
ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels, concat_attention_masks)
|
|
216
|
+
|
|
217
|
+
policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
|
|
218
|
+
policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
|
|
219
|
+
|
|
220
|
+
ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
|
|
221
|
+
ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
|
|
222
|
+
|
|
223
|
+
nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
|
|
214
224
|
|
|
215
225
|
# calc loss
|
|
216
|
-
loss = self.criterion(
|
|
226
|
+
loss = self.criterion(
|
|
227
|
+
policy_chosen_logps,
|
|
228
|
+
policy_rejected_logps,
|
|
229
|
+
ref_chosen_logps,
|
|
230
|
+
ref_rejected_logps
|
|
231
|
+
)
|
|
217
232
|
|
|
218
233
|
if aux_loss_coef and aux_loss:
|
|
219
234
|
loss += aux_loss_coef * aux_loss
|
|
220
235
|
|
|
236
|
+
if nll_loss_coef and nll_loss:
|
|
237
|
+
loss += nll_loss_coef * nll_loss
|
|
238
|
+
|
|
221
239
|
if gradient_accumulation_steps > 1:
|
|
222
240
|
loss = loss / gradient_accumulation_steps
|
|
223
241
|
|
llm_trainer/loss.py
CHANGED
|
@@ -92,17 +92,13 @@ class DPOLoss(nn.Module):
|
|
|
92
92
|
|
|
93
93
|
def forward(
|
|
94
94
|
self,
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
policy_chosen_logps: torch.Tensor,
|
|
96
|
+
policy_reject_logps: torch.Tensor,
|
|
97
|
+
ref_chosen_logps: torch.Tensor,
|
|
98
|
+
ref_reject_logps: torch.Tensor
|
|
97
99
|
) -> torch.Tensor:
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
ref_reject_probs = reference_logps[batch_size//2:]
|
|
101
|
-
policy_chosen_probs = policy_logps[:batch_size//2]
|
|
102
|
-
policy_reject_probs = policy_logps[batch_size//2:]
|
|
103
|
-
|
|
104
|
-
pi_logratios = policy_chosen_probs - policy_reject_probs
|
|
105
|
-
ref_logratios = ref_chosen_probs - ref_reject_probs
|
|
100
|
+
pi_logratios = policy_chosen_logps - policy_reject_logps
|
|
101
|
+
ref_logratios = ref_chosen_logps - ref_reject_logps
|
|
106
102
|
logits = pi_logratios - ref_logratios
|
|
107
103
|
|
|
108
104
|
if self.ipo:
|
llm_trainer/tokenizer.py
CHANGED
|
@@ -3,7 +3,7 @@ import warnings
|
|
|
3
3
|
from typing import List, Dict, Union
|
|
4
4
|
from transformers import Qwen2TokenizerFast
|
|
5
5
|
from transformers import AddedToken
|
|
6
|
-
from transformers import
|
|
6
|
+
from transformers import LlamaTokenizerFast
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
TOKEN_TYPE_QWEN = 'qwen'
|
|
@@ -164,3 +164,18 @@ class Tokenizer:
|
|
|
164
164
|
|
|
165
165
|
return chat_template
|
|
166
166
|
|
|
167
|
+
def get_special_tokens_dict(self):
|
|
168
|
+
return {
|
|
169
|
+
self.text_end: self.end,
|
|
170
|
+
self.text_pad: self.pad,
|
|
171
|
+
self.text_unk: self.unk,
|
|
172
|
+
self.text_user: self.user,
|
|
173
|
+
self.text_assistant: self.assistant,
|
|
174
|
+
self.text_think_start: self.think_start,
|
|
175
|
+
self.text_think_end: self.think_end,
|
|
176
|
+
self.text_answer_start: self.answer_start,
|
|
177
|
+
self.text_answer_end: self.answer_end,
|
|
178
|
+
self.text_system: self.system,
|
|
179
|
+
self.text_image: self.image,
|
|
180
|
+
}
|
|
181
|
+
|
llm_trainer/utils.py
CHANGED
|
@@ -154,16 +154,22 @@ def batch_repeat_image_tok(
|
|
|
154
154
|
|
|
155
155
|
|
|
156
156
|
def pretrain_collate_fn(batch_data):
|
|
157
|
-
|
|
157
|
+
# [[x,x,x], [y,y,y]]
|
|
158
|
+
inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
159
|
+
# crossEntropy默认的ignore_index是-100
|
|
160
|
+
labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
|
|
158
161
|
|
|
159
162
|
# inputs, labels
|
|
160
|
-
return {
|
|
163
|
+
return {
|
|
164
|
+
'inputs': inputs,
|
|
165
|
+
'labels': labels
|
|
166
|
+
}
|
|
161
167
|
|
|
162
168
|
|
|
163
169
|
def get_sft_collate_fn(mask_prompt: bool):
|
|
164
170
|
def sft_collate_fn(batch_data):
|
|
165
171
|
"""
|
|
166
|
-
|
|
172
|
+
如果是sft,则不计算prompt部分的loss, 例如:
|
|
167
173
|
logits: [USER]你好[BOT]我好[SEP]
|
|
168
174
|
labels: [USER]你好[BOT]我好[SEP]
|
|
169
175
|
|
|
@@ -184,11 +190,19 @@ def get_sft_collate_fn(mask_prompt: bool):
|
|
|
184
190
|
batch_train_data.append(item['inputs'])
|
|
185
191
|
image_tags.append(item['image_tag'])
|
|
186
192
|
|
|
187
|
-
|
|
193
|
+
# [[x,x,x], [y,y,y]]
|
|
194
|
+
inputs = pad_sequence(batch_train_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
195
|
+
# crossEntropy默认的ignore_index是-100
|
|
196
|
+
labels = pad_sequence(batch_train_data, batch_first=True, padding_value=-100)
|
|
197
|
+
|
|
188
198
|
if mask_prompt:
|
|
189
199
|
labels = _mask_prompt(labels)
|
|
190
200
|
|
|
191
|
-
return {
|
|
201
|
+
return {
|
|
202
|
+
'inputs': inputs,
|
|
203
|
+
'labels': labels,
|
|
204
|
+
'image_tags': image_tags
|
|
205
|
+
}
|
|
192
206
|
|
|
193
207
|
return sft_collate_fn
|
|
194
208
|
|
|
@@ -295,13 +309,24 @@ def join_batch(batch_data: list[dict]) -> dict:
|
|
|
295
309
|
return result
|
|
296
310
|
|
|
297
311
|
|
|
298
|
-
def
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
labels
|
|
312
|
+
def fill_loss_mask(loss_masks, labels):
|
|
313
|
+
"""
|
|
314
|
+
将loss_mask中prompt部分强制设置为False
|
|
315
|
+
loss_masks: shape (B, T)
|
|
316
|
+
labels: shape (B, T)
|
|
317
|
+
"""
|
|
318
|
+
tokenizer = TrainerTools().tokenizer
|
|
319
|
+
# 支持多轮会话的mask
|
|
320
|
+
for batch, label in enumerate(labels):
|
|
321
|
+
start_index = -1
|
|
322
|
+
for index, token in enumerate(label):
|
|
323
|
+
if token == tokenizer.system or token == tokenizer.user:
|
|
324
|
+
start_index = index
|
|
325
|
+
elif token == tokenizer.end and start_index != -1:
|
|
326
|
+
loss_masks[batch, start_index:index + 1] = False
|
|
327
|
+
start_index = -1
|
|
303
328
|
|
|
304
|
-
return
|
|
329
|
+
return loss_masks
|
|
305
330
|
|
|
306
331
|
|
|
307
332
|
def _mask_prompt(labels):
|
|
@@ -10,14 +10,15 @@ if __name__ == '__main__':
|
|
|
10
10
|
if len(arguments) > 1:
|
|
11
11
|
# 0,1,2,3
|
|
12
12
|
cuda_visible_devive = arguments[1]
|
|
13
|
-
else:
|
|
14
|
-
cuda_visible_devive = None
|
|
15
13
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
14
|
+
# cuda location
|
|
15
|
+
if len(arguments) > 2:
|
|
16
|
+
cuda_loc = arguments[2]
|
|
17
|
+
else:
|
|
18
|
+
cuda_loc = 'localhost'
|
|
19
19
|
else:
|
|
20
|
-
|
|
20
|
+
cuda_visible_devive = None
|
|
21
|
+
cuda_loc = None
|
|
21
22
|
|
|
22
23
|
os.environ['PARALLEL_TYPE'] = 'ds'
|
|
23
24
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=X5ZeUtJlxVz7pnWQLaS-y7UIZOaOAnZTt2L8rSAPzUs,4428
|
|
3
|
-
llm_trainer/dataset.py,sha256=
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
3
|
+
llm_trainer/dataset.py,sha256=UL3fGeM4XSlyNQRZH-139u3LujqAQx3YyaxNRewk6LE,8935
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=Bgds18UWFhzf_UNCFN-iBCdhKf9pcXJBFPEc32oJeXA,13354
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
|
|
8
8
|
llm_trainer/grpo_trainer.py,sha256=MXnP8Kc9CQJw0CB3uMbHxIYwvpuujai4hgbbpUut_K4,16808
|
|
9
9
|
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
|
-
llm_trainer/loss.py,sha256=
|
|
10
|
+
llm_trainer/loss.py,sha256=RhTxftLMj1Tqc5pkUvJiZumfbMEPWL8GBGxdTfQggmk,6744
|
|
11
11
|
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
@@ -15,19 +15,19 @@ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,
|
|
|
15
15
|
llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
|
|
16
16
|
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
|
|
18
|
-
llm_trainer/tokenizer.py,sha256=
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=0-xQCMz1xiPTDAZiYsVsiECSoZ_1eIvW9XsZOoFfakQ,7250
|
|
19
19
|
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
20
|
llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
|
|
21
21
|
llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
|
|
22
|
-
llm_trainer/utils.py,sha256=
|
|
23
|
-
project_llm_trainer-0.7.
|
|
24
|
-
project_llm_trainer-0.7.
|
|
25
|
-
project_llm_trainer-0.7.
|
|
26
|
-
project_llm_trainer-0.7.
|
|
27
|
-
project_llm_trainer-0.7.
|
|
28
|
-
project_llm_trainer-0.7.
|
|
29
|
-
project_llm_trainer-0.7.
|
|
30
|
-
project_llm_trainer-0.7.
|
|
31
|
-
project_llm_trainer-0.7.
|
|
32
|
-
project_llm_trainer-0.7.
|
|
33
|
-
project_llm_trainer-0.7.
|
|
22
|
+
llm_trainer/utils.py,sha256=xC5plG-8-_Al5yIF5xIU5lroOcBBk98TEhtUJrazZPE,12305
|
|
23
|
+
project_llm_trainer-0.7.9.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.9.data/scripts/ddp_train,sha256=Z-309mM56CN0m3bxoeC5us4LUuwuNnoiOm3-fDdLMjQ,566
|
|
25
|
+
project_llm_trainer-0.7.9.data/scripts/ds_train,sha256=tME0xmMdX1D9XuVo07D9dilW5VIWavBS3UK9DoY67WI,709
|
|
26
|
+
project_llm_trainer-0.7.9.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.9.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.9.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.9.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
|
|
30
|
+
project_llm_trainer-0.7.9.dist-info/METADATA,sha256=mDGLc1BjmIlOPz85JYB5bFnlXJgJ5VaNesW4z0HDZCA,195
|
|
31
|
+
project_llm_trainer-0.7.9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.9.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.9.dist-info/RECORD,,
|
{project_llm_trainer-0.7.7.data → project_llm_trainer-0.7.9.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|