project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,37 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys, torch
5
+
6
+ arguments = sys.argv[1:]
7
+ # file name
8
+ run_file_name = arguments[0]
9
+
10
+ extra_args = ''
11
+ if len(arguments) > 1:
12
+ extra_args = f"{' '.join(arguments[1:])} "
13
+
14
+ try:
15
+ import deepspeed
16
+ parallel_type = 'ds'
17
+ except:
18
+ gpu_count = torch.cuda.device_count()
19
+ if gpu_count <= 1:
20
+ parallel_type = 'none'
21
+ else:
22
+ parallel_type = 'ddp'
23
+
24
+ os.environ['PARALLEL_TYPE'] = parallel_type
25
+
26
+ if parallel_type == 'ds':
27
+ command = f'deepspeed {extra_args}{run_file_name}'
28
+ elif parallel_type == 'ddp':
29
+ if len(extra_args) == 0:
30
+ extra_args = '--standalone --nproc_per_node=gpu '
31
+
32
+ command = f'torchrun {extra_args}{run_file_name}'
33
+ else:
34
+ command = f'python3 {run_file_name}'
35
+
36
+ print(f'run command {command}')
37
+ os.system(command)
@@ -0,0 +1,98 @@
1
+ #!python
2
+
3
+ import math
4
+ import os, sys
5
+ import matplotlib.pyplot as plt
6
+ from numpy import ndarray
7
+ from matplotlib.ticker import MaxNLocator
8
+ import re
9
+
10
+ if __name__ == '__main__':
11
+ arguments = sys.argv[1:]
12
+ loss_file = arguments[0]
13
+
14
+ if not os.path.exists(loss_file):
15
+ print(f'{loss_file} not found')
16
+ exit(0)
17
+
18
+ data_map = {}
19
+ all_metric_keys = []
20
+
21
+ with open(loss_file, 'r') as f:
22
+ for line in f:
23
+ if '====' in line:
24
+ continue
25
+
26
+ try:
27
+ meta_part, values_part = line.split(' -> ')
28
+
29
+ epoch = int(re.search(r'epoch:\s*(\d+)', meta_part).group(1))
30
+ file_str = re.search(r'file:\s*(\d+)', meta_part).group(1)
31
+ file_idx = int(file_str)
32
+ batch_str = re.search(r'batch:\s*(\d+)', meta_part).group(1)
33
+ batch_idx = int(batch_str)
34
+
35
+ sort_key = (epoch, file_idx, batch_idx)
36
+
37
+ current_metrics = {}
38
+ values_kvs = values_part.split(', ')
39
+ for values_kv in values_kvs:
40
+ k, v = values_kv.split(': ')
41
+ val = float(v.strip())
42
+ current_metrics[k] = val
43
+
44
+ if k not in all_metric_keys:
45
+ all_metric_keys.append(k)
46
+
47
+ data_map[sort_key] = current_metrics
48
+
49
+ except Exception as e:
50
+ continue
51
+
52
+ sorted_keys = sorted(data_map.keys())
53
+ results = {k: [] for k in all_metric_keys}
54
+
55
+ for key in sorted_keys:
56
+ metrics = data_map[key]
57
+ for k in all_metric_keys:
58
+ if k in metrics:
59
+ results[k].append(metrics[k])
60
+
61
+ if not results:
62
+ print("No valid data found.")
63
+ exit(0)
64
+
65
+ results_size = len(results.keys())
66
+ if results_size <= 4:
67
+ rows = 1
68
+ cols = results_size
69
+ else:
70
+ rows = math.ceil(results_size / 4)
71
+ cols = 4
72
+
73
+ fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4 * cols, 4 * rows))
74
+
75
+ if isinstance(axes, ndarray):
76
+ axes = axes.flatten()
77
+ else:
78
+ axes = [axes]
79
+
80
+ for idx, title in enumerate(results.keys()):
81
+ ax = axes[idx]
82
+ y = results[title]
83
+ x = list(range(len(y)))
84
+
85
+ ax.plot(x, y)
86
+ ax.set_title(title)
87
+
88
+ ax.xaxis.set_major_locator(MaxNLocator(nbins=10))
89
+ ax.tick_params(axis='x', rotation=30)
90
+ ax.set_xlabel("Step")
91
+ ax.set_ylabel(title)
92
+
93
+ total_plots = len(results.keys())
94
+ for i in range(total_plots, len(axes)):
95
+ axes[i].set_visible(False)
96
+
97
+ plt.tight_layout()
98
+ plt.show()
@@ -0,0 +1,46 @@
1
+ #!python
2
+
3
+ import os, sys
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.ticker import MaxNLocator
6
+
7
+ if __name__ == '__main__':
8
+ arguments = sys.argv[1:]
9
+ lr_file = arguments[0]
10
+
11
+ if not os.path.exists(lr_file):
12
+ print(f'{lr_file} not found')
13
+ exit(0)
14
+
15
+ lrs = {}
16
+ # [time] step: {self.cur_steps}, lr: {lr}
17
+ with open(lr_file, 'r') as f:
18
+ for line in f:
19
+ if not line:
20
+ continue
21
+
22
+ data = line.split('step: ')[-1]
23
+ data = data.split(', lr:')
24
+
25
+ step = int(data[0].strip())
26
+ lr = float(data[1].strip())
27
+
28
+ lrs[step] = lr
29
+
30
+ plt.title('lr')
31
+ plt.xlabel("Step")
32
+ plt.ylabel("Learning Rate")
33
+
34
+ y = lrs.values()
35
+ x = list(range(len(y)))
36
+
37
+ ax = plt.gca()
38
+ plt.plot(x, y)
39
+ ax.xaxis.set_major_locator(MaxNLocator(nbins=20))
40
+
41
+ plt.xticks(rotation=30)
42
+
43
+ plt.tight_layout()
44
+ plt.show()
45
+
46
+
@@ -0,0 +1,9 @@
1
+ Metadata-Version: 2.4
2
+ Name: project_llm_trainer
3
+ Version: 0.13.4
4
+ Summary: LLM and VLM trainer
5
+ Author: qibin
6
+ Author-email: qibin0506@gmail.com
7
+ Dynamic: author
8
+ Dynamic: author-email
9
+ Dynamic: summary
@@ -0,0 +1,32 @@
1
+ llm_trainer/__init__.py,sha256=U_rFD6hqNJuNXjcKJ9QnxnAL3SXhyWdGZEcA5GbrU3s,385
2
+ llm_trainer/base_trainer.py,sha256=62zoWzNajK07cnSLuWovxZSlQOikvK5hGa7nW5Yy9BE,29916
3
+ llm_trainer/checkpoint.py,sha256=vjarm-9J-9HAklpQAxbB3Bgph2HI6gxBQvUkB3LywwI,4009
4
+ llm_trainer/dataset.py,sha256=obbJuFmRS3-ntjF3q7acRYkbKYNqLQFMtZij0mCfCjU,10947
5
+ llm_trainer/dpo_trainer.py,sha256=TI8SZxxiqS3BA8IByQl74fjyjCNe-C6OXAqBNbcO5Yw,13192
6
+ llm_trainer/ds_checkpoint.py,sha256=0XZEdBV50obVmAXK1dX_mNuS-yomZW6RTzt1R0TdCyw,2611
7
+ llm_trainer/eval.py,sha256=6qwkRZQXpWJoGm3173Tx39GbgI0gEjA0VNath5J9ekg,1004
8
+ llm_trainer/generate_utils.py,sha256=Yc6xqS0xIaWx4paJMIHDrvQaLCHi5_R91dKvoEtMXgw,16388
9
+ llm_trainer/grpo_trainer.py,sha256=3dSxFSzxzTciGYjUZ_7VN6SdHZx71RIILq0c7Ph6QfU,15962
10
+ llm_trainer/log.py,sha256=BCb8qzs2TGltBFHNuDeEibT6FgBZZTZ-Ijuu1XNOSes,1746
11
+ llm_trainer/loss.py,sha256=AeiUSIkUV6JqyhH3M5CSrXFY9Y_EscG-kE3aOw4bMBE,10140
12
+ llm_trainer/parallel.py,sha256=eWRcqFkOfWM50Chv6gKpifAkaoxF3h8lr3592QXBmx8,6199
13
+ llm_trainer/partition_utils.py,sha256=EMXVGi-AN2piqbOCQei7WmddwQ07jwC5RWClaofIj9Q,8087
14
+ llm_trainer/ppo_trainer.py,sha256=xXgXNVKxTV1jTuz25J1BMfP6r9I0k-hRVGf-b4yJsyw,28946
15
+ llm_trainer/scheduler.py,sha256=cNRPeApnIrSh0fRDo9qKkrkRSYJb7JWKlWOJ30rmzoM,6448
16
+ llm_trainer/sft_trainer.py,sha256=yAHZp8MUlngKgciEUrcVhdEFjjQKRwQ-NqppaBmhc5Y,3687
17
+ llm_trainer/tokenizer.py,sha256=8Mccp4sCaYWiKVD78dEwBMHlA9uS0xf22FOiVxTVtK4,5875
18
+ llm_trainer/tools.py,sha256=QGYOwjabWEMyOe_N9z1yL9WNEjNrEshpZFjnv_QOZH0,3323
19
+ llm_trainer/train_configs.py,sha256=ZL4M5ap3ndaK8hRnBCJ3mjspBYiDyzU8rZxsu2LXJ4E,10519
20
+ llm_trainer/trainer.py,sha256=X0E5-mU5SZRrpevDhhCuUIVMVs0GhVnY7OwAhEgMo9w,1214
21
+ llm_trainer/utils.py,sha256=4SBse7AXn6R7xiRKpRGOF9xrx_ZP9SidgyANkO22CxU,23346
22
+ project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
23
+ project_llm_trainer-0.13.4.data/scripts/ddp_train,sha256=eZSud6KYQAoKLsYB5QB-FI2zq5AZm6Apq1azKdupV3o,477
24
+ project_llm_trainer-0.13.4.data/scripts/ds_train,sha256=41q4rOxwbvZDUY0FDdAIpG13PEaUWBpthhvFvww8uOc,388
25
+ project_llm_trainer-0.13.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
26
+ project_llm_trainer-0.13.4.data/scripts/smart_train,sha256=N8dp2n7k6bghGczedBVwOdtf1O66oM_cNPh9QmZt0bM,914
27
+ project_llm_trainer-0.13.4.data/scripts/vis_log,sha256=hn3HinTbmOhn9PTby_vodAWmNHDwRA0a9yoU7DHqMjg,2626
28
+ project_llm_trainer-0.13.4.data/scripts/vis_lr,sha256=mgSOckQrRw_42locxk09TTBEeCqSTiu7j1OJ5_vMLDU,923
29
+ project_llm_trainer-0.13.4.dist-info/METADATA,sha256=TaaOytFZKGXMITWTGqPL6Dvm_v_dhLT-ejsMvQ7hsH4,196
30
+ project_llm_trainer-0.13.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
31
+ project_llm_trainer-0.13.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
32
+ project_llm_trainer-0.13.4.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.7.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ llm_trainer