project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,114 @@
1
+ import os
2
+ from typing import Optional, Union
3
+ import torch
4
+ from torch import nn
5
+ from torch.optim import Optimizer
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ import torch.distributed as dist
8
+
9
+ from .parallel import DsParallel
10
+ from .scheduler import LRScheduler
11
+ from .tools import TrainerTools
12
+
13
+ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
14
+
15
+ def save_checkpoint(
16
+ model: nn.Module,
17
+ optimizer: Optional[Optimizer] = None,
18
+ extra_module: Optional[nn.Module] = None
19
+ ):
20
+ if isinstance(TrainerTools().parallel, DsParallel):
21
+ from .ds_checkpoint import save_ds_checkpoint
22
+ save_ds_checkpoint(model, extra_module=extra_module)
23
+ else:
24
+ if TrainerTools().parallel.is_main_process:
25
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
26
+
27
+ raw_model = model if not isinstance(model, DDP) else model.module
28
+ ckpt = {'model_state_dict': raw_model.state_dict()}
29
+
30
+ if optimizer:
31
+ ckpt.update({'optim_state_dict': optimizer.state_dict()})
32
+
33
+ if extra_module:
34
+ ckpt.update({'extra_module_state_dict': extra_module.state_dict()})
35
+
36
+ torch.save(ckpt, checkpoint_name)
37
+
38
+
39
+ def load_checkpoint(
40
+ model: nn.Module,
41
+ optimizer: Optional[Optimizer] = None,
42
+ device: Optional[Union[torch.device, str]] = None,
43
+ load_module_only: bool = False,
44
+ extra_module: Optional[nn.Module] = None
45
+ ):
46
+ if isinstance(TrainerTools().parallel, DsParallel):
47
+ from .ds_checkpoint import load_ds_checkpoint
48
+ load_ds_checkpoint(model, load_module_only=load_module_only, extra_module=extra_module)
49
+ else:
50
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
51
+
52
+ if os.path.exists(checkpoint_name):
53
+ state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
54
+ raw_model = model.module if isinstance(model, DDP) else model
55
+ raw_model.load_state_dict(state_dict['model_state_dict'])
56
+
57
+ if optimizer and 'optim_state_dict' in state_dict:
58
+ optimizer.load_state_dict(state_dict['optim_state_dict'])
59
+
60
+ if extra_module and 'extra_module_state_dict' in state_dict:
61
+ extra_module.load_state_dict(state_dict['extra_module_state_dict'])
62
+
63
+
64
+ def save_steps(
65
+ epoch: int = 0,
66
+ file_idx: int = 0,
67
+ batch_idx: int = 0,
68
+ lr_scheduler: Optional[LRScheduler] = None
69
+ ):
70
+ # 暂时只保存主进程的
71
+ if TrainerTools().parallel.is_main_process:
72
+ steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
73
+ ckpt = {
74
+ 'epoch': epoch,
75
+ 'file_idx': file_idx,
76
+ 'batch_idx': batch_idx,
77
+ 'cpu_rng_state': torch.get_rng_state(),
78
+ }
79
+
80
+ if torch.cuda.is_available():
81
+ ckpt['cuda_rng_state'] = torch.cuda.get_rng_state()
82
+
83
+ if lr_scheduler:
84
+ ckpt.update(lr_scheduler.get_ckpt_dict())
85
+
86
+ torch.save(ckpt, steps_checkpoint_name)
87
+
88
+
89
+ def load_steps() -> Optional[dict]:
90
+ steps_dict = None
91
+
92
+ if TrainerTools().parallel.is_main_process:
93
+ steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
94
+ if os.path.exists(steps_checkpoint_name):
95
+ try:
96
+ steps_dict = torch.load(steps_checkpoint_name, weights_only=True)
97
+ except:
98
+ steps_dict = None
99
+
100
+ if TrainerTools().parallel.world_size > 1:
101
+ object_list = [steps_dict]
102
+ dist.broadcast_object_list(object_list, src=0)
103
+ steps_dict = object_list[0]
104
+ TrainerTools().parallel.wait('broadcast steps_dict')
105
+
106
+ if steps_dict:
107
+ if 'cpu_rng_state' in steps_dict:
108
+ torch.set_rng_state(steps_dict['cpu_rng_state'])
109
+ if 'cuda_rng_state' in steps_dict and torch.cuda.is_available():
110
+ try:
111
+ torch.cuda.set_rng_state(steps_dict['cuda_rng_state'])
112
+ except: ...
113
+
114
+ return steps_dict
llm_trainer/dataset.py ADDED
@@ -0,0 +1,335 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import pickle
4
+ import csv
5
+ import json
6
+ import numpy as np
7
+
8
+ from .tools import TrainerTools
9
+ from .utils import repeat_image_tok
10
+
11
+
12
+ """
13
+ support jsonl and pkl
14
+ """
15
+ def _get_file_type(file_path: str):
16
+ if file_path.endswith('.npy'):
17
+ return 'npy'
18
+ elif file_path.endswith('.jsonl'):
19
+ return 'jsonl'
20
+ elif file_path.endswith('.pkl'):
21
+ return 'pkl'
22
+
23
+ return None
24
+
25
+
26
+ class PretrainDataset(Dataset):
27
+ """
28
+ 适用于pretrain阶段,数据格式支持jsonl和pkl,如果是jsonl会在init阶段全部encode成token
29
+ 1. npy:【推荐】numpy 数组,支持 mmap,内存占用极低
30
+ 2. jsonl: {'text': 'text1'}\n{'text': 'text2'}
31
+ 3. pkl: [0, 1, 2, 3 ...]
32
+ """
33
+ def __init__(
34
+ self,
35
+ file_path,
36
+ block_size,
37
+ stride
38
+ ):
39
+ super().__init__()
40
+
41
+ self.block_size = block_size
42
+ self.stride = stride
43
+ self.use_mmap = False
44
+
45
+ file_type = _get_file_type(file_path)
46
+
47
+ if file_type == 'npy':
48
+ self.input_ids = np.load(file_path, mmap_mode='r')
49
+ self.use_mmap = True
50
+ elif file_type == 'jsonl':
51
+ tokens = []
52
+ with open(file_path, 'r') as f:
53
+ for line in f:
54
+ tokens.extend(TrainerTools().tokenizer.encode(json.loads(line.strip())['text']))
55
+ self.input_ids = torch.tensor(tokens, dtype=torch.int32)
56
+ del tokens
57
+ elif file_type == 'pkl':
58
+ with open(file_path, 'rb') as f:
59
+ tokens = pickle.load(f)
60
+ self.input_ids = torch.tensor(tokens, dtype=torch.int32)
61
+ del tokens
62
+ else:
63
+ raise Exception(f'unsupported file type for {file_path}')
64
+
65
+ if len(self.input_ids) < block_size:
66
+ self.length = 0
67
+ else:
68
+ self.length = (len(self.input_ids) - block_size) // stride + 1
69
+
70
+ def __len__(self):
71
+ return self.length
72
+
73
+ def __getitem__(self, item):
74
+ if item < 0 or item >= self.length:
75
+ raise IndexError(f"Index {item} out of range")
76
+
77
+ start_idx = item * self.stride
78
+ end_idx = start_idx + self.block_size
79
+
80
+ data = self.input_ids[start_idx:end_idx]
81
+
82
+ if self.use_mmap:
83
+ return torch.from_numpy(data.astype(np.int64))
84
+ else:
85
+ return data.long()
86
+
87
+
88
+ class SFTDataset(Dataset):
89
+ """
90
+ 适用于sft阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
91
+ npy: [
92
+ [0, 1, 2, 3],
93
+ [4, 5, 6, 7]
94
+ ]
95
+ jsonl: [
96
+ {'role': 'system', 'content': 'system_content'},
97
+ {'role': 'user', 'content': 'user_content'},
98
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
99
+ ]\n
100
+ [
101
+ {'role': 'system', 'content': 'system_content'},
102
+ {'role': 'user', 'content': 'user_content'},
103
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
104
+ ]
105
+ pkl: [
106
+ [0, 1, 2, 3],
107
+ [4, 5, 6, 7]
108
+ ]
109
+ """
110
+ def __init__(
111
+ self,
112
+ file_path,
113
+ block_size,
114
+ image_tags_file_path=None,
115
+ tokens_per_image=-1
116
+ ):
117
+ super().__init__()
118
+
119
+ self.block_size = block_size
120
+ self.tokens_per_image = tokens_per_image
121
+ self.input_ids = []
122
+ self.image_tags = []
123
+ self.plain_text = False
124
+
125
+ file_type = _get_file_type(file_path)
126
+
127
+ if file_type == 'npy':
128
+ try:
129
+ self.input_ids = np.load(file_path, mmap_mode='r')
130
+ except ValueError:
131
+ self.input_ids = np.load(file_path, allow_pickle=True)
132
+ elif file_type == 'jsonl':
133
+ self.plain_text = True
134
+ with open(file_path, 'r') as f:
135
+ for line in f:
136
+ self.input_ids.append(json.loads(line.strip()))
137
+ elif file_type == 'pkl':
138
+ with open(file_path, 'rb') as f:
139
+ self.input_ids = pickle.load(f)
140
+ else:
141
+ raise Exception(f'unsupported file type for {file_path}')
142
+
143
+ if image_tags_file_path:
144
+ with open(image_tags_file_path, 'r') as f:
145
+ csv_reader = csv.reader(f)
146
+ for line in csv_reader:
147
+ self.image_tags.append(line[0])
148
+
149
+ def __len__(self):
150
+ return len(self.input_ids)
151
+
152
+ def __getitem__(self, item):
153
+ if self.plain_text:
154
+ inputs = TrainerTools().tokenizer.apply_chat_template(self.input_ids[item])
155
+ else:
156
+ inputs = self.input_ids[item]
157
+
158
+ if isinstance(inputs, np.ndarray):
159
+ inputs = torch.from_numpy(inputs.astype(np.int64))
160
+ else:
161
+ inputs = torch.tensor(inputs).long()
162
+
163
+ image_tag = self.image_tags[item] if self.image_tags else None
164
+
165
+ if self.tokens_per_image != -1:
166
+ inputs, _ = repeat_image_tok(inputs, self.tokens_per_image)
167
+ else:
168
+ image_tag = None
169
+
170
+ inputs = inputs[:self.block_size]
171
+
172
+ return {
173
+ 'inputs': inputs,
174
+ 'image_tag': image_tag
175
+ }
176
+
177
+
178
+ class DPODataset(Dataset):
179
+ """
180
+ 适用于dpo阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
181
+ npy: [
182
+ {'chosen': xxx, 'rejected': xxx},
183
+ {'chosen': xxx, 'rejected': xxx},
184
+ ]
185
+ jsonl: {'chosen':
186
+ [{'role': 'system', 'content': 'system_content'},
187
+ {'role': 'user', 'content': 'user_content'},
188
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
189
+ 'rejected':
190
+ [{'role': 'system', 'content': 'system_content'},
191
+ {'role': 'user', 'content': 'user_content'},
192
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
193
+ }\n
194
+ {'chosen':
195
+ [{'role': 'system', 'content': 'system_content'},
196
+ {'role': 'user', 'content': 'user_content'},
197
+ {'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
198
+ 'rejected':
199
+ [{'role': 'system', 'content': 'system_content'},
200
+ {'role': 'user', 'content': 'user_content'},
201
+ 'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
202
+ }
203
+ pkl: [
204
+ {'chosen': xxx, 'rejected': xxx},
205
+ {'chosen': xxx, 'rejected': xxx},
206
+ ]
207
+ """
208
+ def __init__(self, file_path, block_size):
209
+ self.block_size = block_size
210
+ self.data = []
211
+ self.plain_text = False
212
+
213
+ file_type = _get_file_type(file_path)
214
+
215
+ if file_type == 'npy':
216
+ try:
217
+ self.data = np.load(file_path, mmap_mode='r')
218
+ except ValueError:
219
+ self.data = np.load(file_path, allow_pickle=True)
220
+ elif file_type == 'jsonl':
221
+ self.plain_text = True
222
+ with open(file_path, 'r') as f:
223
+ for line in f:
224
+ self.data.append(json.loads(line.strip()))
225
+ elif file_type == 'pkl':
226
+ with open(file_path, 'rb') as f:
227
+ self.data = pickle.load(f)
228
+ else:
229
+ raise Exception(f'unsupported file type for {file_path}')
230
+
231
+ def __len__(self):
232
+ return len(self.data)
233
+
234
+ def __getitem__(self, item):
235
+ record = self.data[item]
236
+
237
+ chosen_raw = record['chosen']
238
+ rejected_raw = record['rejected']
239
+
240
+ if self.plain_text:
241
+ chosen_id = TrainerTools().tokenizer.apply_chat_template(chosen_raw)
242
+ rejected_id = TrainerTools().tokenizer.apply_chat_template(rejected_raw)
243
+ else:
244
+ chosen_id = chosen_raw
245
+ rejected_id = rejected_raw
246
+
247
+ if isinstance(chosen_id, np.ndarray): chosen_id = chosen_id.tolist()
248
+ if isinstance(rejected_id, np.ndarray): rejected_id = rejected_id.tolist()
249
+
250
+ return {
251
+ 'chosen': chosen_id[:self.block_size],
252
+ 'rejected': rejected_id[:self.block_size]
253
+ }
254
+
255
+
256
+ class RLDataset(Dataset):
257
+ """
258
+ 适用于RL阶段(例如:PPO、GRPO、GSPO),数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
259
+ npy: [
260
+ {'prompt': xxx, 'answer': xxx},
261
+ {'prompt': xxx, 'answer': xxx},
262
+ ]
263
+ jsonl: {'prompt':
264
+ [{'role': 'system', 'content': 'system_content'},
265
+ {'role': 'user', 'content': 'user_content'}]
266
+ 'answer': '10'
267
+ }\n
268
+ {'prompt':
269
+ [{'role': 'system', 'content': 'system_content'},
270
+ {'role': 'user', 'content': 'user_content'}]
271
+ 'answer': '10'
272
+ }
273
+ pkl: [
274
+ {'prompt': xxx, 'answer': xxx},
275
+ {'prompt': xxx, 'answer': xxx},
276
+ ]
277
+ """
278
+ def __init__(self, file_path):
279
+ self.data = []
280
+ self.plain_text = False
281
+
282
+ file_type = _get_file_type(file_path)
283
+
284
+ if file_type == 'npy':
285
+ try:
286
+ self.data = np.load(file_path, mmap_mode='r')
287
+ except ValueError:
288
+ self.data = np.load(file_path, allow_pickle=True)
289
+ elif file_type == 'jsonl':
290
+ self.plain_text = True
291
+
292
+ with open(file_path, 'r') as f:
293
+ for line in f:
294
+ self.data.append(json.loads(line.strip()))
295
+ elif file_type == 'pkl':
296
+ with open(file_path, 'rb') as f:
297
+ self.data = pickle.load(f)
298
+ else:
299
+ raise Exception(f'unsupported file type for {file_path}')
300
+
301
+ def __len__(self):
302
+ return len(self.data)
303
+
304
+ def __getitem__(self, item):
305
+ record = self.data[item]
306
+
307
+ prompt_raw = record['prompt']
308
+ answer_raw = record.get('answer', None)
309
+
310
+ if self.plain_text:
311
+ question = TrainerTools().tokenizer.apply_chat_template(prompt_raw)
312
+ answer = TrainerTools().tokenizer.encode(answer_raw) if answer_raw else None
313
+ else:
314
+ question = prompt_raw
315
+ answer = answer_raw
316
+
317
+ # 转换为 Tensor
318
+ if isinstance(question, np.ndarray):
319
+ prompt_tensor = torch.from_numpy(question.astype(np.int64))
320
+ else:
321
+ prompt_tensor = torch.tensor(question).long()
322
+
323
+ if answer is not None:
324
+ if isinstance(answer, np.ndarray):
325
+ answer_tensor = torch.from_numpy(answer.astype(np.int64))
326
+ else:
327
+ answer_tensor = torch.tensor(answer).long()
328
+ else:
329
+ answer_tensor = None
330
+
331
+ return {
332
+ 'prompt': prompt_tensor,
333
+ 'answer': answer_tensor
334
+ }
335
+