project-llm-trainer 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +26 -26
- llm_trainer/scheduler.py +41 -11
- llm_trainer/trainer.py +8 -5
- {project_llm_trainer-0.5.5.dist-info → project_llm_trainer-0.5.7.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.5.dist-info → project_llm_trainer-0.5.7.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.5.dist-info → project_llm_trainer-0.5.7.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.5.dist-info → project_llm_trainer-0.5.7.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -38,29 +38,32 @@ def save_best_checkpoint(
|
|
|
38
38
|
) -> bool:
|
|
39
39
|
need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
|
|
40
40
|
if need_replace and TrainerTools().parallel.is_main_process:
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
try:
|
|
42
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
43
|
+
checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
if checkpoint_dir.endswith('/'):
|
|
46
|
+
best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
|
|
47
|
+
else:
|
|
48
|
+
best_checkpoint_dir = f'{checkpoint_dir}_best'
|
|
48
49
|
|
|
49
|
-
|
|
50
|
-
|
|
50
|
+
if not os.path.exists(best_checkpoint_dir):
|
|
51
|
+
os.makedirs(best_checkpoint_dir)
|
|
51
52
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
53
|
+
if os.path.exists(checkpoint_dir):
|
|
54
|
+
shutil.rmtree(best_checkpoint_dir)
|
|
55
|
+
shutil.copytree(checkpoint_dir, best_checkpoint_dir)
|
|
56
|
+
else:
|
|
57
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
58
|
+
best_checkpoint_name = f'{checkpoint_name}_best'
|
|
58
59
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
if os.path.exists(checkpoint_name):
|
|
61
|
+
if os.path.exists(best_checkpoint_name):
|
|
62
|
+
os.remove(best_checkpoint_name)
|
|
62
63
|
|
|
63
|
-
|
|
64
|
+
shutil.copy2(checkpoint_name, best_checkpoint_name)
|
|
65
|
+
except:
|
|
66
|
+
pass
|
|
64
67
|
|
|
65
68
|
TrainerTools().parallel.wait()
|
|
66
69
|
return need_replace
|
|
@@ -101,17 +104,14 @@ def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
|
|
|
101
104
|
# 暂时只保存主进程的
|
|
102
105
|
if TrainerTools().parallel.is_main_process:
|
|
103
106
|
steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
|
|
104
|
-
ckpt = {'global_steps': global_steps
|
|
107
|
+
ckpt = {'global_steps': global_steps}
|
|
108
|
+
ckpt.update(lr_scheduler.get_ckpt_dict())
|
|
105
109
|
torch.save(ckpt, steps_checkpoint_name)
|
|
106
110
|
|
|
107
111
|
|
|
108
|
-
def load_steps(
|
|
109
|
-
default_global_steps: int = 0,
|
|
110
|
-
default_lr_steps: int = 0
|
|
111
|
-
) -> Tuple[Optional[int], Optional[int]]:
|
|
112
|
+
def load_steps() -> Optional[dict]:
|
|
112
113
|
steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
|
|
113
114
|
if os.path.exists(steps_checkpoint_name):
|
|
114
|
-
|
|
115
|
-
return ckpt['global_steps'], ckpt['lr_steps']
|
|
115
|
+
return torch.load(steps_checkpoint_name, weights_only=True)
|
|
116
116
|
|
|
117
|
-
return
|
|
117
|
+
return None
|
llm_trainer/scheduler.py
CHANGED
|
@@ -15,15 +15,18 @@ class LRScheduler(ABC):
|
|
|
15
15
|
@abstractmethod
|
|
16
16
|
def cur_lr(self): ...
|
|
17
17
|
|
|
18
|
-
@abstractmethod
|
|
19
|
-
def update_steps(self, steps): ...
|
|
20
|
-
|
|
21
18
|
@abstractmethod
|
|
22
19
|
def step(self): ...
|
|
23
20
|
|
|
24
21
|
@abstractmethod
|
|
25
22
|
def can_clip_grad(self): ...
|
|
26
23
|
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def get_ckpt_dict(self) -> dict: ...
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def restore_ckpt_dict(self, ckpt: dict): ...
|
|
29
|
+
|
|
27
30
|
|
|
28
31
|
class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
29
32
|
def __init__(
|
|
@@ -72,11 +75,6 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
72
75
|
def cur_lr(self):
|
|
73
76
|
return self._current_lr
|
|
74
77
|
|
|
75
|
-
def update_steps(self, steps):
|
|
76
|
-
log(f'update step to {steps}')
|
|
77
|
-
self._steps = steps
|
|
78
|
-
self._update_lr()
|
|
79
|
-
|
|
80
78
|
def step(self):
|
|
81
79
|
self._steps += 1
|
|
82
80
|
self._update_lr()
|
|
@@ -122,6 +120,33 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
122
120
|
if self.need_log:
|
|
123
121
|
log(f"step={self.cur_steps},lr={lr}\n", f'{get_log_dir()}lr.txt')
|
|
124
122
|
|
|
123
|
+
def get_ckpt_dict(self) -> dict:
|
|
124
|
+
return {
|
|
125
|
+
'cur_lr': self._current_lr,
|
|
126
|
+
'lr_steps': self.cur_steps,
|
|
127
|
+
'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
|
|
128
|
+
't_cur': self.T_cur,
|
|
129
|
+
'cycle': self.cycle,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
133
|
+
if ckpt['cur_lr']:
|
|
134
|
+
self._current_lr = ckpt['cur_lr']
|
|
135
|
+
|
|
136
|
+
if ckpt['lr_steps']:
|
|
137
|
+
self._steps = ckpt['lr_steps']
|
|
138
|
+
|
|
139
|
+
if ckpt['cosine_annealing_base_lr']:
|
|
140
|
+
self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
|
|
141
|
+
|
|
142
|
+
if ckpt['t_cur']:
|
|
143
|
+
self.T_cur = ckpt['t_cur']
|
|
144
|
+
|
|
145
|
+
if ckpt['cycle']:
|
|
146
|
+
self.cycle = ckpt['cycle']
|
|
147
|
+
|
|
148
|
+
self._update_lr()
|
|
149
|
+
|
|
125
150
|
|
|
126
151
|
class NoneLRScheduler(LRScheduler):
|
|
127
152
|
def __init__(self, initial_lr):
|
|
@@ -135,9 +160,14 @@ class NoneLRScheduler(LRScheduler):
|
|
|
135
160
|
def cur_lr(self):
|
|
136
161
|
return self._current_lr
|
|
137
162
|
|
|
138
|
-
def update_steps(self, steps): ...
|
|
139
|
-
|
|
140
163
|
def step(self): ...
|
|
141
164
|
|
|
142
165
|
def can_clip_grad(self):
|
|
143
|
-
return True
|
|
166
|
+
return True
|
|
167
|
+
|
|
168
|
+
def get_ckpt_dict(self) -> dict:
|
|
169
|
+
return {'cur_lr': self._current_lr}
|
|
170
|
+
|
|
171
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
172
|
+
if ckpt['cur_lr']:
|
|
173
|
+
self._current_lr = ckpt['cur_lr']
|
llm_trainer/trainer.py
CHANGED
|
@@ -92,12 +92,15 @@ class Trainer:
|
|
|
92
92
|
device=TrainerTools().parallel.device
|
|
93
93
|
)
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
95
|
+
steps_dict = load_steps()
|
|
96
|
+
if steps_dict:
|
|
97
|
+
self.last_global_steps = steps_dict['global_steps']
|
|
98
|
+
if not self.last_global_steps:
|
|
99
|
+
self.last_global_steps = 0
|
|
98
100
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
+
self.lr_scheduler.restore_ckpt_dict(steps_dict)
|
|
102
|
+
|
|
103
|
+
log(f'restore steps_dict = {steps_dict}')
|
|
101
104
|
|
|
102
105
|
if isinstance(train_config.model_config, VLMConfig):
|
|
103
106
|
self.pixel_values_provider = train_config.pixel_values_provider
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=RoRlIB-Qtvl3MyY3g0FbEBHLpFRLnEMZLpEncXOLToQ,4242
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
|
|
@@ -13,21 +13,21 @@ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1
|
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
15
|
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
|
-
llm_trainer/scheduler.py,sha256=
|
|
16
|
+
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=sqN5cXsFAH9xe8-px6tAgcUe5nw6iZU5PEjT9mgEusE,26106
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.7.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.7.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.7.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.7.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.7.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.7.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.7.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.7.dist-info/METADATA,sha256=3yxEJlE4psbIzjpHGKnrpz04BT1n03vUR0xlnqu0-V0,195
|
|
31
|
+
project_llm_trainer-0.5.7.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.7.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.7.dist-info/RECORD,,
|
{project_llm_trainer-0.5.5.data → project_llm_trainer-0.5.7.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|