project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/utils.py ADDED
@@ -0,0 +1,262 @@
1
+ import random
2
+ from typing import Tuple, Optional
3
+ import torch
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import torch.nn.functional as F
6
+ from .tools import TrainerTools
7
+ import numpy as np
8
+
9
+
10
+ def set_seed(seed=42):
11
+ random.seed(seed)
12
+ np.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ torch.cuda.manual_seed(seed)
15
+ torch.cuda.manual_seed_all(seed)
16
+
17
+
18
+ def extra_image_tag_and_repeat_image_tok(
19
+ inputs: list[int],
20
+ tokens_per_image: int
21
+ ) -> Tuple[list[int], Optional[int]]:
22
+ # tokens_per_image=3 -> <image>{image_tag}...xxxx -> <image><image><image>...xxx
23
+ image_tok = TrainerTools().tokenizer.image
24
+ if image_tok not in inputs:
25
+ return inputs, None
26
+
27
+ image_tok_idx = inputs.index(image_tok)
28
+ image_tag_idx = image_tok_idx + 1
29
+
30
+ if image_tag_idx < len(inputs):
31
+ # remove it
32
+ image_tag = inputs.pop(image_tag_idx)
33
+ else:
34
+ image_tag = None
35
+
36
+ # repeat image_tok
37
+ new_inputs = inputs[:image_tok_idx] + [image_tok] * tokens_per_image + inputs[image_tok_idx + 1:]
38
+ return new_inputs, image_tag
39
+
40
+
41
+ def batch_extra_image_tag_and_repeat_image_tok(
42
+ tokens: torch.Tensor,
43
+ tokens_per_image: int
44
+ ) -> Tuple[torch.Tensor, list[int]]:
45
+ new_tokens = []
46
+ image_tags = []
47
+
48
+ tokens_list = tokens.cpu().detach().tolist()
49
+ for token in tokens_list:
50
+ new_token, image_tag = extra_image_tag_and_repeat_image_tok(token, tokens_per_image)
51
+ new_tokens.append(new_token)
52
+ image_tags.append(image_tag)
53
+
54
+ return torch.tensor(new_tokens, dtype=tokens.dtype, device=tokens.device), image_tags
55
+
56
+
57
+ def repeat_image_tok(
58
+ tokens: torch.Tensor,
59
+ tokens_per_image: int
60
+ ) -> torch.Tensor:
61
+ # tokens_per_image=3 -> <image>...xxxx -> <image><image><image>...xxx
62
+ image_tok = TrainerTools().tokenizer.image
63
+ if image_tok not in tokens:
64
+ return tokens
65
+
66
+ image_tok_idx = torch.where(tokens == image_tok)[0].item()
67
+ repeat_image_toks = torch.tensor([image_tok] * tokens_per_image, dtype=tokens.dtype, device=tokens.device)
68
+
69
+ # repeat image_tok
70
+ new_tokens = torch.concat([tokens[:image_tok_idx], repeat_image_toks, tokens[image_tok_idx + 1:]], dim=-1)
71
+ return new_tokens
72
+
73
+
74
+ def batch_repeat_image_tok(
75
+ tokens: torch.Tensor,
76
+ tokens_per_image: int
77
+ ) -> torch.Tensor:
78
+ new_tokens = []
79
+
80
+ for token in tokens:
81
+ new_tokens.append(repeat_image_tok(token, tokens_per_image))
82
+
83
+ return torch.stack(new_tokens, dim=0)
84
+
85
+
86
+ def _pad_sequence(batch_data):
87
+ # [[x,x,x], [y,y,y]]
88
+ inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
89
+ # crossEntropy默认的ignore_index是-100
90
+ labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
91
+
92
+ return inputs, labels
93
+
94
+
95
+ def _mask_prompt(labels):
96
+ tokenizer = TrainerTools().tokenizer
97
+ # 支持多轮会话的mask
98
+ for batch, label in enumerate(labels):
99
+ start_index = -1
100
+ for index, token in enumerate(label):
101
+ if token == tokenizer.system or token == tokenizer.user:
102
+ start_index = index
103
+ elif token == tokenizer.end and start_index != -1:
104
+ labels[batch, start_index:index + 1] = -100
105
+ start_index = -1
106
+
107
+ return labels
108
+
109
+
110
+ def _zero_pad_sequences(
111
+ sequences: list[torch.Tensor], side: str = "left"
112
+ ) -> torch.Tensor:
113
+ assert side in ("left", "right")
114
+ max_len = max(seq.size(0) for seq in sequences)
115
+ padded_sequences = []
116
+ for seq in sequences:
117
+ pad_len = max_len - seq.size(0)
118
+ padding = (pad_len, 0) if side == "left" else (0, pad_len)
119
+ padded_sequences.append(F.pad(seq, padding))
120
+ return torch.stack(padded_sequences, dim=0)
121
+
122
+
123
+ def pretrain_collate_fn(batch_data):
124
+ inputs, labels = _pad_sequence(batch_data)
125
+
126
+ # inputs, labels
127
+ return {'inputs': inputs, 'labels': labels}
128
+
129
+
130
+ def get_sft_collate_fn(mask_prompt: bool):
131
+ def sft_collate_fn(batch_data):
132
+ """
133
+ 如果是sft,则不计算prompt部分的loss, 例如:
134
+ logits: [USER]你好[BOT]我好[SEP]
135
+ labels: [USER]你好[BOT]我好[SEP]
136
+
137
+ shift_logits: [USER]你好[BOT]我好
138
+ shift_labels: 你好[BOT]我好[SEP]
139
+
140
+ mask_labels: mask mask mask mask 我好[SEP]
141
+ * mask=-100和pad一样
142
+
143
+
144
+ 多轮对话场景
145
+ [USER]你好[BOT]我好[SEP][USER]很好[BOT]不好[SEP]
146
+ mask: mask mask mask mask 我好[SEP] mask mask mask mask 不好[SEP]
147
+ """
148
+ batch_train_data = []
149
+ image_tags = []
150
+ for item in batch_data:
151
+ batch_train_data.append(item['inputs'])
152
+ image_tags.append(item['image_tag'])
153
+
154
+ inputs, labels = _pad_sequence(batch_train_data)
155
+ if mask_prompt:
156
+ labels = _mask_prompt(labels)
157
+
158
+ return {'inputs': inputs, 'labels': labels, 'image_tags': image_tags}
159
+
160
+ return sft_collate_fn
161
+
162
+
163
+ def get_dpo_collate_fn(mask_prompt: bool):
164
+ def dpo_collate_fn(batch_data):
165
+ # batch_data: [{'chosen': chosen, 'rejected': rejected}, {'chosen': chosen, 'rejected': rejected}]
166
+ chosen_inputs = []
167
+ chosen_labels = []
168
+ rejected_inputs = []
169
+ rejected_labels = []
170
+
171
+ max_len = 0
172
+ for key in ['chosen', 'rejected']:
173
+ max_len = max(max(len(item[key]) for item in batch_data), max_len)
174
+
175
+ for item in batch_data:
176
+ chosen_sequence = item['chosen']
177
+ chosen_inputs.append(chosen_sequence + [TrainerTools().tokenizer.pad] * (max_len - len(chosen_sequence)))
178
+ chosen_labels.append(chosen_sequence + [-100] * (max_len - len(chosen_sequence)))
179
+
180
+ rejected_sequence = item['rejected']
181
+ rejected_inputs.append(rejected_sequence + [TrainerTools().tokenizer.pad] * (max_len - len(rejected_sequence)))
182
+ rejected_labels.append(rejected_sequence + [-100] * (max_len - len(rejected_sequence)))
183
+
184
+ chosen_inputs = torch.tensor(chosen_inputs).long()
185
+ chosen_labels = torch.tensor(chosen_labels).long()
186
+ if mask_prompt:
187
+ chosen_labels = _mask_prompt(chosen_labels)
188
+
189
+ rejected_inputs = torch.tensor(rejected_inputs).long()
190
+ rejected_labels = torch.tensor(rejected_labels).long()
191
+ if mask_prompt:
192
+ rejected_labels = _mask_prompt(rejected_labels)
193
+
194
+ return {
195
+ 'chosen_inputs': chosen_inputs,
196
+ 'chosen_labels': chosen_labels,
197
+ 'rejected_inputs': rejected_inputs,
198
+ 'rejected_labels': rejected_labels
199
+ }
200
+
201
+ return dpo_collate_fn
202
+
203
+
204
+ def split_batch(data_per_batch: dict) -> list[dict]:
205
+ """
206
+ from: data_per_batch("sequences": [group_size, max_generate_len] ...)
207
+ to: [dict("sequences": [max_generate_len] ...) ... group_size]
208
+ """
209
+
210
+ group_size = data_per_batch['sequence_ids'].size(0)
211
+ # [{"sequence_ids": xxx, "old_log_probs": xxx...}, ...]
212
+ group_data = [{} for _ in range(group_size)]
213
+
214
+ keys = (
215
+ 'sequence_ids',
216
+ 'old_log_probs',
217
+ 'ref_log_probs',
218
+ 'advantages',
219
+ 'attention_mask',
220
+ 'mask',
221
+ )
222
+
223
+ for key in keys:
224
+ value = data_per_batch[key]
225
+ if value is None:
226
+ vals = [None] * group_size
227
+ else:
228
+ vals = torch.unbind(value)
229
+
230
+ for i, v in enumerate(vals):
231
+ group_data[i][key] = v
232
+
233
+ return group_data
234
+
235
+
236
+ def join_batch(batch_data: list[dict]) -> dict:
237
+ """
238
+ from: [dict("sequences": [max_generate_len] ...), ...]
239
+ to: dict("sequences": max_generate_len, ...)
240
+ """
241
+
242
+ result = {}
243
+ keys = (
244
+ 'sequence_ids',
245
+ 'old_log_probs',
246
+ 'ref_log_probs',
247
+ 'advantages',
248
+ 'attention_mask',
249
+ 'mask',
250
+ )
251
+
252
+ for key in keys:
253
+ # [sequence_ids, sequence_ids ...]
254
+ # shape [batch_size, seq_len]
255
+ vals = [item[key] for item in batch_data]
256
+ if all(v is not None for v in vals):
257
+ data = _zero_pad_sequences(vals, "left")
258
+ else:
259
+ data = None
260
+ result[key] = data
261
+
262
+ return result
@@ -0,0 +1,15 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import sys
5
+ arguments = sys.argv[1:]
6
+ hidden_size = int(arguments[0])
7
+ if len(arguments) > 1:
8
+ multiple_of = int(arguments[1])
9
+ else:
10
+ multiple_of = 64
11
+
12
+ intermediate_size = 4 * hidden_size
13
+ intermediate_size = int(2 * intermediate_size / 3)
14
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
15
+ print(f'intermediate_size={intermediate_size}')
@@ -0,0 +1,12 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ run_file_name = arguments[0]
7
+
8
+ os.environ['PARALLEL_TYPE'] = 'ddp'
9
+ command = f'torchrun --standalone --nproc_per_node=gpu {run_file_name}'
10
+
11
+ print(f'real command is {command}')
12
+ os.system(command)
@@ -0,0 +1,12 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ run_file_name = arguments[0]
7
+
8
+ os.environ['PARALLEL_TYPE'] = 'ds'
9
+ command = f'deepspeed {run_file_name}'
10
+
11
+ print(f'real command is {command}')
12
+ os.system(command)
@@ -0,0 +1,39 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ import matplotlib.pyplot as plt
6
+
7
+ arguments = sys.argv[1:]
8
+ loss_file = arguments[0]
9
+
10
+ if not os.path.exists(loss_file):
11
+ print(f'{loss_file} not found')
12
+ exit(0)
13
+
14
+ steps = []
15
+ losses = []
16
+ with open(loss_file, 'r') as f:
17
+ step = 0
18
+ for line in f:
19
+ if not line or 'loss:' not in line:
20
+ if 'start train' not in line:
21
+ steps.clear()
22
+ losses.clear()
23
+ step = 0
24
+ continue
25
+
26
+ # (2025-03-19 20:13:44) epoch: 0, file: 1/1, batch: 623/1099, loss: 0.12186837196350098
27
+ loss = float(line.split('loss:')[-1].strip())
28
+
29
+ steps.append(step)
30
+ losses.append(loss)
31
+ step += 1
32
+
33
+ plt.xlabel('steps')
34
+ plt.ylabel('loss')
35
+
36
+ plt.plot(steps, losses)
37
+ plt.show()
38
+
39
+
@@ -0,0 +1,41 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ import matplotlib.pyplot as plt
6
+
7
+ arguments = sys.argv[1:]
8
+ lr_file = arguments[0]
9
+
10
+ if not os.path.exists(lr_file):
11
+ print(f'{lr_file} not found')
12
+ exit(0)
13
+
14
+ steps = []
15
+ lrs = []
16
+ with open(lr_file, 'r') as f:
17
+ for line in f:
18
+ if not line:
19
+ continue
20
+ # line: (2025-03-19 18:15:30) step=159,lr=2.159680442248444e-05
21
+ # data: 159,lr=2.159680442248444e-05
22
+ data = line.split('step=')[-1]
23
+ # [159, 2.159680442248444e-05]
24
+ data = data.split(',lr=')
25
+
26
+ step = int(data[0])
27
+ lr = float(data[1])
28
+
29
+ if step in steps:
30
+ continue
31
+
32
+ steps.append(step)
33
+ lrs.append(lr)
34
+
35
+ plt.xlabel('steps')
36
+ plt.ylabel('lr')
37
+
38
+ plt.plot(steps, lrs)
39
+ plt.show()
40
+
41
+
@@ -0,0 +1,12 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ run_file_name = arguments[0]
7
+
8
+ os.environ['PARALLEL_TYPE'] = 'none'
9
+ command = f'python3 {run_file_name}'
10
+
11
+ print(f'real command is {command}')
12
+ os.system(command)
@@ -0,0 +1,28 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys, torch
5
+ arguments = sys.argv[1:]
6
+ run_file_name = arguments[0]
7
+
8
+ try:
9
+ import deepspeed
10
+ parallel_type = 'ds'
11
+ except:
12
+ gpu_count = torch.cuda.device_count()
13
+ if gpu_count <= 1:
14
+ parallel_type = 'none'
15
+ else:
16
+ parallel_type = 'ddp'
17
+
18
+ os.environ['PARALLEL_TYPE'] = parallel_type
19
+
20
+ if parallel_type == 'ds':
21
+ command = f'deepspeed {run_file_name}'
22
+ elif parallel_type == 'ddp':
23
+ command = f'torchrun --standalone --nproc_per_node=gpu {run_file_name}'
24
+ else:
25
+ command = f'python3 {run_file_name}'
26
+
27
+ print(f'real command is {command}')
28
+ os.system(command)
@@ -0,0 +1,9 @@
1
+ Metadata-Version: 2.4
2
+ Name: project_llm_trainer
3
+ Version: 0.3
4
+ Summary: LLM and VLM trainer
5
+ Author: qibin
6
+ Author-email: qibin0506@gmail.com
7
+ Dynamic: author
8
+ Dynamic: author-email
9
+ Dynamic: summary
@@ -0,0 +1,34 @@
1
+ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
+ llm_trainer/checkpoint.py,sha256=iTbnmVrT0Ql4DpD178UI95zCmfBUdYtoJS5wIvf8_4k,6099
3
+ llm_trainer/dataset.py,sha256=uz1TTd87ikf7CZPdGxmR95TSQTFWPPTilgWLBWO46_I,3916
4
+ llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
+ llm_trainer/dpo_trainer.py,sha256=6rm8Jq0rI0xazcl_bCOun8rnd34Tb_PKgezowhwoiCM,13150
6
+ llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
7
+ llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
8
+ llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
9
+ llm_trainer/grpo_trainer.py,sha256=gWDX8vRZ7hLKl_483X5ua92nst1m617BrqnzLhwr87g,16390
10
+ llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
11
+ llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
12
+ llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
13
+ llm_trainer/parallel_ddp.py,sha256=Gz-3LZ6LKmqlNwxrnGRC4uKoqoSxCvp9JHejIBSQp3c,1238
14
+ llm_trainer/parallel_ds.py,sha256=W_PkczyAlgffCRcQadN-Pf7H7HM7TU26v5W63jKELFM,990
15
+ llm_trainer/parallel_fsdp.py,sha256=u9XbbVTzcsMcaf-aQFrC_QwWsDRGoEpRmgvu1cKNtgk,3887
16
+ llm_trainer/parallel_none.py,sha256=a6tt3aBmCq5rSP7n2I-sF-hsZ992BbLbpbxutDCFJfs,607
17
+ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
18
+ llm_trainer/sft_trainer.py,sha256=T9CujoEp8D5I65fLF2wgV6SPjzhGFbAI4We5NwL4O-M,1443
19
+ llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
20
+ llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
21
+ llm_trainer/train_configs.py,sha256=FAlylSYVeh_oJGTy2fcMNUV8JLD6B70hMuk-iKx14iI,15748
22
+ llm_trainer/trainer.py,sha256=mq51d-2ADUpcWCArszhYnOSTveatt3_x43hcC7IZgYk,24330
23
+ llm_trainer/utils.py,sha256=04XiMENVotNgbNRBn9wadHu-cJHPxj0Xq-zzLJmNgZQ,8062
24
+ project_llm_trainer-0.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
+ project_llm_trainer-0.3.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
+ project_llm_trainer-0.3.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
+ project_llm_trainer-0.3.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
+ project_llm_trainer-0.3.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
+ project_llm_trainer-0.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
+ project_llm_trainer-0.3.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
+ project_llm_trainer-0.3.dist-info/METADATA,sha256=P64NiFbJzSd4QkFJ5udQ4qMyHUorPp3ex4F3eIdtVdU,193
32
+ project_llm_trainer-0.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
+ project_llm_trainer-0.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
+ project_llm_trainer-0.3.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.7.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ llm_trainer