project-llm-trainer 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +683 -0
- llm_trainer/checkpoint.py +126 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +297 -0
- llm_trainer/ds_checkpoint.py +63 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +450 -0
- llm_trainer/grpo_trainer.py +385 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +268 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +521 -0
- llm_trainer/scheduler.py +179 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +324 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +547 -0
- project_llm_trainer-0.12.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.12.3.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.12.3.data/scripts/ds_train +17 -0
- project_llm_trainer-0.12.3.data/scripts/plot_log +69 -0
- project_llm_trainer-0.12.3.data/scripts/plot_lr +45 -0
- project_llm_trainer-0.12.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.12.3.data/scripts/smart_train +37 -0
- project_llm_trainer-0.12.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.12.3.dist-info/RECORD +32 -0
- project_llm_trainer-0.12.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.12.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
import math
|
|
3
|
+
import os, sys
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
from numpy import ndarray
|
|
6
|
+
from matplotlib.ticker import MaxNLocator
|
|
7
|
+
|
|
8
|
+
if __name__ == '__main__':
|
|
9
|
+
arguments = sys.argv[1:]
|
|
10
|
+
loss_file = arguments[0]
|
|
11
|
+
|
|
12
|
+
if not os.path.exists(loss_file):
|
|
13
|
+
print(f'{loss_file} not found')
|
|
14
|
+
exit(0)
|
|
15
|
+
|
|
16
|
+
results = {}
|
|
17
|
+
|
|
18
|
+
# ====epoch: {epoch}, start train {file_name}====
|
|
19
|
+
# [time] keys_key1: keys_value1, keys_key2: keys_value2 -> values_key1: values_value1, values_key2: values_value2
|
|
20
|
+
with open(loss_file, 'r') as f:
|
|
21
|
+
for line in f:
|
|
22
|
+
if '====' in line:
|
|
23
|
+
continue
|
|
24
|
+
|
|
25
|
+
# values_key1: values_value1, values_key2: values_value2
|
|
26
|
+
values_kvs = line.split(' -> ')[1].split(', ')
|
|
27
|
+
for values_kv in values_kvs:
|
|
28
|
+
k, v = values_kv.split(': ')
|
|
29
|
+
if k not in results:
|
|
30
|
+
results[k] = [float(v.strip())]
|
|
31
|
+
else:
|
|
32
|
+
results[k].append(float(v.strip()))
|
|
33
|
+
|
|
34
|
+
results_size = len(results.keys())
|
|
35
|
+
if results_size <= 4:
|
|
36
|
+
rows = 1
|
|
37
|
+
cols = results_size
|
|
38
|
+
else:
|
|
39
|
+
rows = math.ceil(results_size / 4)
|
|
40
|
+
cols = 4
|
|
41
|
+
|
|
42
|
+
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4 * cols, 4 * rows))
|
|
43
|
+
|
|
44
|
+
if isinstance(axes, ndarray):
|
|
45
|
+
axes = axes.flatten()
|
|
46
|
+
else:
|
|
47
|
+
axes = [axes]
|
|
48
|
+
|
|
49
|
+
for idx, title in enumerate(results.keys()):
|
|
50
|
+
ax = axes[idx]
|
|
51
|
+
y = results[title]
|
|
52
|
+
x = list(range(len(y)))
|
|
53
|
+
|
|
54
|
+
ax.plot(x, y)
|
|
55
|
+
ax.set_title(title)
|
|
56
|
+
|
|
57
|
+
ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
|
|
58
|
+
|
|
59
|
+
ax.tick_params(axis='x', rotation=30)
|
|
60
|
+
|
|
61
|
+
ax.set_xlabel("Step")
|
|
62
|
+
ax.set_ylabel(title)
|
|
63
|
+
|
|
64
|
+
total_plots = len(results.keys())
|
|
65
|
+
for i in range(total_plots, len(axes)):
|
|
66
|
+
axes[i].set_visible(False)
|
|
67
|
+
|
|
68
|
+
plt.tight_layout()
|
|
69
|
+
plt.show()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
import os, sys
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from matplotlib.ticker import MaxNLocator
|
|
5
|
+
|
|
6
|
+
if __name__ == '__main__':
|
|
7
|
+
arguments = sys.argv[1:]
|
|
8
|
+
lr_file = arguments[0]
|
|
9
|
+
|
|
10
|
+
if not os.path.exists(lr_file):
|
|
11
|
+
print(f'{lr_file} not found')
|
|
12
|
+
exit(0)
|
|
13
|
+
|
|
14
|
+
lrs = {}
|
|
15
|
+
# [time] step: {self.cur_steps}, lr: {lr}
|
|
16
|
+
with open(lr_file, 'r') as f:
|
|
17
|
+
for line in f:
|
|
18
|
+
if not line:
|
|
19
|
+
continue
|
|
20
|
+
|
|
21
|
+
data = line.split('step: ')[-1]
|
|
22
|
+
data = data.split(', lr:')
|
|
23
|
+
|
|
24
|
+
step = int(data[0].strip())
|
|
25
|
+
lr = float(data[1].strip())
|
|
26
|
+
|
|
27
|
+
lrs[step] = lr
|
|
28
|
+
|
|
29
|
+
plt.title('lr')
|
|
30
|
+
plt.xlabel("Step")
|
|
31
|
+
plt.ylabel("Learning Rate")
|
|
32
|
+
|
|
33
|
+
y = lrs.values()
|
|
34
|
+
x = list(range(len(y)))
|
|
35
|
+
|
|
36
|
+
ax = plt.gca()
|
|
37
|
+
plt.plot(x, y)
|
|
38
|
+
ax.xaxis.set_major_locator(MaxNLocator(nbins=20))
|
|
39
|
+
|
|
40
|
+
plt.xticks(rotation=30)
|
|
41
|
+
|
|
42
|
+
plt.tight_layout()
|
|
43
|
+
plt.show()
|
|
44
|
+
|
|
45
|
+
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import os, sys
|
|
5
|
+
arguments = sys.argv[1:]
|
|
6
|
+
run_file_name = arguments[0]
|
|
7
|
+
|
|
8
|
+
os.environ['PARALLEL_TYPE'] = 'none'
|
|
9
|
+
command = f'python3 {run_file_name}'
|
|
10
|
+
|
|
11
|
+
print(f'real command is {command}')
|
|
12
|
+
os.system(command)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import os, sys, torch
|
|
5
|
+
|
|
6
|
+
arguments = sys.argv[1:]
|
|
7
|
+
# file name
|
|
8
|
+
run_file_name = arguments[0]
|
|
9
|
+
|
|
10
|
+
extra_args = ''
|
|
11
|
+
if len(arguments) > 1:
|
|
12
|
+
extra_args = f"{' '.join(arguments[1:])} "
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import deepspeed
|
|
16
|
+
parallel_type = 'ds'
|
|
17
|
+
except:
|
|
18
|
+
gpu_count = torch.cuda.device_count()
|
|
19
|
+
if gpu_count <= 1:
|
|
20
|
+
parallel_type = 'none'
|
|
21
|
+
else:
|
|
22
|
+
parallel_type = 'ddp'
|
|
23
|
+
|
|
24
|
+
os.environ['PARALLEL_TYPE'] = parallel_type
|
|
25
|
+
|
|
26
|
+
if parallel_type == 'ds':
|
|
27
|
+
command = f'deepspeed {extra_args}{run_file_name}'
|
|
28
|
+
elif parallel_type == 'ddp':
|
|
29
|
+
if len(extra_args) == 0:
|
|
30
|
+
extra_args = '--standalone --nproc_per_node=gpu '
|
|
31
|
+
|
|
32
|
+
command = f'torchrun {extra_args}{run_file_name}'
|
|
33
|
+
else:
|
|
34
|
+
command = f'python3 {run_file_name}'
|
|
35
|
+
|
|
36
|
+
print(f'run command {command}')
|
|
37
|
+
os.system(command)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
llm_trainer/__init__.py,sha256=U_rFD6hqNJuNXjcKJ9QnxnAL3SXhyWdGZEcA5GbrU3s,385
|
|
2
|
+
llm_trainer/base_trainer.py,sha256=tAqUdSsrJBTBZFZKsinHAoBr7KmDD72qJHux2lMOMYg,29029
|
|
3
|
+
llm_trainer/checkpoint.py,sha256=Aal5D7pVPVRlLZU3WAJKC6-cXoDTIj2JdH_InOaP_1E,4466
|
|
4
|
+
llm_trainer/dataset.py,sha256=SuUedIU46yiHRIz-Fa5pgQr5h9UMQKQ6OSyvQ8xkMow,10917
|
|
5
|
+
llm_trainer/dpo_trainer.py,sha256=jSx2g9snX6sNounpU9gcUZzv4XVRyslxrM5msR5o6Ko,12687
|
|
6
|
+
llm_trainer/ds_checkpoint.py,sha256=I67co_LttpX7nIr5rW_qjtt_QJxKG_UiIvqKYI89rA0,2304
|
|
7
|
+
llm_trainer/eval.py,sha256=uuzWF40xfEx5nPntVEXdyb9UnWiG9cSWF0N3v5FFZDk,981
|
|
8
|
+
llm_trainer/generate_utils.py,sha256=wdOmU3PvMP0OzlsE8_zvoK_Kcq0saQm10_vTozfFxjA,15792
|
|
9
|
+
llm_trainer/grpo_trainer.py,sha256=EKwkmTZWAQrNPQQDSRV4ucAunj0_iEYBhJoog9yOQWE,14882
|
|
10
|
+
llm_trainer/log.py,sha256=BCb8qzs2TGltBFHNuDeEibT6FgBZZTZ-Ijuu1XNOSes,1746
|
|
11
|
+
llm_trainer/loss.py,sha256=56Q0sIO8J4uVOgyvbnHDBdls5m3iW3HrsQ2XWN4zC-I,10228
|
|
12
|
+
llm_trainer/parallel.py,sha256=eWRcqFkOfWM50Chv6gKpifAkaoxF3h8lr3592QXBmx8,6199
|
|
13
|
+
llm_trainer/partition_utils.py,sha256=EMXVGi-AN2piqbOCQei7WmddwQ07jwC5RWClaofIj9Q,8087
|
|
14
|
+
llm_trainer/ppo_trainer.py,sha256=8uY2cYfCYLb_hIN9u0VgP-IMY7D-c0lfIUY-a66Dy84,22445
|
|
15
|
+
llm_trainer/scheduler.py,sha256=7VTmv6slOSB03-KY9nCEzsOrqPW9Jw-jPDxVudmGPzw,5178
|
|
16
|
+
llm_trainer/sft_trainer.py,sha256=NWUkHJe3Ii54bwlnBKWs2pP7zIOUM47Sc7A5TWXG_AI,3682
|
|
17
|
+
llm_trainer/tokenizer.py,sha256=8Mccp4sCaYWiKVD78dEwBMHlA9uS0xf22FOiVxTVtK4,5875
|
|
18
|
+
llm_trainer/tools.py,sha256=7i5ZdCE-TOtoD8hz1Xzx9mIe3wANTd3la_T3vXp6LuM,3328
|
|
19
|
+
llm_trainer/train_configs.py,sha256=FjYuW2e9CuTGm07-wfjow_49R7mhAjdcHpdifFPcuRo,10384
|
|
20
|
+
llm_trainer/trainer.py,sha256=PsSDZvvNVrFun7B_sUYA0QsBaC-2C-CYb6ey3PlRWCw,1210
|
|
21
|
+
llm_trainer/utils.py,sha256=TumXZvN7EyxvTsXYdGwaKlPfup-VK3HsF3GJOM0zrf4,20380
|
|
22
|
+
project_llm_trainer-0.12.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
23
|
+
project_llm_trainer-0.12.3.data/scripts/ddp_train,sha256=eZSud6KYQAoKLsYB5QB-FI2zq5AZm6Apq1azKdupV3o,477
|
|
24
|
+
project_llm_trainer-0.12.3.data/scripts/ds_train,sha256=41q4rOxwbvZDUY0FDdAIpG13PEaUWBpthhvFvww8uOc,388
|
|
25
|
+
project_llm_trainer-0.12.3.data/scripts/plot_log,sha256=EuYQ2_xx98PEtuDr84B4dIji3QSPBHC6WefqyqX7GwI,1872
|
|
26
|
+
project_llm_trainer-0.12.3.data/scripts/plot_lr,sha256=TfLXzqHIFo3mVPy-v-WZlD8zK6Q8IEb1V-fZiwoOug0,922
|
|
27
|
+
project_llm_trainer-0.12.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
28
|
+
project_llm_trainer-0.12.3.data/scripts/smart_train,sha256=N8dp2n7k6bghGczedBVwOdtf1O66oM_cNPh9QmZt0bM,914
|
|
29
|
+
project_llm_trainer-0.12.3.dist-info/METADATA,sha256=QuIPMCqL2V4KoiJkdDF-8Zsb2PZU9tMgqXdYVH53j1g,196
|
|
30
|
+
project_llm_trainer-0.12.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
31
|
+
project_llm_trainer-0.12.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
32
|
+
project_llm_trainer-0.12.3.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
llm_trainer
|