project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,126 @@
1
+ import os
2
+ from typing import Optional, Union
3
+ import shutil
4
+ import torch
5
+ from torch import nn
6
+ from torch.optim import Optimizer
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+
9
+ from .parallel import DsParallel
10
+ from .scheduler import LRScheduler
11
+ from .tools import TrainerTools
12
+
13
+ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
14
+
15
+ def save_checkpoint(
16
+ model: nn.Module,
17
+ optimizer: Optional[Optimizer] = None
18
+ ):
19
+ if isinstance(TrainerTools().parallel, DsParallel):
20
+ from .ds_checkpoint import save_ds_checkpoint
21
+ save_ds_checkpoint(model)
22
+ else:
23
+ if TrainerTools().parallel.is_main_process:
24
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
25
+
26
+ raw_model = model if not isinstance(model, DDP) else model.module
27
+ ckpt = {'model_state_dict': raw_model.state_dict()}
28
+
29
+ if optimizer:
30
+ ckpt.update({'optim_state_dict': optimizer.state_dict()})
31
+
32
+ torch.save(ckpt, checkpoint_name)
33
+
34
+
35
+ def save_best_checkpoint(
36
+ current_loss: float,
37
+ last_best_checkpoint_loss: Optional[float] = None
38
+ ) -> bool:
39
+ # 指定不保存最佳checkpoint
40
+ if os.environ.get('SAVE_BEST_CHECKPOINT', '1') != '1':
41
+ return False
42
+
43
+ need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
44
+ if need_replace and TrainerTools().parallel.is_main_process:
45
+ try:
46
+ if isinstance(TrainerTools().parallel, DsParallel):
47
+ checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
48
+
49
+ if checkpoint_dir.endswith('/'):
50
+ best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
51
+ else:
52
+ best_checkpoint_dir = f'{checkpoint_dir}_best'
53
+
54
+ if not os.path.exists(best_checkpoint_dir):
55
+ os.makedirs(best_checkpoint_dir)
56
+
57
+ if os.path.exists(checkpoint_dir):
58
+ shutil.rmtree(best_checkpoint_dir)
59
+ shutil.copytree(checkpoint_dir, best_checkpoint_dir)
60
+ else:
61
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
62
+ best_checkpoint_name = f'{checkpoint_name}_best'
63
+
64
+ if os.path.exists(checkpoint_name):
65
+ if os.path.exists(best_checkpoint_name):
66
+ os.remove(best_checkpoint_name)
67
+
68
+ shutil.copy2(checkpoint_name, best_checkpoint_name)
69
+ except: pass
70
+
71
+ TrainerTools().parallel.wait('save best checkpoint')
72
+ return need_replace
73
+
74
+
75
+ def load_checkpoint(
76
+ model: nn.Module,
77
+ optimizer: Optional[Optimizer] = None,
78
+ device: Optional[Union[torch.device, str]] = None,
79
+ load_module_only: bool = False
80
+ ):
81
+ if isinstance(TrainerTools().parallel, DsParallel):
82
+ from .ds_checkpoint import load_ds_checkpoint
83
+ load_ds_checkpoint(model, load_module_only=load_module_only)
84
+ else:
85
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
86
+
87
+ if os.path.exists(checkpoint_name):
88
+ state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
89
+ raw_model = model.module if isinstance(model, DDP) else model
90
+ raw_model.load_state_dict(state_dict['model_state_dict'])
91
+
92
+ if optimizer:
93
+ optimizer.load_state_dict(state_dict['optim_state_dict'])
94
+
95
+
96
+ def load_checkpoint_for_eval(
97
+ model: nn.Module,
98
+ device: Optional[Union[torch.device, str]] = None
99
+ ):
100
+ if isinstance(TrainerTools().parallel, DsParallel):
101
+ from .ds_checkpoint import load_ds_checkpoint_for_eval
102
+ load_ds_checkpoint_for_eval(model)
103
+ else:
104
+ load_checkpoint(model, None, device)
105
+
106
+
107
+ def save_steps(
108
+ global_steps: int,
109
+ lr_scheduler: Optional[LRScheduler] = None,
110
+ ):
111
+ # 暂时只保存主进程的
112
+ if TrainerTools().parallel.is_main_process:
113
+ steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
114
+ ckpt = {'global_steps': global_steps}
115
+ if lr_scheduler:
116
+ ckpt.update(lr_scheduler.get_ckpt_dict())
117
+
118
+ torch.save(ckpt, steps_checkpoint_name)
119
+
120
+
121
+ def load_steps() -> Optional[dict]:
122
+ steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
123
+ if os.path.exists(steps_checkpoint_name):
124
+ return torch.load(steps_checkpoint_name, weights_only=True)
125
+
126
+ return None
llm_trainer/dataset.py ADDED
@@ -0,0 +1,335 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import pickle
4
+ import csv
5
+ import json
6
+ import numpy as np
7
+
8
+ from .tools import TrainerTools
9
+ from .utils import repeat_image_tok
10
+
11
+
12
+ """
13
+ support jsonl and pkl
14
+ """
15
+ def _get_file_type(file_path: str):
16
+ if file_path.endswith('.npy'):
17
+ return 'npy'
18
+ elif file_path.endswith('.jsonl'):
19
+ return 'jsonl'
20
+ elif file_path.endswith('.pkl'):
21
+ return 'pkl'
22
+
23
+ return None
24
+
25
+
26
+ class PretrainDataset(Dataset):
27
+ """
28
+ 适用于pretrain阶段,数据格式支持jsonl和pkl,如果是jsonl会在init阶段全部encode成token
29
+ 1. npy:【推荐】numpy 数组,支持 mmap,内存占用极低
30
+ 2. jsonl: {'text': 'text1'}\n{'text': 'text2'}
31
+ 3. pkl: [0, 1, 2, 3 ...]
32
+ """
33
+ def __init__(
34
+ self,
35
+ file_path,
36
+ block_size,
37
+ stride
38
+ ):
39
+ super().__init__()
40
+
41
+ self.block_size = block_size
42
+ self.stride = stride
43
+ self.use_mmap = False
44
+
45
+ file_type = _get_file_type(file_path)
46
+
47
+ if file_type == 'npy':
48
+ self.input_ids = np.load(file_path, mmap_mode='r')
49
+ self.use_mmap = True
50
+ elif file_type == 'jsonl':
51
+ tokens = []
52
+ with open(file_path, 'r') as f:
53
+ for line in f:
54
+ tokens.extend(TrainerTools().tokenizer.encode(json.loads(line.strip())['text']))
55
+ self.input_ids = torch.tensor(tokens, dtype=torch.int32)
56
+ del tokens
57
+ elif file_type == 'pkl':
58
+ with open(file_path, 'rb') as f:
59
+ tokens = pickle.load(f)
60
+ self.input_ids = torch.tensor(tokens, dtype=torch.int32)
61
+ del tokens
62
+ else:
63
+ raise Exception(f'unsupported file type for {file_path}')
64
+
65
+ if len(self.input_ids) < block_size:
66
+ self.length = 0
67
+ else:
68
+ self.length = (len(self.input_ids) - block_size) // stride + 1
69
+
70
+ def __len__(self):
71
+ return self.length
72
+
73
+ def __getitem__(self, item):
74
+ if item < 0 or item >= self.length:
75
+ raise IndexError(f"Index {item} out of range")
76
+
77
+ start_idx = item * self.stride
78
+ end_idx = start_idx + self.block_size
79
+
80
+ data = self.input_ids[start_idx:end_idx]
81
+
82
+ if self.use_mmap:
83
+ return torch.from_numpy(data.astype(np.int64))
84
+ else:
85
+ return data.long()
86
+
87
+
88
+ class SFTDataset(Dataset):
89
+ """
90
+ 适用于sft阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
91
+ npy: [
92
+ [0, 1, 2, 3],
93
+ [4, 5, 6, 7]
94
+ ]
95
+ jsonl: [
96
+ {'role': 'system', 'content': 'system_content'},
97
+ {'role': 'user', 'content': 'user_content'},
98
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
99
+ ]\n
100
+ [
101
+ {'role': 'system', 'content': 'system_content'},
102
+ {'role': 'user', 'content': 'user_content'},
103
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
104
+ ]
105
+ pkl: [
106
+ [0, 1, 2, 3],
107
+ [4, 5, 6, 7]
108
+ ]
109
+ """
110
+ def __init__(
111
+ self,
112
+ file_path,
113
+ max_len,
114
+ image_tags_file_path=None,
115
+ tokens_per_image=-1
116
+ ):
117
+ super().__init__()
118
+
119
+ self.max_len = max_len
120
+ self.tokens_per_image = tokens_per_image
121
+ self.input_ids = []
122
+ self.image_tags = []
123
+ self.plain_text = False
124
+
125
+ file_type = _get_file_type(file_path)
126
+
127
+ if file_type == 'npy':
128
+ try:
129
+ self.input_ids = np.load(file_path, mmap_mode='r')
130
+ except ValueError:
131
+ self.input_ids = np.load(file_path, allow_pickle=True)
132
+ elif file_type == 'jsonl':
133
+ self.plain_text = True
134
+ with open(file_path, 'r') as f:
135
+ for line in f:
136
+ self.input_ids.append(json.loads(line.strip()))
137
+ elif file_type == 'pkl':
138
+ with open(file_path, 'rb') as f:
139
+ self.input_ids = pickle.load(f)
140
+ else:
141
+ raise Exception(f'unsupported file type for {file_path}')
142
+
143
+ if image_tags_file_path:
144
+ with open(image_tags_file_path, 'r') as f:
145
+ csv_reader = csv.reader(f)
146
+ for line in csv_reader:
147
+ self.image_tags.append(line[0])
148
+
149
+ def __len__(self):
150
+ return len(self.input_ids)
151
+
152
+ def __getitem__(self, item):
153
+ if self.plain_text:
154
+ inputs = TrainerTools().tokenizer.apply_chat_template(self.input_ids[item])
155
+ else:
156
+ inputs = self.input_ids[item]
157
+
158
+ if isinstance(inputs, np.ndarray):
159
+ inputs = torch.from_numpy(inputs.astype(np.int64))
160
+ else:
161
+ inputs = torch.tensor(inputs).long()
162
+
163
+ image_tag = self.image_tags[item] if self.image_tags else None
164
+
165
+ if self.tokens_per_image != -1:
166
+ inputs = repeat_image_tok(inputs, self.tokens_per_image)
167
+ else:
168
+ image_tag = None
169
+
170
+ inputs = inputs[:self.max_len]
171
+
172
+ return {
173
+ 'inputs': inputs,
174
+ 'image_tag': image_tag
175
+ }
176
+
177
+
178
+ class DPODataset(Dataset):
179
+ """
180
+ 适用于dpo阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
181
+ npy: [
182
+ {'chosen': xxx, 'rejected': xxx},
183
+ {'chosen': xxx, 'rejected': xxx},
184
+ ]
185
+ jsonl: {'chosen':
186
+ [{'role': 'system', 'content': 'system_content'},
187
+ {'role': 'user', 'content': 'user_content'},
188
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
189
+ 'rejected':
190
+ [{'role': 'system', 'content': 'system_content'},
191
+ {'role': 'user', 'content': 'user_content'},
192
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
193
+ }\n
194
+ {'chosen':
195
+ [{'role': 'system', 'content': 'system_content'},
196
+ {'role': 'user', 'content': 'user_content'},
197
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
198
+ 'rejected':
199
+ [{'role': 'system', 'content': 'system_content'},
200
+ {'role': 'user', 'content': 'user_content'},
201
+ 'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
202
+ }
203
+ pkl: [
204
+ {'chosen': xxx, 'rejected': xxx},
205
+ {'chosen': xxx, 'rejected': xxx},
206
+ ]
207
+ """
208
+ def __init__(self, file_path, max_len):
209
+ self.max_len = max_len
210
+ self.data = []
211
+ self.plain_text = False
212
+
213
+ file_type = _get_file_type(file_path)
214
+
215
+ if file_type == 'npy':
216
+ try:
217
+ self.data = np.load(file_path, mmap_mode='r')
218
+ except ValueError:
219
+ self.data = np.load(file_path, allow_pickle=True)
220
+ elif file_type == 'jsonl':
221
+ self.plain_text = True
222
+ with open(file_path, 'r') as f:
223
+ for line in f:
224
+ self.data.append(json.loads(line.strip()))
225
+ elif file_type == 'pkl':
226
+ with open(file_path, 'rb') as f:
227
+ self.data = pickle.load(f)
228
+ else:
229
+ raise Exception(f'unsupported file type for {file_path}')
230
+
231
+ def __len__(self):
232
+ return len(self.data)
233
+
234
+ def __getitem__(self, item):
235
+ record = self.data[item]
236
+
237
+ chosen_raw = record['chosen']
238
+ rejected_raw = record['rejected']
239
+
240
+ if self.plain_text:
241
+ chosen_id = TrainerTools().tokenizer.apply_chat_template(chosen_raw)
242
+ rejected_id = TrainerTools().tokenizer.apply_chat_template(rejected_raw)
243
+ else:
244
+ chosen_id = chosen_raw
245
+ rejected_id = rejected_raw
246
+
247
+ if isinstance(chosen_id, np.ndarray): chosen_id = chosen_id.tolist()
248
+ if isinstance(rejected_id, np.ndarray): rejected_id = rejected_id.tolist()
249
+
250
+ return {
251
+ 'chosen': chosen_id[:self.max_len],
252
+ 'rejected': rejected_id[:self.max_len]
253
+ }
254
+
255
+
256
+ class RLDataset(Dataset):
257
+ """
258
+ 适用于RL阶段(例如:PPO、GRPO、GSPO),数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
259
+ npy: [
260
+ {'prompt': xxx, 'answer': xxx},
261
+ {'prompt': xxx, 'answer': xxx},
262
+ ]
263
+ jsonl: {'prompt':
264
+ [{'role': 'system', 'content': 'system_content'},
265
+ {'role': 'user', 'content': 'user_content'}]
266
+ 'answer': '10'
267
+ }\n
268
+ {'prompt':
269
+ [{'role': 'system', 'content': 'system_content'},
270
+ {'role': 'user', 'content': 'user_content'}]
271
+ 'answer': '10'
272
+ }
273
+ pkl: [
274
+ {'prompt': xxx, 'answer': xxx},
275
+ {'prompt': xxx, 'answer': xxx},
276
+ ]
277
+ """
278
+ def __init__(self, file_path):
279
+ self.data = []
280
+ self.plain_text = False
281
+
282
+ file_type = _get_file_type(file_path)
283
+
284
+ if file_type == 'npy':
285
+ try:
286
+ self.data = np.load(file_path, mmap_mode='r')
287
+ except ValueError:
288
+ self.data = np.load(file_path, allow_pickle=True)
289
+ elif file_type == 'jsonl':
290
+ self.plain_text = True
291
+
292
+ with open(file_path, 'r') as f:
293
+ for line in f:
294
+ self.data.append(json.loads(line.strip()))
295
+ elif file_type == 'pkl':
296
+ with open(file_path, 'rb') as f:
297
+ self.data = pickle.load(f)
298
+ else:
299
+ raise Exception(f'unsupported file type for {file_path}')
300
+
301
+ def __len__(self):
302
+ return len(self.data)
303
+
304
+ def __getitem__(self, item):
305
+ record = self.data[item]
306
+
307
+ prompt_raw = record['prompt']
308
+ answer_raw = record.get('answer', None)
309
+
310
+ if self.plain_text:
311
+ question = TrainerTools().tokenizer.apply_chat_template(prompt_raw)
312
+ answer = TrainerTools().tokenizer.encode(answer_raw) if answer_raw else None
313
+ else:
314
+ question = prompt_raw
315
+ answer = answer_raw
316
+
317
+ # 转换为 Tensor
318
+ if isinstance(question, np.ndarray):
319
+ prompt_tensor = torch.from_numpy(question.astype(np.int64))
320
+ else:
321
+ prompt_tensor = torch.tensor(question).long()
322
+
323
+ if answer is not None:
324
+ if isinstance(answer, np.ndarray):
325
+ answer_tensor = torch.from_numpy(answer.astype(np.int64))
326
+ else:
327
+ answer_tensor = torch.tensor(answer).long()
328
+ else:
329
+ answer_tensor = None
330
+
331
+ return {
332
+ 'prompt': prompt_tensor,
333
+ 'answer': answer_tensor
334
+ }
335
+