project-llm-trainer 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +707 -0
- llm_trainer/checkpoint.py +114 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +311 -0
- llm_trainer/ds_checkpoint.py +72 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +463 -0
- llm_trainer/grpo_trainer.py +410 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +266 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +686 -0
- llm_trainer/scheduler.py +220 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +327 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +630 -0
- project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.13.4.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.13.4.data/scripts/ds_train +17 -0
- project_llm_trainer-0.13.4.data/scripts/py_train +12 -0
- project_llm_trainer-0.13.4.data/scripts/smart_train +37 -0
- project_llm_trainer-0.13.4.data/scripts/vis_log +98 -0
- project_llm_trainer-0.13.4.data/scripts/vis_lr +46 -0
- project_llm_trainer-0.13.4.dist-info/METADATA +9 -0
- project_llm_trainer-0.13.4.dist-info/RECORD +32 -0
- project_llm_trainer-0.13.4.dist-info/WHEEL +5 -0
- project_llm_trainer-0.13.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import os, sys, torch
|
|
5
|
+
|
|
6
|
+
arguments = sys.argv[1:]
|
|
7
|
+
# file name
|
|
8
|
+
run_file_name = arguments[0]
|
|
9
|
+
|
|
10
|
+
extra_args = ''
|
|
11
|
+
if len(arguments) > 1:
|
|
12
|
+
extra_args = f"{' '.join(arguments[1:])} "
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import deepspeed
|
|
16
|
+
parallel_type = 'ds'
|
|
17
|
+
except:
|
|
18
|
+
gpu_count = torch.cuda.device_count()
|
|
19
|
+
if gpu_count <= 1:
|
|
20
|
+
parallel_type = 'none'
|
|
21
|
+
else:
|
|
22
|
+
parallel_type = 'ddp'
|
|
23
|
+
|
|
24
|
+
os.environ['PARALLEL_TYPE'] = parallel_type
|
|
25
|
+
|
|
26
|
+
if parallel_type == 'ds':
|
|
27
|
+
command = f'deepspeed {extra_args}{run_file_name}'
|
|
28
|
+
elif parallel_type == 'ddp':
|
|
29
|
+
if len(extra_args) == 0:
|
|
30
|
+
extra_args = '--standalone --nproc_per_node=gpu '
|
|
31
|
+
|
|
32
|
+
command = f'torchrun {extra_args}{run_file_name}'
|
|
33
|
+
else:
|
|
34
|
+
command = f'python3 {run_file_name}'
|
|
35
|
+
|
|
36
|
+
print(f'run command {command}')
|
|
37
|
+
os.system(command)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import os, sys
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
from numpy import ndarray
|
|
7
|
+
from matplotlib.ticker import MaxNLocator
|
|
8
|
+
import re
|
|
9
|
+
|
|
10
|
+
if __name__ == '__main__':
|
|
11
|
+
arguments = sys.argv[1:]
|
|
12
|
+
loss_file = arguments[0]
|
|
13
|
+
|
|
14
|
+
if not os.path.exists(loss_file):
|
|
15
|
+
print(f'{loss_file} not found')
|
|
16
|
+
exit(0)
|
|
17
|
+
|
|
18
|
+
data_map = {}
|
|
19
|
+
all_metric_keys = []
|
|
20
|
+
|
|
21
|
+
with open(loss_file, 'r') as f:
|
|
22
|
+
for line in f:
|
|
23
|
+
if '====' in line:
|
|
24
|
+
continue
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
meta_part, values_part = line.split(' -> ')
|
|
28
|
+
|
|
29
|
+
epoch = int(re.search(r'epoch:\s*(\d+)', meta_part).group(1))
|
|
30
|
+
file_str = re.search(r'file:\s*(\d+)', meta_part).group(1)
|
|
31
|
+
file_idx = int(file_str)
|
|
32
|
+
batch_str = re.search(r'batch:\s*(\d+)', meta_part).group(1)
|
|
33
|
+
batch_idx = int(batch_str)
|
|
34
|
+
|
|
35
|
+
sort_key = (epoch, file_idx, batch_idx)
|
|
36
|
+
|
|
37
|
+
current_metrics = {}
|
|
38
|
+
values_kvs = values_part.split(', ')
|
|
39
|
+
for values_kv in values_kvs:
|
|
40
|
+
k, v = values_kv.split(': ')
|
|
41
|
+
val = float(v.strip())
|
|
42
|
+
current_metrics[k] = val
|
|
43
|
+
|
|
44
|
+
if k not in all_metric_keys:
|
|
45
|
+
all_metric_keys.append(k)
|
|
46
|
+
|
|
47
|
+
data_map[sort_key] = current_metrics
|
|
48
|
+
|
|
49
|
+
except Exception as e:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
sorted_keys = sorted(data_map.keys())
|
|
53
|
+
results = {k: [] for k in all_metric_keys}
|
|
54
|
+
|
|
55
|
+
for key in sorted_keys:
|
|
56
|
+
metrics = data_map[key]
|
|
57
|
+
for k in all_metric_keys:
|
|
58
|
+
if k in metrics:
|
|
59
|
+
results[k].append(metrics[k])
|
|
60
|
+
|
|
61
|
+
if not results:
|
|
62
|
+
print("No valid data found.")
|
|
63
|
+
exit(0)
|
|
64
|
+
|
|
65
|
+
results_size = len(results.keys())
|
|
66
|
+
if results_size <= 4:
|
|
67
|
+
rows = 1
|
|
68
|
+
cols = results_size
|
|
69
|
+
else:
|
|
70
|
+
rows = math.ceil(results_size / 4)
|
|
71
|
+
cols = 4
|
|
72
|
+
|
|
73
|
+
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4 * cols, 4 * rows))
|
|
74
|
+
|
|
75
|
+
if isinstance(axes, ndarray):
|
|
76
|
+
axes = axes.flatten()
|
|
77
|
+
else:
|
|
78
|
+
axes = [axes]
|
|
79
|
+
|
|
80
|
+
for idx, title in enumerate(results.keys()):
|
|
81
|
+
ax = axes[idx]
|
|
82
|
+
y = results[title]
|
|
83
|
+
x = list(range(len(y)))
|
|
84
|
+
|
|
85
|
+
ax.plot(x, y)
|
|
86
|
+
ax.set_title(title)
|
|
87
|
+
|
|
88
|
+
ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
|
|
89
|
+
ax.tick_params(axis='x', rotation=30)
|
|
90
|
+
ax.set_xlabel("Step")
|
|
91
|
+
ax.set_ylabel(title)
|
|
92
|
+
|
|
93
|
+
total_plots = len(results.keys())
|
|
94
|
+
for i in range(total_plots, len(axes)):
|
|
95
|
+
axes[i].set_visible(False)
|
|
96
|
+
|
|
97
|
+
plt.tight_layout()
|
|
98
|
+
plt.show()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
import os, sys
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
from matplotlib.ticker import MaxNLocator
|
|
6
|
+
|
|
7
|
+
if __name__ == '__main__':
|
|
8
|
+
arguments = sys.argv[1:]
|
|
9
|
+
lr_file = arguments[0]
|
|
10
|
+
|
|
11
|
+
if not os.path.exists(lr_file):
|
|
12
|
+
print(f'{lr_file} not found')
|
|
13
|
+
exit(0)
|
|
14
|
+
|
|
15
|
+
lrs = {}
|
|
16
|
+
# [time] step: {self.cur_steps}, lr: {lr}
|
|
17
|
+
with open(lr_file, 'r') as f:
|
|
18
|
+
for line in f:
|
|
19
|
+
if not line:
|
|
20
|
+
continue
|
|
21
|
+
|
|
22
|
+
data = line.split('step: ')[-1]
|
|
23
|
+
data = data.split(', lr:')
|
|
24
|
+
|
|
25
|
+
step = int(data[0].strip())
|
|
26
|
+
lr = float(data[1].strip())
|
|
27
|
+
|
|
28
|
+
lrs[step] = lr
|
|
29
|
+
|
|
30
|
+
plt.title('lr')
|
|
31
|
+
plt.xlabel("Step")
|
|
32
|
+
plt.ylabel("Learning Rate")
|
|
33
|
+
|
|
34
|
+
y = lrs.values()
|
|
35
|
+
x = list(range(len(y)))
|
|
36
|
+
|
|
37
|
+
ax = plt.gca()
|
|
38
|
+
plt.plot(x, y)
|
|
39
|
+
ax.xaxis.set_major_locator(MaxNLocator(nbins=20))
|
|
40
|
+
|
|
41
|
+
plt.xticks(rotation=30)
|
|
42
|
+
|
|
43
|
+
plt.tight_layout()
|
|
44
|
+
plt.show()
|
|
45
|
+
|
|
46
|
+
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
llm_trainer/__init__.py,sha256=U_rFD6hqNJuNXjcKJ9QnxnAL3SXhyWdGZEcA5GbrU3s,385
|
|
2
|
+
llm_trainer/base_trainer.py,sha256=62zoWzNajK07cnSLuWovxZSlQOikvK5hGa7nW5Yy9BE,29916
|
|
3
|
+
llm_trainer/checkpoint.py,sha256=vjarm-9J-9HAklpQAxbB3Bgph2HI6gxBQvUkB3LywwI,4009
|
|
4
|
+
llm_trainer/dataset.py,sha256=obbJuFmRS3-ntjF3q7acRYkbKYNqLQFMtZij0mCfCjU,10947
|
|
5
|
+
llm_trainer/dpo_trainer.py,sha256=TI8SZxxiqS3BA8IByQl74fjyjCNe-C6OXAqBNbcO5Yw,13192
|
|
6
|
+
llm_trainer/ds_checkpoint.py,sha256=0XZEdBV50obVmAXK1dX_mNuS-yomZW6RTzt1R0TdCyw,2611
|
|
7
|
+
llm_trainer/eval.py,sha256=6qwkRZQXpWJoGm3173Tx39GbgI0gEjA0VNath5J9ekg,1004
|
|
8
|
+
llm_trainer/generate_utils.py,sha256=Yc6xqS0xIaWx4paJMIHDrvQaLCHi5_R91dKvoEtMXgw,16388
|
|
9
|
+
llm_trainer/grpo_trainer.py,sha256=3dSxFSzxzTciGYjUZ_7VN6SdHZx71RIILq0c7Ph6QfU,15962
|
|
10
|
+
llm_trainer/log.py,sha256=BCb8qzs2TGltBFHNuDeEibT6FgBZZTZ-Ijuu1XNOSes,1746
|
|
11
|
+
llm_trainer/loss.py,sha256=AeiUSIkUV6JqyhH3M5CSrXFY9Y_EscG-kE3aOw4bMBE,10140
|
|
12
|
+
llm_trainer/parallel.py,sha256=eWRcqFkOfWM50Chv6gKpifAkaoxF3h8lr3592QXBmx8,6199
|
|
13
|
+
llm_trainer/partition_utils.py,sha256=EMXVGi-AN2piqbOCQei7WmddwQ07jwC5RWClaofIj9Q,8087
|
|
14
|
+
llm_trainer/ppo_trainer.py,sha256=xXgXNVKxTV1jTuz25J1BMfP6r9I0k-hRVGf-b4yJsyw,28946
|
|
15
|
+
llm_trainer/scheduler.py,sha256=cNRPeApnIrSh0fRDo9qKkrkRSYJb7JWKlWOJ30rmzoM,6448
|
|
16
|
+
llm_trainer/sft_trainer.py,sha256=yAHZp8MUlngKgciEUrcVhdEFjjQKRwQ-NqppaBmhc5Y,3687
|
|
17
|
+
llm_trainer/tokenizer.py,sha256=8Mccp4sCaYWiKVD78dEwBMHlA9uS0xf22FOiVxTVtK4,5875
|
|
18
|
+
llm_trainer/tools.py,sha256=QGYOwjabWEMyOe_N9z1yL9WNEjNrEshpZFjnv_QOZH0,3323
|
|
19
|
+
llm_trainer/train_configs.py,sha256=ZL4M5ap3ndaK8hRnBCJ3mjspBYiDyzU8rZxsu2LXJ4E,10519
|
|
20
|
+
llm_trainer/trainer.py,sha256=X0E5-mU5SZRrpevDhhCuUIVMVs0GhVnY7OwAhEgMo9w,1214
|
|
21
|
+
llm_trainer/utils.py,sha256=4SBse7AXn6R7xiRKpRGOF9xrx_ZP9SidgyANkO22CxU,23346
|
|
22
|
+
project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
23
|
+
project_llm_trainer-0.13.4.data/scripts/ddp_train,sha256=eZSud6KYQAoKLsYB5QB-FI2zq5AZm6Apq1azKdupV3o,477
|
|
24
|
+
project_llm_trainer-0.13.4.data/scripts/ds_train,sha256=41q4rOxwbvZDUY0FDdAIpG13PEaUWBpthhvFvww8uOc,388
|
|
25
|
+
project_llm_trainer-0.13.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
26
|
+
project_llm_trainer-0.13.4.data/scripts/smart_train,sha256=N8dp2n7k6bghGczedBVwOdtf1O66oM_cNPh9QmZt0bM,914
|
|
27
|
+
project_llm_trainer-0.13.4.data/scripts/vis_log,sha256=hn3HinTbmOhn9PTby_vodAWmNHDwRA0a9yoU7DHqMjg,2626
|
|
28
|
+
project_llm_trainer-0.13.4.data/scripts/vis_lr,sha256=mgSOckQrRw_42locxk09TTBEeCqSTiu7j1OJ5_vMLDU,923
|
|
29
|
+
project_llm_trainer-0.13.4.dist-info/METADATA,sha256=TaaOytFZKGXMITWTGqPL6Dvm_v_dhLT-ejsMvQ7hsH4,196
|
|
30
|
+
project_llm_trainer-0.13.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
31
|
+
project_llm_trainer-0.13.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
32
|
+
project_llm_trainer-0.13.4.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
llm_trainer
|