project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +0 -50
- llm_trainer/dpo_trainer.py +6 -3
- llm_trainer/eval.py +3 -30
- llm_trainer/generate_utils.py +9 -74
- llm_trainer/grpo_trainer.py +27 -28
- llm_trainer/loss.py +1 -1
- llm_trainer/partition_utils.py +146 -0
- llm_trainer/tokenizer.py +10 -10
- llm_trainer/tools.py +0 -2
- llm_trainer/train_configs.py +5 -25
- llm_trainer/trainer.py +28 -67
- llm_trainer/utils.py +0 -1
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
- project_llm_trainer-0.5.1.dist-info/RECORD +33 -0
- llm_trainer/dcp.py +0 -93
- llm_trainer/ds_model_params.py +0 -72
- llm_trainer/fsdp_checkpoint.py +0 -52
- llm_trainer/fsdp_model_params.py +0 -39
- llm_trainer/model_params.py +0 -28
- llm_trainer/parallel_fsdp.py +0 -121
- project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
llm_trainer/tokenizer.py
CHANGED
|
@@ -26,8 +26,8 @@ class Tokenizer:
|
|
|
26
26
|
self.text_user = '<user>'
|
|
27
27
|
self.text_assistant = '<assistant>'
|
|
28
28
|
|
|
29
|
-
self.
|
|
30
|
-
self.
|
|
29
|
+
self.text_think_start = '<think>'
|
|
30
|
+
self.text_think_end = '</think>'
|
|
31
31
|
|
|
32
32
|
self.text_answer_start = '<answer>'
|
|
33
33
|
self.text_answer_end = '</answer>'
|
|
@@ -47,8 +47,8 @@ class Tokenizer:
|
|
|
47
47
|
additional_special_tokens = [
|
|
48
48
|
AddedToken(self.text_user, lstrip=False, rstrip=False),
|
|
49
49
|
AddedToken(self.text_assistant, lstrip=False, rstrip=False),
|
|
50
|
-
AddedToken(self.
|
|
51
|
-
AddedToken(self.
|
|
50
|
+
AddedToken(self.text_think_start, lstrip=False, rstrip=False),
|
|
51
|
+
AddedToken(self.text_think_end, lstrip=False, rstrip=False),
|
|
52
52
|
AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
|
|
53
53
|
AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
|
|
54
54
|
AddedToken(self.text_system, lstrip=False, rstrip=False),
|
|
@@ -69,8 +69,8 @@ class Tokenizer:
|
|
|
69
69
|
self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
|
|
70
70
|
self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
|
|
71
71
|
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
72
|
+
self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
|
|
73
|
+
self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
|
|
74
74
|
|
|
75
75
|
self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
|
|
76
76
|
self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
|
|
@@ -140,9 +140,9 @@ class Tokenizer:
|
|
|
140
140
|
{"role":"user", "content":"hello?"},
|
|
141
141
|
{"role":"assistant", "content":"hello"},
|
|
142
142
|
{"role":"user", "content":"hello hello?"},
|
|
143
|
-
{"role":"assistant", "
|
|
143
|
+
{"role":"assistant", "think":"thinking", "content":"hello hello"},
|
|
144
144
|
]
|
|
145
|
-
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><
|
|
145
|
+
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
|
|
146
146
|
"""
|
|
147
147
|
|
|
148
148
|
chat_template = ''
|
|
@@ -154,8 +154,8 @@ class Tokenizer:
|
|
|
154
154
|
if add_answer_tag_for_assistant and role == 'assistant':
|
|
155
155
|
content = f"{self.text_answer_start}{content}{self.text_answer_end}"
|
|
156
156
|
|
|
157
|
-
if '
|
|
158
|
-
content = f"{self.
|
|
157
|
+
if 'think' in conversation:
|
|
158
|
+
content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
|
|
159
159
|
|
|
160
160
|
chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
|
|
161
161
|
|
llm_trainer/tools.py
CHANGED
|
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
import torch
|
|
4
4
|
from .tokenizer import Tokenizer
|
|
5
5
|
from .parallel_ds import DsParallel
|
|
6
|
-
from .parallel_fsdp import FsdpParallel
|
|
7
6
|
from .parallel_ddp import DdpParallel
|
|
8
7
|
from .parallel_none import NoneParallel
|
|
9
8
|
from .log import log
|
|
@@ -11,7 +10,6 @@ from .log import log
|
|
|
11
10
|
|
|
12
11
|
parallel_types = {
|
|
13
12
|
'ds': DsParallel,
|
|
14
|
-
'fsdp': FsdpParallel,
|
|
15
13
|
'ddp': DdpParallel,
|
|
16
14
|
'none': NoneParallel
|
|
17
15
|
}
|
llm_trainer/train_configs.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
from typing import Optional, Union,
|
|
1
|
+
from typing import Optional, Union, Callable, List, Mapping, Any
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from torch import nn
|
|
6
5
|
from llm_model import ModelConfig, VLMConfig
|
|
7
6
|
from .tools import FileDataset
|
|
8
7
|
|
|
@@ -33,6 +32,9 @@ class DsZeROConfig:
|
|
|
33
32
|
reduce_bucket_size: Optional[Union[str, int]] = 5e8
|
|
34
33
|
contiguous_gradients: Optional[bool] = True
|
|
35
34
|
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class DsZero0Config(DsZeROConfig):
|
|
37
|
+
stage: int = field(default=0, init=False)
|
|
36
38
|
|
|
37
39
|
@dataclass(kw_only=True)
|
|
38
40
|
class DsZero1Config(DsZeROConfig):
|
|
@@ -84,26 +86,6 @@ class DsConfig:
|
|
|
84
86
|
activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
|
|
85
87
|
|
|
86
88
|
|
|
87
|
-
@dataclass(kw_only=True)
|
|
88
|
-
class FsdpConfig:
|
|
89
|
-
"""
|
|
90
|
-
fsdp训练模式配置项
|
|
91
|
-
Args:
|
|
92
|
-
transformer_layer_cls (`Set[Type[nn.Module]]`, *optional*, default is None):
|
|
93
|
-
提供transformer层的类
|
|
94
|
-
wrap_policy_num_params (`int`, *optional*, default is -1):
|
|
95
|
-
size_based_auto_wrap_policy的min_num_params参数,-1不生效该策略
|
|
96
|
-
cpu_offload (`bool`, *optional*, default is False):
|
|
97
|
-
是否使用cpu卸载
|
|
98
|
-
offload_params (`bool`, default is False):
|
|
99
|
-
是否卸载参数,在cpu_offload为True时生效
|
|
100
|
-
"""
|
|
101
|
-
transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None
|
|
102
|
-
wrap_policy_num_params: int = -1
|
|
103
|
-
cpu_offload: bool = False
|
|
104
|
-
offload_params: bool = False
|
|
105
|
-
|
|
106
|
-
|
|
107
89
|
@dataclass(kw_only=True)
|
|
108
90
|
class DataLoaderConfig:
|
|
109
91
|
"""
|
|
@@ -157,6 +139,7 @@ class GRPOConfig:
|
|
|
157
139
|
clip_eps: float = 0.2
|
|
158
140
|
kl_weight: float = 0.01
|
|
159
141
|
group_size: int = 12
|
|
142
|
+
mixup_alpha: float = 1.0
|
|
160
143
|
gen_max_new_tokens: Optional[int] = None
|
|
161
144
|
gen_temperature: Optional[float] = None
|
|
162
145
|
gen_k: Optional[int] = None
|
|
@@ -210,8 +193,6 @@ class TrainConfig:
|
|
|
210
193
|
每隔多少个batch进行模型eval
|
|
211
194
|
lr_config (`LrConfig`):
|
|
212
195
|
lr配置项
|
|
213
|
-
fsdp_config: (`FsdpConfig`):
|
|
214
|
-
fsdp训练模式配置项
|
|
215
196
|
data_loader_config: (`DataLoaderConfig`):
|
|
216
197
|
data loader配置项
|
|
217
198
|
kd_config: (`KDConfig`, *Optional*, default is None):
|
|
@@ -231,7 +212,6 @@ class TrainConfig:
|
|
|
231
212
|
lr_config: LrConfig = field(default_factory=LrConfig)
|
|
232
213
|
|
|
233
214
|
ds_config: DsConfig = field(default_factory=DsConfig)
|
|
234
|
-
fsdp_config: FsdpConfig = field(default_factory=FsdpConfig)
|
|
235
215
|
|
|
236
216
|
kd_config: Optional[KDConfig] = None
|
|
237
217
|
dpo_config: Optional[DPOConfig] = None
|
llm_trainer/trainer.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from contextlib import nullcontext
|
|
3
|
-
from typing import Optional, Tuple, List, Dict, Any
|
|
3
|
+
from typing import Optional, Tuple, List, Dict, Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn
|
|
7
6
|
import torch.distributed as dist
|
|
8
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
9
7
|
from torch.utils.data import Dataset
|
|
10
8
|
from llm_model import LlmModel, VlmModel
|
|
11
9
|
|
|
12
10
|
from .parallel_ds import DsParallel
|
|
13
|
-
from .parallel_fsdp import FsdpParallel
|
|
14
11
|
from .tools import TrainerTools
|
|
15
12
|
from .loss import LMLoss, KDLoss
|
|
16
13
|
from .dataset import TextDataset
|
|
17
|
-
from .model_params import copy_model_params
|
|
18
14
|
from .eval import submit_gen_task
|
|
15
|
+
from .partition_utils import unwrap_model_for_generation
|
|
19
16
|
|
|
20
17
|
from .train_configs import (
|
|
21
18
|
TrainConfig,
|
|
@@ -78,7 +75,6 @@ class Trainer:
|
|
|
78
75
|
|
|
79
76
|
self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
|
|
80
77
|
self.lr_scheduler = self._init_lr_scheduler(initial_lr)
|
|
81
|
-
self.eval_model: Optional[nn.Module] = self._init_eval_model()
|
|
82
78
|
|
|
83
79
|
self.criterion, self.kd_loss = self._init_loss()
|
|
84
80
|
|
|
@@ -86,9 +82,7 @@ class Trainer:
|
|
|
86
82
|
device_type=TrainerTools().parallel.device_type,
|
|
87
83
|
dtype=TrainerTools().dtype,
|
|
88
84
|
enabled=True,
|
|
89
|
-
|
|
90
|
-
# https://www.zhihu.com/question/642793891
|
|
91
|
-
cache_enabled=False if isinstance(self.train_model, FSDP) else None
|
|
85
|
+
cache_enabled=None
|
|
92
86
|
) if TrainerTools().use_amp else nullcontext()
|
|
93
87
|
|
|
94
88
|
load_checkpoint(
|
|
@@ -176,12 +170,6 @@ class Trainer:
|
|
|
176
170
|
|
|
177
171
|
return model, optim
|
|
178
172
|
|
|
179
|
-
def _init_eval_model(self) -> Optional[nn.Module]:
|
|
180
|
-
if TrainerTools().parallel.is_main_process:
|
|
181
|
-
return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
|
|
182
|
-
|
|
183
|
-
return None
|
|
184
|
-
|
|
185
173
|
def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
|
|
186
174
|
if self.train_config.lr_config.enable_lr_scheduler:
|
|
187
175
|
min_lr = self.train_config.lr_config.min_lr
|
|
@@ -313,13 +301,6 @@ class Trainer:
|
|
|
313
301
|
activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
|
|
314
302
|
|
|
315
303
|
parallel_kwargs['activation_checkpointing'] = activation_checkpointing
|
|
316
|
-
elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
|
|
317
|
-
parallel_kwargs = {
|
|
318
|
-
'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
|
|
319
|
-
'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
|
|
320
|
-
'cpu_offload': self.train_config.fsdp_config.cpu_offload,
|
|
321
|
-
'offload_params': self.train_config.fsdp_config.offload_params
|
|
322
|
-
}
|
|
323
304
|
|
|
324
305
|
dataloader_args = self.train_config.data_loader_config
|
|
325
306
|
data_loader_kwargs = {
|
|
@@ -441,54 +422,35 @@ class Trainer:
|
|
|
441
422
|
|
|
442
423
|
raise e
|
|
443
424
|
|
|
444
|
-
def
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
425
|
+
def _eval(self, tag: str):
|
|
426
|
+
with unwrap_model_for_generation(self.train_model) as generate_model:
|
|
427
|
+
if TrainerTools().parallel.is_main_process:
|
|
428
|
+
generate_model.eval()
|
|
429
|
+
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
430
|
+
|
|
431
|
+
if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
432
|
+
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
433
|
+
else:
|
|
434
|
+
eval_pixel_values = None
|
|
435
|
+
|
|
436
|
+
submit_gen_task(
|
|
437
|
+
generate_model,
|
|
438
|
+
self.train_config.eval_config,
|
|
439
|
+
tag=tag,
|
|
440
|
+
prompt=eval_prompt,
|
|
441
|
+
pixel_values=eval_pixel_values,
|
|
442
|
+
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
443
|
+
tokens_per_image=self.tokens_per_image
|
|
444
|
+
)
|
|
445
|
+
generate_model.train()
|
|
449
446
|
|
|
450
|
-
if TrainerTools().parallel.is_main_process:
|
|
451
|
-
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
452
|
-
if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
453
|
-
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
454
|
-
else:
|
|
455
|
-
eval_pixel_values = None
|
|
456
|
-
|
|
457
|
-
submit_gen_task(
|
|
458
|
-
self.eval_model,
|
|
459
|
-
self.train_config.eval_config,
|
|
460
|
-
tag=f'sign:batch/{tag}',
|
|
461
|
-
prompt=eval_prompt,
|
|
462
|
-
pixel_values=eval_pixel_values,
|
|
463
|
-
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
464
|
-
tokens_per_image=self.tokens_per_image
|
|
465
|
-
)
|
|
466
447
|
TrainerTools().parallel.wait()
|
|
467
448
|
|
|
468
|
-
def
|
|
469
|
-
|
|
470
|
-
tag: str
|
|
471
|
-
):
|
|
472
|
-
copy_model_params(_from=self.train_model, _to=self.eval_model)
|
|
473
|
-
|
|
474
|
-
if TrainerTools().parallel.is_main_process:
|
|
475
|
-
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
476
|
-
if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
477
|
-
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
478
|
-
else:
|
|
479
|
-
eval_pixel_values = None
|
|
480
|
-
|
|
481
|
-
submit_gen_task(
|
|
482
|
-
self.eval_model,
|
|
483
|
-
self.train_config.eval_config,
|
|
484
|
-
tag=f'sign:epoch/{tag}',
|
|
485
|
-
prompt=eval_prompt,
|
|
486
|
-
pixel_values=eval_pixel_values,
|
|
487
|
-
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
488
|
-
tokens_per_image=self.tokens_per_image
|
|
489
|
-
)
|
|
449
|
+
def _on_batch_end(self, tag: str):
|
|
450
|
+
self._eval(f'sign:batch/{tag}')
|
|
490
451
|
|
|
491
|
-
|
|
452
|
+
def _on_epoch_end(self, tag: str):
|
|
453
|
+
self._eval(f'sign:epoch/{tag}')
|
|
492
454
|
|
|
493
455
|
def _on_file_start(
|
|
494
456
|
self,
|
|
@@ -574,7 +536,6 @@ class Trainer:
|
|
|
574
536
|
if need_update_grad:
|
|
575
537
|
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
576
538
|
|
|
577
|
-
# todo check all_reduce??
|
|
578
539
|
if TrainerTools().parallel.parallel_train:
|
|
579
540
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
580
541
|
|
llm_trainer/utils.py
CHANGED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=xTmmQSJ_jQDVSTT3km1p_8eRrc7yE_dEsi92z9OX5ec,3251
|
|
3
|
+
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=wMREatLt0I8Ajdm_sI2U8Zj-IN1L6txP9s_tH1oI3-s,11431
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
|
|
6
|
+
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=qiC3KwxYPSB9UKqyk4eSRvORP3b6GM-2ozqI8u3QvI0,15568
|
|
9
|
+
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
|
+
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
|
+
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
12
|
+
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
|
+
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
|
+
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
|
+
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
|
+
llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
17
|
+
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
|
+
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
+
llm_trainer/train_configs.py,sha256=m57W71SI5VCCU9aJ_nJkB-3AJrSGiNXmV28rdpuYmLg,7332
|
|
21
|
+
llm_trainer/trainer.py,sha256=zTJVyY1cAjJdTkyXCOy2ZPVP18SOMLdWhD54Mz2JRe4,25314
|
|
22
|
+
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
+
project_llm_trainer-0.5.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.1.dist-info/METADATA,sha256=x-Bobn0EH7wyKznJydUeVLK9sdIrkBmDYDbEpyG4pKc,195
|
|
31
|
+
project_llm_trainer-0.5.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.1.dist-info/RECORD,,
|
llm_trainer/dcp.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Optional, Dict, Any
|
|
3
|
-
from torch import nn
|
|
4
|
-
from torch.optim import Optimizer
|
|
5
|
-
import torch.distributed.checkpoint as dcp
|
|
6
|
-
from torch.distributed.checkpoint.stateful import Stateful
|
|
7
|
-
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
|
8
|
-
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
|
|
9
|
-
|
|
10
|
-
DEFAULT_CHECKPOINT_DIR = "checkpoint"
|
|
11
|
-
|
|
12
|
-
class AppState(Stateful):
|
|
13
|
-
def __init__(self, model: nn.Module, optimizer: Optimizer):
|
|
14
|
-
self.model = model
|
|
15
|
-
self.optimizer = optimizer
|
|
16
|
-
|
|
17
|
-
def state_dict(self) -> Dict[str, Any]:
|
|
18
|
-
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
|
19
|
-
return {
|
|
20
|
-
'model_state_dict': model_state_dict,
|
|
21
|
-
'optim_state_dict': optimizer_state_dict
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
25
|
-
set_state_dict(
|
|
26
|
-
model=self.model,
|
|
27
|
-
optimizers=self.optimizer,
|
|
28
|
-
model_state_dict=state_dict['model_state_dict'],
|
|
29
|
-
optim_state_dict=state_dict['optim_state_dict']
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def save_dcp(
|
|
34
|
-
model: nn.Module,
|
|
35
|
-
optimizer: Optimizer,
|
|
36
|
-
suffix: Optional[str] = None
|
|
37
|
-
):
|
|
38
|
-
checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
39
|
-
if suffix:
|
|
40
|
-
checkpoint_id = f"{checkpoint_id}_{suffix}"
|
|
41
|
-
|
|
42
|
-
state_dict = {'app': AppState(model, optimizer)}
|
|
43
|
-
|
|
44
|
-
# fs_storage_writer = dcp.FileSystemWriter(checkpoint_id, overwrite=True)
|
|
45
|
-
# dcp.save(state_dict=state_dict, storage_writer=fs_storage_writer)
|
|
46
|
-
dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def load_dcp(
|
|
50
|
-
model: nn.Module,
|
|
51
|
-
optimizer: Optional[Optimizer] = None,
|
|
52
|
-
suffix: Optional[str] = None
|
|
53
|
-
):
|
|
54
|
-
checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
55
|
-
if suffix:
|
|
56
|
-
checkpoint_id = f"{checkpoint_id}_{suffix}"
|
|
57
|
-
|
|
58
|
-
if os.path.exists(checkpoint_id):
|
|
59
|
-
state_dict = {'app': AppState(model, optimizer)}
|
|
60
|
-
# AppState帮助加载到state_dict中, 然后加载到model中
|
|
61
|
-
dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
62
|
-
|
|
63
|
-
# if isinstance(model, FSDP):
|
|
64
|
-
# state_dict = {'app': AppState(model, optimizer)}
|
|
65
|
-
# # AppState帮助加载到state_dict中, 然后加载到model中
|
|
66
|
-
# dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
67
|
-
# else:
|
|
68
|
-
# state_dict = {"model_state_dict": model.state_dict()}
|
|
69
|
-
#
|
|
70
|
-
# if optimizer:
|
|
71
|
-
# state_dict.update({'optim_state_dict': optimizer.state_dict()})
|
|
72
|
-
#
|
|
73
|
-
# # since no progress group is initialized, DCP will disable any collectives.
|
|
74
|
-
# # 加载到state_dict中,然后通过model.load_state_dict加载到model中
|
|
75
|
-
# dcp.load(
|
|
76
|
-
# state_dict=state_dict,
|
|
77
|
-
# checkpoint_id=checkpoint_id,
|
|
78
|
-
# )
|
|
79
|
-
#
|
|
80
|
-
# model.load_state_dict(state_dict["model_state_dict"])
|
|
81
|
-
# if optimizer:
|
|
82
|
-
# optimizer.load_state_dict(state_dict["optim_state_dict"])
|
|
83
|
-
|
|
84
|
-
def convert_dcp_to_pth(pth_path: str):
|
|
85
|
-
dcp_path = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
86
|
-
if os.path.exists(dcp_path):
|
|
87
|
-
# convert dcp model to torch.save (assumes checkpoint was generated as above)
|
|
88
|
-
dcp_to_torch_save(dcp_path, pth_path)
|
|
89
|
-
|
|
90
|
-
def convert_pth_to_dcp(pth_path: str):
|
|
91
|
-
if os.path.exists(pth_path):
|
|
92
|
-
# converts the torch.save model back to DCP
|
|
93
|
-
torch_save_to_dcp(pth_path, os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR))
|
llm_trainer/ds_model_params.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
from torch import nn
|
|
3
|
-
import torch.distributed as dist
|
|
4
|
-
|
|
5
|
-
from .tools import TrainerTools
|
|
6
|
-
|
|
7
|
-
try:
|
|
8
|
-
import deepspeed
|
|
9
|
-
from deepspeed import DeepSpeedEngine
|
|
10
|
-
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
|
11
|
-
except: ...
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
|
|
15
|
-
"""
|
|
16
|
-
需要在所有rank上调用,然后只有rank0有值
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
if model.zero_optimization_stage() != 3:
|
|
20
|
-
if TrainerTools().parallel.is_main_process:
|
|
21
|
-
return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
|
|
22
|
-
return None
|
|
23
|
-
|
|
24
|
-
# --- ZeRO-3 ---
|
|
25
|
-
# 只调用一次 GatheredParameters,传入所有参数
|
|
26
|
-
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
|
|
27
|
-
if TrainerTools().parallel.is_main_process:
|
|
28
|
-
# 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
|
|
29
|
-
# 所以我们可以像操作普通模型一样直接调用 state_dict()
|
|
30
|
-
full_state_dict = model.module.state_dict()
|
|
31
|
-
|
|
32
|
-
# 将其克隆到 CPU 并返回
|
|
33
|
-
return {k: v.cpu().clone() for k, v in full_state_dict.items()}
|
|
34
|
-
|
|
35
|
-
# 其他 rank 执行到这里时,上下文结束,直接返回 None
|
|
36
|
-
return None
|
|
37
|
-
|
|
38
|
-
# # ZeRO-3
|
|
39
|
-
# state_dict_on_rank_0 = {}
|
|
40
|
-
# for param_name, param in model.module.named_parameters():
|
|
41
|
-
# if hasattr(param, 'ds_id'):
|
|
42
|
-
# with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
|
|
43
|
-
# if TrainerTools().parallel.is_main_process:
|
|
44
|
-
# state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
45
|
-
# else:
|
|
46
|
-
# if TrainerTools().parallel.is_main_process:
|
|
47
|
-
# state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
48
|
-
#
|
|
49
|
-
# return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def get_ds_model_params(model: nn.Module, only_rank0=False):
|
|
53
|
-
"""
|
|
54
|
-
从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
|
|
55
|
-
兼容 ZeRO Stages 0, 1, 2, 3。
|
|
56
|
-
包含了对 ZeRO-3 中分片参数的正确处理。
|
|
57
|
-
"""
|
|
58
|
-
|
|
59
|
-
assert isinstance(model, DeepSpeedEngine)
|
|
60
|
-
state_dict = _get_ds_full_state_dict_on_rank0(model)
|
|
61
|
-
|
|
62
|
-
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
63
|
-
# 我们需要将其广播给所有进程。
|
|
64
|
-
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
65
|
-
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
66
|
-
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
67
|
-
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
68
|
-
dist.broadcast_object_list(object_list, src=0)
|
|
69
|
-
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
70
|
-
state_dict = object_list[0]
|
|
71
|
-
|
|
72
|
-
return state_dict
|
llm_trainer/fsdp_checkpoint.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Optional, Union, Tuple
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
from torch.optim import Optimizer
|
|
6
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
7
|
-
|
|
8
|
-
from .tools import TrainerTools
|
|
9
|
-
|
|
10
|
-
DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
11
|
-
|
|
12
|
-
def save_fsdp_checkpoint(
|
|
13
|
-
model: nn.Module,
|
|
14
|
-
optimizer: Optional[Optimizer] = None,
|
|
15
|
-
suffix: Optional[str] = None
|
|
16
|
-
):
|
|
17
|
-
# 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
|
|
18
|
-
# 是否使用rank0_only=True?
|
|
19
|
-
with FSDP.summon_full_params(
|
|
20
|
-
module=model,
|
|
21
|
-
rank0_only=True,
|
|
22
|
-
writeback=False,
|
|
23
|
-
offload_to_cpu=True
|
|
24
|
-
):
|
|
25
|
-
if TrainerTools().parallel.is_main_process:
|
|
26
|
-
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
27
|
-
if suffix:
|
|
28
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
29
|
-
|
|
30
|
-
ckpt = {'model_state_dict': model.state_dict()}
|
|
31
|
-
if optimizer:
|
|
32
|
-
ckpt.update({'optim_state_dict': optimizer.state_dict()})
|
|
33
|
-
|
|
34
|
-
torch.save(ckpt, checkpoint_name)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def load_fsdp_checkpoint(
|
|
38
|
-
model: nn.Module,
|
|
39
|
-
optimizer: Optional[Optimizer] = None,
|
|
40
|
-
device: Optional[Union[torch.device, str]] = None,
|
|
41
|
-
suffix: Optional[str] = None
|
|
42
|
-
):
|
|
43
|
-
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
44
|
-
if suffix:
|
|
45
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
46
|
-
|
|
47
|
-
with FSDP.summon_full_params(module=model):
|
|
48
|
-
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
49
|
-
model.load_state_dict(state_dict['model_state_dict'])
|
|
50
|
-
|
|
51
|
-
if optimizer:
|
|
52
|
-
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
llm_trainer/fsdp_model_params.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
from torch import nn
|
|
3
|
-
import torch.distributed as dist
|
|
4
|
-
|
|
5
|
-
from .tools import TrainerTools
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
|
|
9
|
-
"""
|
|
10
|
-
可以在任意rank上调用,然后只有rank0有值
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
14
|
-
with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
|
|
15
|
-
if TrainerTools().parallel.is_main_process:
|
|
16
|
-
return {k: v.clone() for k, v in model.state_dict().items()}
|
|
17
|
-
|
|
18
|
-
return None
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def get_fsdp_model_params(model: nn.Module, only_rank0=False):
|
|
22
|
-
"""
|
|
23
|
-
从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
|
|
24
|
-
这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
state_dict = _get_fsdp_full_state_dict_on_rank0(model)
|
|
28
|
-
|
|
29
|
-
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
30
|
-
# 我们需要将其广播给所有进程。
|
|
31
|
-
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
32
|
-
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
33
|
-
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
34
|
-
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
35
|
-
dist.broadcast_object_list(object_list, src=0)
|
|
36
|
-
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
37
|
-
state_dict = object_list[0]
|
|
38
|
-
|
|
39
|
-
return state_dict
|
llm_trainer/model_params.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
from torch import nn
|
|
3
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
4
|
-
|
|
5
|
-
from .tools import TrainerTools
|
|
6
|
-
from .parallel_ds import DsParallel
|
|
7
|
-
from .parallel_fsdp import FsdpParallel
|
|
8
|
-
|
|
9
|
-
def copy_model_params(
|
|
10
|
-
_from: nn.Module,
|
|
11
|
-
_to: Optional[nn.Module]
|
|
12
|
-
):
|
|
13
|
-
"""
|
|
14
|
-
必须在所有rank上调用,非rank0, _to可以设置为None
|
|
15
|
-
"""
|
|
16
|
-
if isinstance(TrainerTools().parallel, DsParallel):
|
|
17
|
-
from .ds_model_params import get_ds_model_params
|
|
18
|
-
state_dict = get_ds_model_params(_from, only_rank0=_to is None)
|
|
19
|
-
elif isinstance(TrainerTools().parallel, FsdpParallel):
|
|
20
|
-
from .fsdp_model_params import get_fsdp_model_params
|
|
21
|
-
state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
|
|
22
|
-
elif isinstance(_from, DDP):
|
|
23
|
-
state_dict = _from.module.state_dict()
|
|
24
|
-
else:
|
|
25
|
-
state_dict = _from.state_dict()
|
|
26
|
-
|
|
27
|
-
if _to and state_dict:
|
|
28
|
-
_to.load_state_dict(state_dict)
|