project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,6 @@
1
+ from .trainer import Trainer
2
+ from .sft_trainer import SFTTrainer
3
+ from .dpo_trainer import DPOTrainer
4
+ from .grpo_trainer import GRPOTrainer
5
+ from .tools import TrainerTools, FileDataset, estimate_data_size
6
+ from .generate_utils import generate, streaming_generate
@@ -0,0 +1,161 @@
1
+ import os
2
+ from typing import Optional, Union, Tuple
3
+ import torch
4
+ from torch import nn
5
+ from torch.optim import Optimizer
6
+
7
+ from .parallel_ds import DsParallel
8
+ from .parallel_fsdp import FsdpParallel
9
+ from .parallel_ddp import DdpParallel
10
+ from .scheduler import LRScheduler
11
+ from .tools import TrainerTools
12
+
13
+ try:
14
+ from .dcp import save_dcp, load_dcp, convert_dcp_to_pth
15
+ except:
16
+ os.environ['ENABLE_DCP'] = "0"
17
+
18
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19
+
20
+ # https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
21
+
22
+ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
23
+
24
+
25
+ def _can_use_dcp(model: nn.Module) -> bool:
26
+ if os.environ.get('ENABLE_DCP', '1') != '1':
27
+ return False
28
+
29
+ # 如果是fsdp或者ddp,才能使用dcp保存
30
+ if (isinstance(TrainerTools().parallel, FsdpParallel)
31
+ or isinstance(TrainerTools().parallel, DdpParallel)):
32
+ return True
33
+
34
+ return False
35
+
36
+
37
+ def save_checkpoint(
38
+ model: nn.Module,
39
+ optimizer: Optional[Optimizer] = None,
40
+ suffix: Optional[str] = None
41
+ ):
42
+ if isinstance(TrainerTools().parallel, DsParallel):
43
+ from .ds_checkpoint import save_ds_checkpoint
44
+ save_ds_checkpoint(model, suffix)
45
+ elif _can_use_dcp(model):
46
+ save_dcp(model, optimizer, suffix)
47
+ else:
48
+ if isinstance(model, FSDP):
49
+ # 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
50
+ # 是否使用rank0_only=True?
51
+ with FSDP.summon_full_params(
52
+ module=model,
53
+ rank0_only=True,
54
+ writeback=False,
55
+ offload_to_cpu=True
56
+ ):
57
+ if TrainerTools().parallel.is_main_process:
58
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
59
+ if suffix:
60
+ checkpoint_name = f"{checkpoint_name}_{suffix}"
61
+
62
+ ckpt = {'model_state_dict': model.state_dict()}
63
+
64
+ if optimizer:
65
+ ckpt.update({'optim_state_dict': optimizer.state_dict()})
66
+
67
+ torch.save(ckpt, checkpoint_name)
68
+ else:
69
+ if TrainerTools().parallel.is_main_process:
70
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
71
+ if suffix:
72
+ checkpoint_name = f"{checkpoint_name}_{suffix}"
73
+
74
+ ckpt = {'model_state_dict': TrainerTools().parallel.raw_model.state_dict()}
75
+
76
+ if optimizer:
77
+ ckpt.update({'optim_state_dict': optimizer.state_dict()})
78
+
79
+ torch.save(ckpt, checkpoint_name)
80
+
81
+
82
+ def load_checkpoint(
83
+ model: nn.Module,
84
+ optimizer: Optional[Optimizer] = None,
85
+ device: Optional[Union[torch.device, str]] = None,
86
+ load_module_only: bool = False,
87
+ suffix: Optional[str] = None
88
+ ):
89
+ if isinstance(TrainerTools().parallel, DsParallel):
90
+ from .ds_checkpoint import load_ds_checkpoint
91
+ load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
92
+ elif _can_use_dcp(model):
93
+ load_dcp(model, optimizer, suffix)
94
+ else:
95
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
96
+ if suffix:
97
+ checkpoint_name = f"{checkpoint_name}_{suffix}"
98
+
99
+ if os.path.exists(checkpoint_name):
100
+ # 未经过测试,else的逻辑经过测试在fsdp下也没问题
101
+ if isinstance(model, FSDP):
102
+ with FSDP.summon_full_params(module=model):
103
+ state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
104
+ model.load_state_dict(state_dict['model_state_dict'])
105
+
106
+ if optimizer:
107
+ optimizer.load_state_dict(state_dict['optim_state_dict'])
108
+ else:
109
+ state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
110
+ model.load_state_dict(state_dict['model_state_dict'])
111
+
112
+ if optimizer:
113
+ optimizer.load_state_dict(state_dict['optim_state_dict'])
114
+
115
+
116
+ def load_checkpoint_for_eval(
117
+ model: nn.Module,
118
+ device: Optional[Union[torch.device, str]] = None,
119
+ suffix: Optional[str] = None
120
+ ):
121
+ if isinstance(TrainerTools().parallel, DsParallel):
122
+ from .ds_checkpoint import load_ds_checkpoint_for_eval
123
+ load_ds_checkpoint_for_eval(model)
124
+ elif _can_use_dcp(model):
125
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
126
+
127
+ # load_dcp方式在cpu上会报错,所以改为先将ckpt转换为pth,然后再加载pth
128
+ # load_dcp(model, optimizer)
129
+ pth_name = os.environ.get('EVAL_CHECKPOINT_NAME', checkpoint_name)
130
+ if suffix:
131
+ pth_name = f'{pth_name}_{suffix}'
132
+
133
+ convert_dcp_to_pth(pth_name)
134
+
135
+ if os.path.exists(pth_name):
136
+ ckpt = torch.load(pth_name, map_location=device, weights_only=True)
137
+ model.load_state_dict(ckpt['app']['model_state_dict'])
138
+ # 使用完删除
139
+ os.remove(pth_name)
140
+ else:
141
+ load_checkpoint(model, None, device, suffix=suffix)
142
+
143
+
144
+ def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
145
+ # 暂时只保存主进程的
146
+ if TrainerTools().parallel.is_main_process:
147
+ steps_checkpoint_name = f"{os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)}.steps"
148
+ ckpt = {'global_steps': global_steps, 'lr_steps': lr_scheduler.cur_steps}
149
+ torch.save(ckpt, steps_checkpoint_name)
150
+
151
+
152
+ def load_steps(
153
+ default_global_steps: int = 0,
154
+ default_lr_steps: int = 0
155
+ ) -> Tuple[Optional[int], Optional[int]]:
156
+ steps_checkpoint_name = f"{os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)}.steps"
157
+ if os.path.exists(steps_checkpoint_name):
158
+ ckpt = torch.load(steps_checkpoint_name, weights_only=True)
159
+ return ckpt['global_steps'], ckpt['lr_steps']
160
+
161
+ return default_global_steps, default_lr_steps
llm_trainer/dataset.py ADDED
@@ -0,0 +1,140 @@
1
+ import os.path
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import pickle
6
+
7
+ from .tools import TrainerTools
8
+ from .utils import extra_image_tag_and_repeat_image_tok
9
+
10
+
11
+ def _try_load_pkl(file_path: str):
12
+ tokens = None
13
+ try:
14
+ with open(file_path, 'rb') as f:
15
+ tokens = pickle.load(f)
16
+ except Exception as e:
17
+ raise e
18
+ finally:
19
+ return tokens
20
+
21
+
22
+ class TextDataset(Dataset):
23
+ """
24
+ 适用于pretrain阶段
25
+ """
26
+ def __init__(self, file_path, block_size, stride):
27
+ super().__init__()
28
+
29
+ self.input_ids = []
30
+
31
+ tokens = _try_load_pkl(file_path)
32
+ if not tokens:
33
+ cache_file = f'{file_path}.cache'
34
+ if os.path.exists(cache_file):
35
+ tokens = _try_load_pkl(cache_file)
36
+ else:
37
+ tokens = []
38
+ with open(file_path, 'r') as f:
39
+ for line in f:
40
+ tokens.extend(TrainerTools().tokenizer.encode(line))
41
+
42
+ with open(cache_file, 'wb') as f:
43
+ pickle.dump(tokens, f)
44
+
45
+ for i in range(0, len(tokens) - block_size + 1, stride):
46
+ self.input_ids.append(tokens[i:i+block_size])
47
+
48
+ def __len__(self):
49
+ return len(self.input_ids)
50
+
51
+ def __getitem__(self, item):
52
+ return torch.tensor(self.input_ids[item]).long()
53
+
54
+
55
+ class LineByLineTextDataset(Dataset):
56
+ """
57
+ 适用于sft阶段
58
+ """
59
+ def __init__(self, file_path, max_len, tokens_per_image=-1):
60
+ super().__init__()
61
+
62
+ self.max_len = max_len
63
+ self.tokens_per_image = tokens_per_image
64
+ self.input_ids = []
65
+
66
+ tokens = _try_load_pkl(file_path)
67
+ if not tokens:
68
+ cache_file = f'{file_path}.cache'
69
+ if os.path.exists(cache_file):
70
+ tokens = _try_load_pkl(cache_file)
71
+ else:
72
+ tokens = []
73
+ with open(file_path, 'r') as f:
74
+ for line in f:
75
+ tokens.append(TrainerTools().tokenizer.encode(line))
76
+
77
+ with open(cache_file, 'wb') as f:
78
+ pickle.dump(tokens, f)
79
+
80
+ self.input_ids = tokens
81
+
82
+ def __len__(self):
83
+ return len(self.input_ids)
84
+
85
+ def __getitem__(self, item):
86
+ inputs = self.input_ids[item]
87
+ if self.tokens_per_image != -1:
88
+ inputs, image_tag = extra_image_tag_and_repeat_image_tok(inputs, self.tokens_per_image)
89
+ else:
90
+ image_tag = None
91
+
92
+ inputs = inputs[:self.max_len]
93
+
94
+ return {'inputs': torch.tensor(inputs).long(), 'image_tag': image_tag}
95
+
96
+
97
+ class DPODataset(Dataset):
98
+ def __init__(self, file_path, max_len):
99
+ self.max_len = max_len
100
+ self.chosen_ids = []
101
+ self.rejected_ids = []
102
+
103
+ # [{'chosen': xxx, 'rejected': xxx} ...]
104
+ tokens = _try_load_pkl(file_path)
105
+ for token in tokens:
106
+ self.chosen_ids.append(token['chosen'])
107
+ self.rejected_ids.append(token['rejected'])
108
+
109
+ def __len__(self):
110
+ return len(self.chosen_ids)
111
+
112
+ def __getitem__(self, item):
113
+ chosen_id = self.chosen_ids[item]
114
+ rejected_id = self.rejected_ids[item]
115
+
116
+ return {'chosen': chosen_id[:self.max_len], 'rejected': rejected_id[:self.max_len]}
117
+
118
+
119
+ class GRPORolloutDataset(Dataset):
120
+ def __init__(self, file_path):
121
+ self.questions = []
122
+ self.answers = []
123
+
124
+ # [{'question': xxx, 'answer': ''}]
125
+ tokens = _try_load_pkl(file_path)
126
+ for token in tokens:
127
+ self.questions.append(token['prompt'])
128
+ self.answers.append(token['answer'])
129
+
130
+ def __len__(self):
131
+ return len(self.questions)
132
+
133
+ def __getitem__(self, item):
134
+ question = self.questions[item]
135
+ answer = self.answers[item]
136
+
137
+ return {
138
+ 'prompt': torch.tensor(question).long(),
139
+ 'answer': torch.tensor(answer).long()
140
+ }
llm_trainer/dcp.py ADDED
@@ -0,0 +1,93 @@
1
+ import os
2
+ from typing import Optional, Dict, Any
3
+ from torch import nn
4
+ from torch.optim import Optimizer
5
+ import torch.distributed.checkpoint as dcp
6
+ from torch.distributed.checkpoint.stateful import Stateful
7
+ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
8
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
9
+
10
+ DEFAULT_CHECKPOINT_DIR = "checkpoint"
11
+
12
+ class AppState(Stateful):
13
+ def __init__(self, model: nn.Module, optimizer: Optimizer):
14
+ self.model = model
15
+ self.optimizer = optimizer
16
+
17
+ def state_dict(self) -> Dict[str, Any]:
18
+ model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
19
+ return {
20
+ 'model_state_dict': model_state_dict,
21
+ 'optim_state_dict': optimizer_state_dict
22
+ }
23
+
24
+ def load_state_dict(self, state_dict: Dict[str, Any]):
25
+ set_state_dict(
26
+ model=self.model,
27
+ optimizers=self.optimizer,
28
+ model_state_dict=state_dict['model_state_dict'],
29
+ optim_state_dict=state_dict['optim_state_dict']
30
+ )
31
+
32
+
33
+ def save_dcp(
34
+ model: nn.Module,
35
+ optimizer: Optimizer,
36
+ suffix: Optional[str] = None
37
+ ):
38
+ checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
39
+ if suffix:
40
+ checkpoint_id = f"{checkpoint_id}_{suffix}"
41
+
42
+ state_dict = {'app': AppState(model, optimizer)}
43
+
44
+ # fs_storage_writer = dcp.FileSystemWriter(checkpoint_id, overwrite=True)
45
+ # dcp.save(state_dict=state_dict, storage_writer=fs_storage_writer)
46
+ dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_id)
47
+
48
+
49
+ def load_dcp(
50
+ model: nn.Module,
51
+ optimizer: Optional[Optimizer] = None,
52
+ suffix: Optional[str] = None
53
+ ):
54
+ checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
55
+ if suffix:
56
+ checkpoint_id = f"{checkpoint_id}_{suffix}"
57
+
58
+ if os.path.exists(checkpoint_id):
59
+ state_dict = {'app': AppState(model, optimizer)}
60
+ # AppState帮助加载到state_dict中, 然后加载到model中
61
+ dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
62
+
63
+ # if isinstance(model, FSDP):
64
+ # state_dict = {'app': AppState(model, optimizer)}
65
+ # # AppState帮助加载到state_dict中, 然后加载到model中
66
+ # dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
67
+ # else:
68
+ # state_dict = {"model_state_dict": model.state_dict()}
69
+ #
70
+ # if optimizer:
71
+ # state_dict.update({'optim_state_dict': optimizer.state_dict()})
72
+ #
73
+ # # since no progress group is initialized, DCP will disable any collectives.
74
+ # # 加载到state_dict中,然后通过model.load_state_dict加载到model中
75
+ # dcp.load(
76
+ # state_dict=state_dict,
77
+ # checkpoint_id=checkpoint_id,
78
+ # )
79
+ #
80
+ # model.load_state_dict(state_dict["model_state_dict"])
81
+ # if optimizer:
82
+ # optimizer.load_state_dict(state_dict["optim_state_dict"])
83
+
84
+ def convert_dcp_to_pth(pth_path: str):
85
+ dcp_path = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
86
+ if os.path.exists(dcp_path):
87
+ # convert dcp model to torch.save (assumes checkpoint was generated as above)
88
+ dcp_to_torch_save(dcp_path, pth_path)
89
+
90
+ def convert_pth_to_dcp(pth_path: str):
91
+ if os.path.exists(pth_path):
92
+ # converts the torch.save model back to DCP
93
+ torch_save_to_dcp(pth_path, os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR))