project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,146 @@
1
+ import os
2
+ from typing import Optional, Tuple
3
+ from abc import ABC, abstractmethod
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.distributed as dist
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from .log import log
11
+
12
+
13
+ class Parallel(ABC):
14
+ def __init__(
15
+ self,
16
+ init_process_group: bool = True,
17
+ use_parallel: bool = True,
18
+ use_compile: bool = False
19
+ ):
20
+ self._initialize(init_process_group, use_parallel, use_compile)
21
+
22
+ def _initialize(
23
+ self,
24
+ init_process_group: bool,
25
+ use_parallel: bool,
26
+ use_compile: bool
27
+ ):
28
+ self._global_rank: int = int(os.environ.get('RANK', -1))
29
+ self._local_rank: int = int(os.environ.get('LOCAL_RANK', -1))
30
+ self._use_parallel: bool = use_parallel and self._global_rank != -1
31
+ self._use_compile = use_compile
32
+
33
+ self._sampler: Optional[DistributedSampler] = None
34
+
35
+ self.model: Optional[nn.Module] = None
36
+ self.raw_model: Optional[nn.Module] = None
37
+
38
+ if use_compile:
39
+ torch.set_float32_matmul_precision('high')
40
+
41
+ if self._use_parallel:
42
+ if init_process_group:
43
+ dist.init_process_group(backend='nccl')
44
+
45
+ self.device: str = f'cuda:{self._local_rank}'
46
+ self.device_type: str = 'cuda'
47
+
48
+ torch.cuda.set_device(self.device)
49
+
50
+ log(f'global_rank:{self._global_rank},local_rank:{self._local_rank}, world_size:{self.world_size}')
51
+ else:
52
+ device = "cpu"
53
+ if torch.cuda.is_available():
54
+ device = "cuda"
55
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
56
+ device = "mps"
57
+
58
+ self.device: str = device
59
+ self.device_type: str = device
60
+
61
+
62
+ @abstractmethod
63
+ def process(
64
+ self,
65
+ model: nn.Module,
66
+ optimizer: torch.optim.Optimizer,
67
+ kwargs: Optional[dict] = None
68
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]: ...
69
+
70
+ def process_dataloader(
71
+ self,
72
+ dataset: Dataset,
73
+ data_loader_kwargs: dict,
74
+ sampler_kwargs: Optional[dict]=None
75
+ ) -> DataLoader:
76
+ """
77
+ :param dataset:
78
+ :param data_loader_kwargs
79
+ "batch_size" int,
80
+ "pin_memory" bool,
81
+ "collate_fn" collate_fn,
82
+ "num_workers" int
83
+ "shuffle" bool
84
+ "drop_last" bool
85
+ :param sampler_kwargs:
86
+ "shuffle" bool
87
+ "drop_last" bool
88
+ :return:
89
+ """
90
+
91
+ if self._use_parallel:
92
+ self._sampler = DistributedSampler(dataset=dataset, **sampler_kwargs)
93
+ return DataLoader(dataset=dataset, sampler=self._sampler, **data_loader_kwargs)
94
+
95
+ return DataLoader(dataset=dataset, **data_loader_kwargs)
96
+
97
+ def on_epoch_start(self, epoch):
98
+ if self._sampler:
99
+ self._sampler.set_epoch(epoch)
100
+
101
+ def on_epoch_end(self, epoch): ...
102
+
103
+ def synchronize(self):
104
+ if self._use_parallel:
105
+ torch.cuda.synchronize(device=self.device)
106
+
107
+ def destroy(self):
108
+ if self._use_parallel:
109
+ dist.destroy_process_group()
110
+
111
+ # def reduce_loss(self, avg_loss: torch.Tensor, loss: torch.Tensor, batch) -> torch.Tensor:
112
+ # if self._use_parallel:
113
+ # world_size = dist.get_world_size()
114
+ # if world_size < 2:
115
+ # return loss.detach()
116
+ #
117
+ # torch.distributed.all_reduce(loss)
118
+ # # 整个训练过程的滑动损失均值=在历史平均损失的基础上,加上最新损失再求平均
119
+ # avg_loss = (avg_loss * batch + loss.detach()) / (batch + 1)
120
+ # return avg_loss
121
+ #
122
+ # return loss.detach()
123
+
124
+ @property
125
+ def parallel_train(self) -> bool:
126
+ return self._use_parallel
127
+
128
+ @property
129
+ def is_main_process(self) -> bool:
130
+ if self._use_parallel:
131
+ return self._global_rank == 0
132
+
133
+ return True
134
+
135
+ @property
136
+ def world_size(self) -> int:
137
+ if self._use_parallel:
138
+ return dist.get_world_size()
139
+ return 1
140
+
141
+ def wait(self):
142
+ try:
143
+ log(f'wait at {self.device}')
144
+ dist.barrier()
145
+ except: ...
146
+ log(f'continue at {self.device}')
@@ -0,0 +1,39 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+
7
+ from .parallel import Parallel
8
+
9
+
10
+ # python3 -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 gpt.py
11
+ # torchrun --standalone --nproc_per_node=gpu pretrain.py
12
+ # --standalone 代表单机运行
13
+ # --nproc_per_node=gpu 代表使用所有可用GPU, 等于号后也可写gpu数量n, 这样会使用前n个GPU
14
+
15
+
16
+ class DdpParallel(Parallel):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def process(
21
+ self,
22
+ model: nn.Module,
23
+ optimizer: torch.optim.Optimizer,
24
+ kwargs: Optional[dict] = None
25
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
26
+ model.to(self.device)
27
+
28
+ if self._use_compile:
29
+ model = torch.compile(model)
30
+
31
+ if self._use_parallel:
32
+ # self.model = DDP(module=model, broadcast_buffers=False, find_unused_parameters=True)
33
+ self.model = DDP(module=model, device_ids=[self._local_rank], output_device=self._local_rank)
34
+ self.raw_model = self.model.module
35
+ else:
36
+ self.model = model
37
+ self.raw_model = model
38
+
39
+ return self.model, optimizer
@@ -0,0 +1,45 @@
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ from torch import nn
4
+ from .parallel import Parallel
5
+
6
+ try:
7
+ import deepspeed
8
+ except: ...
9
+
10
+ class DsParallel(Parallel):
11
+ def __init__(self):
12
+ deepspeed.init_distributed(dist_backend='nccl')
13
+ super().__init__(init_process_group=False)
14
+
15
+ def process(
16
+ self,
17
+ model: nn.Module,
18
+ optimizer: torch.optim.Optimizer,
19
+ kwargs: Optional[dict] = None
20
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
21
+ """
22
+ :param model:
23
+ :param optimizer:
24
+ :param kwargs:
25
+ 参考deepspeed配置
26
+ :return:
27
+ """
28
+ self.raw_model = model
29
+
30
+ model, optim, _, _ = deepspeed.initialize(
31
+ model=model,
32
+ optimizer=optimizer,
33
+ dist_init_required=False,
34
+ config_params=kwargs
35
+ )
36
+
37
+ self.model = model
38
+ return model, optim
39
+
40
+ def synchronize(self): ...
41
+
42
+ def destroy(self): ...
43
+
44
+
45
+
@@ -0,0 +1,115 @@
1
+ from typing import Optional, Tuple
2
+ import functools
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributed.fsdp import (
6
+ FullyShardedDataParallel as FSDP,
7
+ MixedPrecision,
8
+ ShardingStrategy,
9
+ BackwardPrefetch,
10
+ CPUOffload,
11
+ )
12
+
13
+ from torch.distributed.fsdp.wrap import (
14
+ size_based_auto_wrap_policy,
15
+ transformer_auto_wrap_policy,
16
+ always_wrap_policy,
17
+ enable_wrap,
18
+ wrap,
19
+ )
20
+
21
+ from .parallel import Parallel
22
+
23
+ class FsdpParallel(Parallel):
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def process(
28
+ self,
29
+ model: nn.Module,
30
+ optimizer: torch.optim.Optimizer,
31
+ kwargs: Optional[dict] = None
32
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
33
+ """
34
+ :param model:
35
+ :param optimizer:
36
+ :param kwargs:
37
+ "wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
38
+ "cpu_offload" bool 是否使用cpu卸载
39
+ "offload_params" bool 是否卸载参数,在cpu_offload为True时生效
40
+ :return:
41
+ """
42
+
43
+ model.to(self.device)
44
+
45
+ if self._use_compile:
46
+ model = torch.compile(model)
47
+
48
+ if self._use_parallel:
49
+ if 'transformer_layer_cls' in kwargs:
50
+ auto_wrap_policy = functools.partial(
51
+ transformer_auto_wrap_policy,
52
+ transformer_layer_cls=kwargs['transformer_layer_cls']
53
+ )
54
+ elif 'wrap_policy_num_params' in kwargs:
55
+ auto_wrap_policy = functools.partial(
56
+ size_based_auto_wrap_policy,
57
+ min_num_params=kwargs['wrap_policy_num_params']
58
+ )
59
+ else:
60
+ auto_wrap_policy = None
61
+
62
+ if 'cpu_offload' in kwargs:
63
+ offload_params = False
64
+ if 'offload_params' in kwargs:
65
+ offload_params = kwargs['offload_params']
66
+
67
+ # 选择配置 cpu_offload,以便在计算中不使用包装参数时将这些参数卸载到 CPU。
68
+ # 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。
69
+ cpu_offload = CPUOffload(offload_params=offload_params)
70
+ else:
71
+ cpu_offload = None
72
+
73
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
74
+ mixed_precision = MixedPrecision(
75
+ param_dtype=torch.bfloat16,
76
+ # Gradient communication precision.
77
+ reduce_dtype=torch.bfloat16,
78
+ # Buffer precision.
79
+ buffer_dtype=torch.bfloat16,
80
+ )
81
+ else:
82
+ mixed_precision = None
83
+
84
+ self.raw_model = model
85
+
86
+ # device_mesh = init_device_mesh("cuda", (self.world_size,))
87
+ # self.model = FSDP(
88
+ # model,
89
+ # auto_wrap_policy=auto_wrap_policy,
90
+ # mixed_precision=mixed_precision,
91
+ # cpu_offload=cpu_offload,
92
+ # device_id=torch.cuda.current_device(),
93
+ # device_mesh=device_mesh
94
+ # )
95
+
96
+ self.model = FSDP(
97
+ model,
98
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
99
+ auto_wrap_policy=auto_wrap_policy,
100
+ mixed_precision=mixed_precision,
101
+ cpu_offload=cpu_offload,
102
+ device_id=torch.cuda.current_device(),
103
+ process_group=None,
104
+ # use_orig_params=True,
105
+ # backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # bit faster async comms, bit higher memory
106
+ # limit_all_gathers=False,
107
+ # forward_prefetch=True,
108
+ )
109
+ else:
110
+ self.model = model
111
+ self.raw_model = model
112
+
113
+ return self.model, optimizer
114
+
115
+
@@ -0,0 +1,28 @@
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ from torch import nn
4
+
5
+ from .parallel import Parallel
6
+
7
+ class NoneParallel(Parallel):
8
+ def __init__(self):
9
+ super().__init__(use_parallel=False)
10
+
11
+ def process(
12
+ self,
13
+ model: nn.Module,
14
+ optimizer: torch.optim.Optimizer,
15
+ kwargs: Optional[dict] = None
16
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
17
+ model.to(self.device)
18
+
19
+ if self._use_compile:
20
+ model = torch.compile(model)
21
+
22
+ self.raw_model = model
23
+ self.model = model
24
+
25
+ return self.model, optimizer
26
+
27
+
28
+
@@ -0,0 +1,138 @@
1
+ from abc import ABC, abstractmethod
2
+ import math
3
+ import torch
4
+ from .log import (
5
+ log,
6
+ get_log_dir
7
+ )
8
+
9
+ class LRScheduler(ABC):
10
+ @property
11
+ @abstractmethod
12
+ def cur_steps(self): ...
13
+
14
+ @property
15
+ @abstractmethod
16
+ def cur_lr(self): ...
17
+
18
+ @abstractmethod
19
+ def update_steps(self, steps): ...
20
+
21
+ @abstractmethod
22
+ def step(self): ...
23
+
24
+ @abstractmethod
25
+ def can_clip_grad(self): ...
26
+
27
+
28
+ class WarmupCosineAnnealingLRScheduler(LRScheduler):
29
+ def __init__(
30
+ self,
31
+ *,
32
+ optimizer: torch.optim.Optimizer,
33
+ initial_lr: float,
34
+ min_lr: float,
35
+ max_lr: float,
36
+ warmup_iters: int,
37
+ period: int, # 每个周期的步数
38
+ period_mul: int = 1, # 周期长度的倍数
39
+ need_log: bool = False
40
+ ):
41
+ super().__init__()
42
+
43
+ self._optimizer = optimizer
44
+ self._initial_lr = initial_lr
45
+ self._min_lr = min_lr
46
+ self._max_lr = max_lr
47
+ self._warmup_iters = warmup_iters
48
+
49
+ self._period = period
50
+ self._period_mul = period_mul
51
+
52
+ self.T_cur = 0 # 当前周期内已走过的步数
53
+ self.cycle = 0 # 当前周期编号
54
+
55
+ if warmup_iters != 0:
56
+ self._lr_increment = (max_lr - initial_lr) / warmup_iters
57
+ else:
58
+ self._lr_increment = 0
59
+
60
+ self._steps = -1
61
+ self._current_lr = initial_lr
62
+ self._cosine_annealing_base_lr = None
63
+
64
+ self.need_log = need_log
65
+
66
+
67
+ @property
68
+ def cur_steps(self):
69
+ return self._steps
70
+
71
+ @property
72
+ def cur_lr(self):
73
+ return self._current_lr
74
+
75
+ def update_steps(self, steps):
76
+ log(f'update step to {steps}')
77
+ self._steps = steps
78
+ self._update_lr()
79
+
80
+ def step(self):
81
+ self._steps += 1
82
+ self._update_lr()
83
+
84
+ def can_clip_grad(self):
85
+ return self._steps > self._warmup_iters
86
+
87
+ def _update_lr(self):
88
+ if self._steps <= self._warmup_iters:
89
+ # Warmup: adjust learning rate linearly
90
+ # (max_lr - initial_lr) / warmup_iters
91
+ lr = self._initial_lr + self._steps * self._lr_increment
92
+ for param_group in self._optimizer.param_groups:
93
+ param_group['lr'] = lr
94
+ else:
95
+ if not self._cosine_annealing_base_lr:
96
+ self._cosine_annealing_base_lr = self.cur_lr
97
+
98
+ """每步更新学习率"""
99
+ # 计算当前周期的最大步数
100
+ T_max = self._period * (self._period_mul ** self.cycle)
101
+
102
+ # 更新周期状态
103
+ self.T_cur += 1
104
+ if self.T_cur >= T_max:
105
+ self.cycle += 1
106
+ self.T_cur = 0 # 重置周期步数
107
+
108
+ # 计算并设置新学习率
109
+ cos_factor = (1 + math.cos(math.pi * self.T_cur / T_max)) / 2
110
+ lr = self._min_lr + (self._cosine_annealing_base_lr - self._min_lr) * cos_factor
111
+
112
+ for param_group in self._optimizer.param_groups:
113
+ param_group['lr'] = lr
114
+
115
+ self._current_lr = lr
116
+
117
+ if self.need_log:
118
+ log(f"step={self.cur_steps},lr={lr}\n", f'{get_log_dir()}lr.txt')
119
+
120
+
121
+ class NoneLRScheduler(LRScheduler):
122
+ def __init__(self, initial_lr):
123
+ self._current_lr = initial_lr
124
+
125
+ @property
126
+ def cur_steps(self):
127
+ return -1
128
+
129
+ @property
130
+ def cur_lr(self):
131
+ return self._current_lr
132
+
133
+ def update_steps(self, steps): ...
134
+
135
+ def step(self): ...
136
+
137
+ def can_clip_grad(self):
138
+ return True
@@ -0,0 +1,39 @@
1
+ from typing import Optional, Tuple, List
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ from .trainer import Trainer
6
+ from .train_configs import TrainConfig, VLMConfig
7
+ from .dataset import LineByLineTextDataset
8
+ from .utils import get_sft_collate_fn
9
+
10
+
11
+ class SFTTrainer(Trainer):
12
+ def __init__(
13
+ self,
14
+ *,
15
+ train_config: TrainConfig,
16
+ eval_prompts: List[str],
17
+ eval_image_tags: Optional[List[int]] = None
18
+ ):
19
+ super().__init__(
20
+ train_config=train_config,
21
+ eval_prompts=eval_prompts,
22
+ eval_image_tags=eval_image_tags
23
+ )
24
+
25
+ def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
26
+ sft_collate_fn = get_sft_collate_fn(self.train_config.mask_prompt)
27
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
28
+ data_loader_kwargs.update({"collate_fn": sft_collate_fn})
29
+
30
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
31
+
32
+ def _create_dataset(self, file_path) -> Dataset:
33
+ max_position_embeddings = self.train_config.model_config.max_position_embeddings
34
+ if isinstance(self.train_config.model_config, VLMConfig):
35
+ tokens_per_image = self.train_config.model_config.tokens_per_image
36
+ else:
37
+ tokens_per_image = -1
38
+
39
+ return LineByLineTextDataset(file_path, max_position_embeddings, tokens_per_image)