project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
llm_trainer/parallel.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
from torch.utils.data import Dataset, DataLoader
|
|
9
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
10
|
+
from .log import log
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Parallel(ABC):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
init_process_group: bool = True,
|
|
17
|
+
use_parallel: bool = True,
|
|
18
|
+
use_compile: bool = False
|
|
19
|
+
):
|
|
20
|
+
self._initialize(init_process_group, use_parallel, use_compile)
|
|
21
|
+
|
|
22
|
+
def _initialize(
|
|
23
|
+
self,
|
|
24
|
+
init_process_group: bool,
|
|
25
|
+
use_parallel: bool,
|
|
26
|
+
use_compile: bool
|
|
27
|
+
):
|
|
28
|
+
self._global_rank: int = int(os.environ.get('RANK', -1))
|
|
29
|
+
self._local_rank: int = int(os.environ.get('LOCAL_RANK', -1))
|
|
30
|
+
self._use_parallel: bool = use_parallel and self._global_rank != -1
|
|
31
|
+
self._use_compile = use_compile
|
|
32
|
+
|
|
33
|
+
self._sampler: Optional[DistributedSampler] = None
|
|
34
|
+
|
|
35
|
+
self.model: Optional[nn.Module] = None
|
|
36
|
+
self.raw_model: Optional[nn.Module] = None
|
|
37
|
+
|
|
38
|
+
if use_compile:
|
|
39
|
+
torch.set_float32_matmul_precision('high')
|
|
40
|
+
|
|
41
|
+
if self._use_parallel:
|
|
42
|
+
if init_process_group:
|
|
43
|
+
dist.init_process_group(backend='nccl')
|
|
44
|
+
|
|
45
|
+
self.device: str = f'cuda:{self._local_rank}'
|
|
46
|
+
self.device_type: str = 'cuda'
|
|
47
|
+
|
|
48
|
+
torch.cuda.set_device(self.device)
|
|
49
|
+
|
|
50
|
+
log(f'global_rank:{self._global_rank},local_rank:{self._local_rank}, world_size:{self.world_size}')
|
|
51
|
+
else:
|
|
52
|
+
device = "cpu"
|
|
53
|
+
if torch.cuda.is_available():
|
|
54
|
+
device = "cuda"
|
|
55
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
56
|
+
device = "mps"
|
|
57
|
+
|
|
58
|
+
self.device: str = device
|
|
59
|
+
self.device_type: str = device
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def process(
|
|
64
|
+
self,
|
|
65
|
+
model: nn.Module,
|
|
66
|
+
optimizer: torch.optim.Optimizer,
|
|
67
|
+
kwargs: Optional[dict] = None
|
|
68
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]: ...
|
|
69
|
+
|
|
70
|
+
def process_dataloader(
|
|
71
|
+
self,
|
|
72
|
+
dataset: Dataset,
|
|
73
|
+
data_loader_kwargs: dict,
|
|
74
|
+
sampler_kwargs: Optional[dict]=None
|
|
75
|
+
) -> DataLoader:
|
|
76
|
+
"""
|
|
77
|
+
:param dataset:
|
|
78
|
+
:param data_loader_kwargs
|
|
79
|
+
"batch_size" int,
|
|
80
|
+
"pin_memory" bool,
|
|
81
|
+
"collate_fn" collate_fn,
|
|
82
|
+
"num_workers" int
|
|
83
|
+
"shuffle" bool
|
|
84
|
+
"drop_last" bool
|
|
85
|
+
:param sampler_kwargs:
|
|
86
|
+
"shuffle" bool
|
|
87
|
+
"drop_last" bool
|
|
88
|
+
:return:
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
if self._use_parallel:
|
|
92
|
+
self._sampler = DistributedSampler(dataset=dataset, **sampler_kwargs)
|
|
93
|
+
return DataLoader(dataset=dataset, sampler=self._sampler, **data_loader_kwargs)
|
|
94
|
+
|
|
95
|
+
return DataLoader(dataset=dataset, **data_loader_kwargs)
|
|
96
|
+
|
|
97
|
+
def on_epoch_start(self, epoch):
|
|
98
|
+
if self._sampler:
|
|
99
|
+
self._sampler.set_epoch(epoch)
|
|
100
|
+
|
|
101
|
+
def on_epoch_end(self, epoch): ...
|
|
102
|
+
|
|
103
|
+
def synchronize(self):
|
|
104
|
+
if self._use_parallel:
|
|
105
|
+
torch.cuda.synchronize(device=self.device)
|
|
106
|
+
|
|
107
|
+
def destroy(self):
|
|
108
|
+
if self._use_parallel:
|
|
109
|
+
dist.destroy_process_group()
|
|
110
|
+
|
|
111
|
+
# def reduce_loss(self, avg_loss: torch.Tensor, loss: torch.Tensor, batch) -> torch.Tensor:
|
|
112
|
+
# if self._use_parallel:
|
|
113
|
+
# world_size = dist.get_world_size()
|
|
114
|
+
# if world_size < 2:
|
|
115
|
+
# return loss.detach()
|
|
116
|
+
#
|
|
117
|
+
# torch.distributed.all_reduce(loss)
|
|
118
|
+
# # 整个训练过程的滑动损失均值=在历史平均损失的基础上,加上最新损失再求平均
|
|
119
|
+
# avg_loss = (avg_loss * batch + loss.detach()) / (batch + 1)
|
|
120
|
+
# return avg_loss
|
|
121
|
+
#
|
|
122
|
+
# return loss.detach()
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def parallel_train(self) -> bool:
|
|
126
|
+
return self._use_parallel
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def is_main_process(self) -> bool:
|
|
130
|
+
if self._use_parallel:
|
|
131
|
+
return self._global_rank == 0
|
|
132
|
+
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def world_size(self) -> int:
|
|
137
|
+
if self._use_parallel:
|
|
138
|
+
return dist.get_world_size()
|
|
139
|
+
return 1
|
|
140
|
+
|
|
141
|
+
def wait(self):
|
|
142
|
+
try:
|
|
143
|
+
log(f'wait at {self.device}')
|
|
144
|
+
dist.barrier()
|
|
145
|
+
except: ...
|
|
146
|
+
log(f'continue at {self.device}')
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
6
|
+
|
|
7
|
+
from .parallel import Parallel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# python3 -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 gpt.py
|
|
11
|
+
# torchrun --standalone --nproc_per_node=gpu pretrain.py
|
|
12
|
+
# --standalone 代表单机运行
|
|
13
|
+
# --nproc_per_node=gpu 代表使用所有可用GPU, 等于号后也可写gpu数量n, 这样会使用前n个GPU
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DdpParallel(Parallel):
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
def process(
|
|
21
|
+
self,
|
|
22
|
+
model: nn.Module,
|
|
23
|
+
optimizer: torch.optim.Optimizer,
|
|
24
|
+
kwargs: Optional[dict] = None
|
|
25
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
26
|
+
model.to(self.device)
|
|
27
|
+
|
|
28
|
+
if self._use_compile:
|
|
29
|
+
model = torch.compile(model)
|
|
30
|
+
|
|
31
|
+
if self._use_parallel:
|
|
32
|
+
# self.model = DDP(module=model, broadcast_buffers=False, find_unused_parameters=True)
|
|
33
|
+
self.model = DDP(module=model, device_ids=[self._local_rank], output_device=self._local_rank)
|
|
34
|
+
self.raw_model = self.model.module
|
|
35
|
+
else:
|
|
36
|
+
self.model = model
|
|
37
|
+
self.raw_model = model
|
|
38
|
+
|
|
39
|
+
return self.model, optimizer
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
from .parallel import Parallel
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import deepspeed
|
|
8
|
+
except: ...
|
|
9
|
+
|
|
10
|
+
class DsParallel(Parallel):
|
|
11
|
+
def __init__(self):
|
|
12
|
+
deepspeed.init_distributed(dist_backend='nccl')
|
|
13
|
+
super().__init__(init_process_group=False)
|
|
14
|
+
|
|
15
|
+
def process(
|
|
16
|
+
self,
|
|
17
|
+
model: nn.Module,
|
|
18
|
+
optimizer: torch.optim.Optimizer,
|
|
19
|
+
kwargs: Optional[dict] = None
|
|
20
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
21
|
+
"""
|
|
22
|
+
:param model:
|
|
23
|
+
:param optimizer:
|
|
24
|
+
:param kwargs:
|
|
25
|
+
参考deepspeed配置
|
|
26
|
+
:return:
|
|
27
|
+
"""
|
|
28
|
+
self.raw_model = model
|
|
29
|
+
|
|
30
|
+
model, optim, _, _ = deepspeed.initialize(
|
|
31
|
+
model=model,
|
|
32
|
+
optimizer=optimizer,
|
|
33
|
+
dist_init_required=False,
|
|
34
|
+
config_params=kwargs
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
self.model = model
|
|
38
|
+
return model, optim
|
|
39
|
+
|
|
40
|
+
def synchronize(self): ...
|
|
41
|
+
|
|
42
|
+
def destroy(self): ...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import functools
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.distributed.fsdp import (
|
|
6
|
+
FullyShardedDataParallel as FSDP,
|
|
7
|
+
MixedPrecision,
|
|
8
|
+
ShardingStrategy,
|
|
9
|
+
BackwardPrefetch,
|
|
10
|
+
CPUOffload,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from torch.distributed.fsdp.wrap import (
|
|
14
|
+
size_based_auto_wrap_policy,
|
|
15
|
+
transformer_auto_wrap_policy,
|
|
16
|
+
always_wrap_policy,
|
|
17
|
+
enable_wrap,
|
|
18
|
+
wrap,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from .parallel import Parallel
|
|
22
|
+
|
|
23
|
+
class FsdpParallel(Parallel):
|
|
24
|
+
def __init__(self):
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
def process(
|
|
28
|
+
self,
|
|
29
|
+
model: nn.Module,
|
|
30
|
+
optimizer: torch.optim.Optimizer,
|
|
31
|
+
kwargs: Optional[dict] = None
|
|
32
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
33
|
+
"""
|
|
34
|
+
:param model:
|
|
35
|
+
:param optimizer:
|
|
36
|
+
:param kwargs:
|
|
37
|
+
"wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
|
|
38
|
+
"cpu_offload" bool 是否使用cpu卸载
|
|
39
|
+
"offload_params" bool 是否卸载参数,在cpu_offload为True时生效
|
|
40
|
+
:return:
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
model.to(self.device)
|
|
44
|
+
|
|
45
|
+
if self._use_compile:
|
|
46
|
+
model = torch.compile(model)
|
|
47
|
+
|
|
48
|
+
if self._use_parallel:
|
|
49
|
+
if 'transformer_layer_cls' in kwargs:
|
|
50
|
+
auto_wrap_policy = functools.partial(
|
|
51
|
+
transformer_auto_wrap_policy,
|
|
52
|
+
transformer_layer_cls=kwargs['transformer_layer_cls']
|
|
53
|
+
)
|
|
54
|
+
elif 'wrap_policy_num_params' in kwargs:
|
|
55
|
+
auto_wrap_policy = functools.partial(
|
|
56
|
+
size_based_auto_wrap_policy,
|
|
57
|
+
min_num_params=kwargs['wrap_policy_num_params']
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
auto_wrap_policy = None
|
|
61
|
+
|
|
62
|
+
if 'cpu_offload' in kwargs:
|
|
63
|
+
offload_params = False
|
|
64
|
+
if 'offload_params' in kwargs:
|
|
65
|
+
offload_params = kwargs['offload_params']
|
|
66
|
+
|
|
67
|
+
# 选择配置 cpu_offload,以便在计算中不使用包装参数时将这些参数卸载到 CPU。
|
|
68
|
+
# 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。
|
|
69
|
+
cpu_offload = CPUOffload(offload_params=offload_params)
|
|
70
|
+
else:
|
|
71
|
+
cpu_offload = None
|
|
72
|
+
|
|
73
|
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
74
|
+
mixed_precision = MixedPrecision(
|
|
75
|
+
param_dtype=torch.bfloat16,
|
|
76
|
+
# Gradient communication precision.
|
|
77
|
+
reduce_dtype=torch.bfloat16,
|
|
78
|
+
# Buffer precision.
|
|
79
|
+
buffer_dtype=torch.bfloat16,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
mixed_precision = None
|
|
83
|
+
|
|
84
|
+
self.raw_model = model
|
|
85
|
+
|
|
86
|
+
# device_mesh = init_device_mesh("cuda", (self.world_size,))
|
|
87
|
+
# self.model = FSDP(
|
|
88
|
+
# model,
|
|
89
|
+
# auto_wrap_policy=auto_wrap_policy,
|
|
90
|
+
# mixed_precision=mixed_precision,
|
|
91
|
+
# cpu_offload=cpu_offload,
|
|
92
|
+
# device_id=torch.cuda.current_device(),
|
|
93
|
+
# device_mesh=device_mesh
|
|
94
|
+
# )
|
|
95
|
+
|
|
96
|
+
self.model = FSDP(
|
|
97
|
+
model,
|
|
98
|
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
99
|
+
auto_wrap_policy=auto_wrap_policy,
|
|
100
|
+
mixed_precision=mixed_precision,
|
|
101
|
+
cpu_offload=cpu_offload,
|
|
102
|
+
device_id=torch.cuda.current_device(),
|
|
103
|
+
process_group=None,
|
|
104
|
+
# use_orig_params=True,
|
|
105
|
+
# backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # bit faster async comms, bit higher memory
|
|
106
|
+
# limit_all_gathers=False,
|
|
107
|
+
# forward_prefetch=True,
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
self.model = model
|
|
111
|
+
self.raw_model = model
|
|
112
|
+
|
|
113
|
+
return self.model, optimizer
|
|
114
|
+
|
|
115
|
+
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
from .parallel import Parallel
|
|
6
|
+
|
|
7
|
+
class NoneParallel(Parallel):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__(use_parallel=False)
|
|
10
|
+
|
|
11
|
+
def process(
|
|
12
|
+
self,
|
|
13
|
+
model: nn.Module,
|
|
14
|
+
optimizer: torch.optim.Optimizer,
|
|
15
|
+
kwargs: Optional[dict] = None
|
|
16
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
17
|
+
model.to(self.device)
|
|
18
|
+
|
|
19
|
+
if self._use_compile:
|
|
20
|
+
model = torch.compile(model)
|
|
21
|
+
|
|
22
|
+
self.raw_model = model
|
|
23
|
+
self.model = model
|
|
24
|
+
|
|
25
|
+
return self.model, optimizer
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
llm_trainer/scheduler.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import math
|
|
3
|
+
import torch
|
|
4
|
+
from .log import (
|
|
5
|
+
log,
|
|
6
|
+
get_log_dir
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
class LRScheduler(ABC):
|
|
10
|
+
@property
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def cur_steps(self): ...
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def cur_lr(self): ...
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def update_steps(self, steps): ...
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def step(self): ...
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def can_clip_grad(self): ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*,
|
|
32
|
+
optimizer: torch.optim.Optimizer,
|
|
33
|
+
initial_lr: float,
|
|
34
|
+
min_lr: float,
|
|
35
|
+
max_lr: float,
|
|
36
|
+
warmup_iters: int,
|
|
37
|
+
period: int, # 每个周期的步数
|
|
38
|
+
period_mul: int = 1, # 周期长度的倍数
|
|
39
|
+
need_log: bool = False
|
|
40
|
+
):
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self._optimizer = optimizer
|
|
44
|
+
self._initial_lr = initial_lr
|
|
45
|
+
self._min_lr = min_lr
|
|
46
|
+
self._max_lr = max_lr
|
|
47
|
+
self._warmup_iters = warmup_iters
|
|
48
|
+
|
|
49
|
+
self._period = period
|
|
50
|
+
self._period_mul = period_mul
|
|
51
|
+
|
|
52
|
+
self.T_cur = 0 # 当前周期内已走过的步数
|
|
53
|
+
self.cycle = 0 # 当前周期编号
|
|
54
|
+
|
|
55
|
+
if warmup_iters != 0:
|
|
56
|
+
self._lr_increment = (max_lr - initial_lr) / warmup_iters
|
|
57
|
+
else:
|
|
58
|
+
self._lr_increment = 0
|
|
59
|
+
|
|
60
|
+
self._steps = -1
|
|
61
|
+
self._current_lr = initial_lr
|
|
62
|
+
self._cosine_annealing_base_lr = None
|
|
63
|
+
|
|
64
|
+
self.need_log = need_log
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def cur_steps(self):
|
|
69
|
+
return self._steps
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def cur_lr(self):
|
|
73
|
+
return self._current_lr
|
|
74
|
+
|
|
75
|
+
def update_steps(self, steps):
|
|
76
|
+
log(f'update step to {steps}')
|
|
77
|
+
self._steps = steps
|
|
78
|
+
self._update_lr()
|
|
79
|
+
|
|
80
|
+
def step(self):
|
|
81
|
+
self._steps += 1
|
|
82
|
+
self._update_lr()
|
|
83
|
+
|
|
84
|
+
def can_clip_grad(self):
|
|
85
|
+
return self._steps > self._warmup_iters
|
|
86
|
+
|
|
87
|
+
def _update_lr(self):
|
|
88
|
+
if self._steps <= self._warmup_iters:
|
|
89
|
+
# Warmup: adjust learning rate linearly
|
|
90
|
+
# (max_lr - initial_lr) / warmup_iters
|
|
91
|
+
lr = self._initial_lr + self._steps * self._lr_increment
|
|
92
|
+
for param_group in self._optimizer.param_groups:
|
|
93
|
+
param_group['lr'] = lr
|
|
94
|
+
else:
|
|
95
|
+
if not self._cosine_annealing_base_lr:
|
|
96
|
+
self._cosine_annealing_base_lr = self.cur_lr
|
|
97
|
+
|
|
98
|
+
"""每步更新学习率"""
|
|
99
|
+
# 计算当前周期的最大步数
|
|
100
|
+
T_max = self._period * (self._period_mul ** self.cycle)
|
|
101
|
+
|
|
102
|
+
# 更新周期状态
|
|
103
|
+
self.T_cur += 1
|
|
104
|
+
if self.T_cur >= T_max:
|
|
105
|
+
self.cycle += 1
|
|
106
|
+
self.T_cur = 0 # 重置周期步数
|
|
107
|
+
|
|
108
|
+
# 计算并设置新学习率
|
|
109
|
+
cos_factor = (1 + math.cos(math.pi * self.T_cur / T_max)) / 2
|
|
110
|
+
lr = self._min_lr + (self._cosine_annealing_base_lr - self._min_lr) * cos_factor
|
|
111
|
+
|
|
112
|
+
for param_group in self._optimizer.param_groups:
|
|
113
|
+
param_group['lr'] = lr
|
|
114
|
+
|
|
115
|
+
self._current_lr = lr
|
|
116
|
+
|
|
117
|
+
if self.need_log:
|
|
118
|
+
log(f"step={self.cur_steps},lr={lr}\n", f'{get_log_dir()}lr.txt')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class NoneLRScheduler(LRScheduler):
|
|
122
|
+
def __init__(self, initial_lr):
|
|
123
|
+
self._current_lr = initial_lr
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def cur_steps(self):
|
|
127
|
+
return -1
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def cur_lr(self):
|
|
131
|
+
return self._current_lr
|
|
132
|
+
|
|
133
|
+
def update_steps(self, steps): ...
|
|
134
|
+
|
|
135
|
+
def step(self): ...
|
|
136
|
+
|
|
137
|
+
def can_clip_grad(self):
|
|
138
|
+
return True
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Optional, Tuple, List
|
|
2
|
+
|
|
3
|
+
from torch.utils.data import Dataset
|
|
4
|
+
|
|
5
|
+
from .trainer import Trainer
|
|
6
|
+
from .train_configs import TrainConfig, VLMConfig
|
|
7
|
+
from .dataset import LineByLineTextDataset
|
|
8
|
+
from .utils import get_sft_collate_fn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SFTTrainer(Trainer):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
*,
|
|
15
|
+
train_config: TrainConfig,
|
|
16
|
+
eval_prompts: List[str],
|
|
17
|
+
eval_image_tags: Optional[List[int]] = None
|
|
18
|
+
):
|
|
19
|
+
super().__init__(
|
|
20
|
+
train_config=train_config,
|
|
21
|
+
eval_prompts=eval_prompts,
|
|
22
|
+
eval_image_tags=eval_image_tags
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
|
|
26
|
+
sft_collate_fn = get_sft_collate_fn(self.train_config.mask_prompt)
|
|
27
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
|
|
28
|
+
data_loader_kwargs.update({"collate_fn": sft_collate_fn})
|
|
29
|
+
|
|
30
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
31
|
+
|
|
32
|
+
def _create_dataset(self, file_path) -> Dataset:
|
|
33
|
+
max_position_embeddings = self.train_config.model_config.max_position_embeddings
|
|
34
|
+
if isinstance(self.train_config.model_config, VLMConfig):
|
|
35
|
+
tokens_per_image = self.train_config.model_config.tokens_per_image
|
|
36
|
+
else:
|
|
37
|
+
tokens_per_image = -1
|
|
38
|
+
|
|
39
|
+
return LineByLineTextDataset(file_path, max_position_embeddings, tokens_per_image)
|