project-llm-trainer 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +683 -0
- llm_trainer/checkpoint.py +126 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +297 -0
- llm_trainer/ds_checkpoint.py +63 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +450 -0
- llm_trainer/grpo_trainer.py +385 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +268 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +521 -0
- llm_trainer/scheduler.py +179 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +324 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +547 -0
- project_llm_trainer-0.12.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.12.3.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.12.3.data/scripts/ds_train +17 -0
- project_llm_trainer-0.12.3.data/scripts/plot_log +69 -0
- project_llm_trainer-0.12.3.data/scripts/plot_lr +45 -0
- project_llm_trainer-0.12.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.12.3.data/scripts/smart_train +37 -0
- project_llm_trainer-0.12.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.12.3.dist-info/RECORD +32 -0
- project_llm_trainer-0.12.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.12.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Union
|
|
3
|
+
import shutil
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.optim import Optimizer
|
|
7
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
8
|
+
|
|
9
|
+
from .parallel import DsParallel
|
|
10
|
+
from .scheduler import LRScheduler
|
|
11
|
+
from .tools import TrainerTools
|
|
12
|
+
|
|
13
|
+
DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
14
|
+
|
|
15
|
+
def save_checkpoint(
|
|
16
|
+
model: nn.Module,
|
|
17
|
+
optimizer: Optional[Optimizer] = None
|
|
18
|
+
):
|
|
19
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
20
|
+
from .ds_checkpoint import save_ds_checkpoint
|
|
21
|
+
save_ds_checkpoint(model)
|
|
22
|
+
else:
|
|
23
|
+
if TrainerTools().parallel.is_main_process:
|
|
24
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
25
|
+
|
|
26
|
+
raw_model = model if not isinstance(model, DDP) else model.module
|
|
27
|
+
ckpt = {'model_state_dict': raw_model.state_dict()}
|
|
28
|
+
|
|
29
|
+
if optimizer:
|
|
30
|
+
ckpt.update({'optim_state_dict': optimizer.state_dict()})
|
|
31
|
+
|
|
32
|
+
torch.save(ckpt, checkpoint_name)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def save_best_checkpoint(
|
|
36
|
+
current_loss: float,
|
|
37
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
38
|
+
) -> bool:
|
|
39
|
+
# 指定不保存最佳checkpoint
|
|
40
|
+
if os.environ.get('SAVE_BEST_CHECKPOINT', '1') != '1':
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
|
|
44
|
+
if need_replace and TrainerTools().parallel.is_main_process:
|
|
45
|
+
try:
|
|
46
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
47
|
+
checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
48
|
+
|
|
49
|
+
if checkpoint_dir.endswith('/'):
|
|
50
|
+
best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
|
|
51
|
+
else:
|
|
52
|
+
best_checkpoint_dir = f'{checkpoint_dir}_best'
|
|
53
|
+
|
|
54
|
+
if not os.path.exists(best_checkpoint_dir):
|
|
55
|
+
os.makedirs(best_checkpoint_dir)
|
|
56
|
+
|
|
57
|
+
if os.path.exists(checkpoint_dir):
|
|
58
|
+
shutil.rmtree(best_checkpoint_dir)
|
|
59
|
+
shutil.copytree(checkpoint_dir, best_checkpoint_dir)
|
|
60
|
+
else:
|
|
61
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
62
|
+
best_checkpoint_name = f'{checkpoint_name}_best'
|
|
63
|
+
|
|
64
|
+
if os.path.exists(checkpoint_name):
|
|
65
|
+
if os.path.exists(best_checkpoint_name):
|
|
66
|
+
os.remove(best_checkpoint_name)
|
|
67
|
+
|
|
68
|
+
shutil.copy2(checkpoint_name, best_checkpoint_name)
|
|
69
|
+
except: pass
|
|
70
|
+
|
|
71
|
+
TrainerTools().parallel.wait('save best checkpoint')
|
|
72
|
+
return need_replace
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def load_checkpoint(
|
|
76
|
+
model: nn.Module,
|
|
77
|
+
optimizer: Optional[Optimizer] = None,
|
|
78
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
79
|
+
load_module_only: bool = False
|
|
80
|
+
):
|
|
81
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
82
|
+
from .ds_checkpoint import load_ds_checkpoint
|
|
83
|
+
load_ds_checkpoint(model, load_module_only=load_module_only)
|
|
84
|
+
else:
|
|
85
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
86
|
+
|
|
87
|
+
if os.path.exists(checkpoint_name):
|
|
88
|
+
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
89
|
+
raw_model = model.module if isinstance(model, DDP) else model
|
|
90
|
+
raw_model.load_state_dict(state_dict['model_state_dict'])
|
|
91
|
+
|
|
92
|
+
if optimizer:
|
|
93
|
+
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def load_checkpoint_for_eval(
|
|
97
|
+
model: nn.Module,
|
|
98
|
+
device: Optional[Union[torch.device, str]] = None
|
|
99
|
+
):
|
|
100
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
101
|
+
from .ds_checkpoint import load_ds_checkpoint_for_eval
|
|
102
|
+
load_ds_checkpoint_for_eval(model)
|
|
103
|
+
else:
|
|
104
|
+
load_checkpoint(model, None, device)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def save_steps(
|
|
108
|
+
global_steps: int,
|
|
109
|
+
lr_scheduler: Optional[LRScheduler] = None,
|
|
110
|
+
):
|
|
111
|
+
# 暂时只保存主进程的
|
|
112
|
+
if TrainerTools().parallel.is_main_process:
|
|
113
|
+
steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
|
|
114
|
+
ckpt = {'global_steps': global_steps}
|
|
115
|
+
if lr_scheduler:
|
|
116
|
+
ckpt.update(lr_scheduler.get_ckpt_dict())
|
|
117
|
+
|
|
118
|
+
torch.save(ckpt, steps_checkpoint_name)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def load_steps() -> Optional[dict]:
|
|
122
|
+
steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
|
|
123
|
+
if os.path.exists(steps_checkpoint_name):
|
|
124
|
+
return torch.load(steps_checkpoint_name, weights_only=True)
|
|
125
|
+
|
|
126
|
+
return None
|
llm_trainer/dataset.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import Dataset
|
|
3
|
+
import pickle
|
|
4
|
+
import csv
|
|
5
|
+
import json
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .tools import TrainerTools
|
|
9
|
+
from .utils import repeat_image_tok
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
support jsonl and pkl
|
|
14
|
+
"""
|
|
15
|
+
def _get_file_type(file_path: str):
|
|
16
|
+
if file_path.endswith('.npy'):
|
|
17
|
+
return 'npy'
|
|
18
|
+
elif file_path.endswith('.jsonl'):
|
|
19
|
+
return 'jsonl'
|
|
20
|
+
elif file_path.endswith('.pkl'):
|
|
21
|
+
return 'pkl'
|
|
22
|
+
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PretrainDataset(Dataset):
|
|
27
|
+
"""
|
|
28
|
+
适用于pretrain阶段,数据格式支持jsonl和pkl,如果是jsonl会在init阶段全部encode成token
|
|
29
|
+
1. npy:【推荐】numpy 数组,支持 mmap,内存占用极低
|
|
30
|
+
2. jsonl: {'text': 'text1'}\n{'text': 'text2'}
|
|
31
|
+
3. pkl: [0, 1, 2, 3 ...]
|
|
32
|
+
"""
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
file_path,
|
|
36
|
+
block_size,
|
|
37
|
+
stride
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.block_size = block_size
|
|
42
|
+
self.stride = stride
|
|
43
|
+
self.use_mmap = False
|
|
44
|
+
|
|
45
|
+
file_type = _get_file_type(file_path)
|
|
46
|
+
|
|
47
|
+
if file_type == 'npy':
|
|
48
|
+
self.input_ids = np.load(file_path, mmap_mode='r')
|
|
49
|
+
self.use_mmap = True
|
|
50
|
+
elif file_type == 'jsonl':
|
|
51
|
+
tokens = []
|
|
52
|
+
with open(file_path, 'r') as f:
|
|
53
|
+
for line in f:
|
|
54
|
+
tokens.extend(TrainerTools().tokenizer.encode(json.loads(line.strip())['text']))
|
|
55
|
+
self.input_ids = torch.tensor(tokens, dtype=torch.int32)
|
|
56
|
+
del tokens
|
|
57
|
+
elif file_type == 'pkl':
|
|
58
|
+
with open(file_path, 'rb') as f:
|
|
59
|
+
tokens = pickle.load(f)
|
|
60
|
+
self.input_ids = torch.tensor(tokens, dtype=torch.int32)
|
|
61
|
+
del tokens
|
|
62
|
+
else:
|
|
63
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
64
|
+
|
|
65
|
+
if len(self.input_ids) < block_size:
|
|
66
|
+
self.length = 0
|
|
67
|
+
else:
|
|
68
|
+
self.length = (len(self.input_ids) - block_size) // stride + 1
|
|
69
|
+
|
|
70
|
+
def __len__(self):
|
|
71
|
+
return self.length
|
|
72
|
+
|
|
73
|
+
def __getitem__(self, item):
|
|
74
|
+
if item < 0 or item >= self.length:
|
|
75
|
+
raise IndexError(f"Index {item} out of range")
|
|
76
|
+
|
|
77
|
+
start_idx = item * self.stride
|
|
78
|
+
end_idx = start_idx + self.block_size
|
|
79
|
+
|
|
80
|
+
data = self.input_ids[start_idx:end_idx]
|
|
81
|
+
|
|
82
|
+
if self.use_mmap:
|
|
83
|
+
return torch.from_numpy(data.astype(np.int64))
|
|
84
|
+
else:
|
|
85
|
+
return data.long()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class SFTDataset(Dataset):
|
|
89
|
+
"""
|
|
90
|
+
适用于sft阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
91
|
+
npy: [
|
|
92
|
+
[0, 1, 2, 3],
|
|
93
|
+
[4, 5, 6, 7]
|
|
94
|
+
]
|
|
95
|
+
jsonl: [
|
|
96
|
+
{'role': 'system', 'content': 'system_content'},
|
|
97
|
+
{'role': 'user', 'content': 'user_content'},
|
|
98
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
|
|
99
|
+
]\n
|
|
100
|
+
[
|
|
101
|
+
{'role': 'system', 'content': 'system_content'},
|
|
102
|
+
{'role': 'user', 'content': 'user_content'},
|
|
103
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}
|
|
104
|
+
]
|
|
105
|
+
pkl: [
|
|
106
|
+
[0, 1, 2, 3],
|
|
107
|
+
[4, 5, 6, 7]
|
|
108
|
+
]
|
|
109
|
+
"""
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
file_path,
|
|
113
|
+
max_len,
|
|
114
|
+
image_tags_file_path=None,
|
|
115
|
+
tokens_per_image=-1
|
|
116
|
+
):
|
|
117
|
+
super().__init__()
|
|
118
|
+
|
|
119
|
+
self.max_len = max_len
|
|
120
|
+
self.tokens_per_image = tokens_per_image
|
|
121
|
+
self.input_ids = []
|
|
122
|
+
self.image_tags = []
|
|
123
|
+
self.plain_text = False
|
|
124
|
+
|
|
125
|
+
file_type = _get_file_type(file_path)
|
|
126
|
+
|
|
127
|
+
if file_type == 'npy':
|
|
128
|
+
try:
|
|
129
|
+
self.input_ids = np.load(file_path, mmap_mode='r')
|
|
130
|
+
except ValueError:
|
|
131
|
+
self.input_ids = np.load(file_path, allow_pickle=True)
|
|
132
|
+
elif file_type == 'jsonl':
|
|
133
|
+
self.plain_text = True
|
|
134
|
+
with open(file_path, 'r') as f:
|
|
135
|
+
for line in f:
|
|
136
|
+
self.input_ids.append(json.loads(line.strip()))
|
|
137
|
+
elif file_type == 'pkl':
|
|
138
|
+
with open(file_path, 'rb') as f:
|
|
139
|
+
self.input_ids = pickle.load(f)
|
|
140
|
+
else:
|
|
141
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
142
|
+
|
|
143
|
+
if image_tags_file_path:
|
|
144
|
+
with open(image_tags_file_path, 'r') as f:
|
|
145
|
+
csv_reader = csv.reader(f)
|
|
146
|
+
for line in csv_reader:
|
|
147
|
+
self.image_tags.append(line[0])
|
|
148
|
+
|
|
149
|
+
def __len__(self):
|
|
150
|
+
return len(self.input_ids)
|
|
151
|
+
|
|
152
|
+
def __getitem__(self, item):
|
|
153
|
+
if self.plain_text:
|
|
154
|
+
inputs = TrainerTools().tokenizer.apply_chat_template(self.input_ids[item])
|
|
155
|
+
else:
|
|
156
|
+
inputs = self.input_ids[item]
|
|
157
|
+
|
|
158
|
+
if isinstance(inputs, np.ndarray):
|
|
159
|
+
inputs = torch.from_numpy(inputs.astype(np.int64))
|
|
160
|
+
else:
|
|
161
|
+
inputs = torch.tensor(inputs).long()
|
|
162
|
+
|
|
163
|
+
image_tag = self.image_tags[item] if self.image_tags else None
|
|
164
|
+
|
|
165
|
+
if self.tokens_per_image != -1:
|
|
166
|
+
inputs = repeat_image_tok(inputs, self.tokens_per_image)
|
|
167
|
+
else:
|
|
168
|
+
image_tag = None
|
|
169
|
+
|
|
170
|
+
inputs = inputs[:self.max_len]
|
|
171
|
+
|
|
172
|
+
return {
|
|
173
|
+
'inputs': inputs,
|
|
174
|
+
'image_tag': image_tag
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class DPODataset(Dataset):
|
|
179
|
+
"""
|
|
180
|
+
适用于dpo阶段,数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
181
|
+
npy: [
|
|
182
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
183
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
184
|
+
]
|
|
185
|
+
jsonl: {'chosen':
|
|
186
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
187
|
+
{'role': 'user', 'content': 'user_content'},
|
|
188
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
189
|
+
'rejected':
|
|
190
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
191
|
+
{'role': 'user', 'content': 'user_content'},
|
|
192
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
193
|
+
}\n
|
|
194
|
+
{'chosen':
|
|
195
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
196
|
+
{'role': 'user', 'content': 'user_content'},
|
|
197
|
+
{'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
198
|
+
'rejected':
|
|
199
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
200
|
+
{'role': 'user', 'content': 'user_content'},
|
|
201
|
+
'role': 'assistant', 'think': 'think_content', 'content': 'assistant_content'}],
|
|
202
|
+
}
|
|
203
|
+
pkl: [
|
|
204
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
205
|
+
{'chosen': xxx, 'rejected': xxx},
|
|
206
|
+
]
|
|
207
|
+
"""
|
|
208
|
+
def __init__(self, file_path, max_len):
|
|
209
|
+
self.max_len = max_len
|
|
210
|
+
self.data = []
|
|
211
|
+
self.plain_text = False
|
|
212
|
+
|
|
213
|
+
file_type = _get_file_type(file_path)
|
|
214
|
+
|
|
215
|
+
if file_type == 'npy':
|
|
216
|
+
try:
|
|
217
|
+
self.data = np.load(file_path, mmap_mode='r')
|
|
218
|
+
except ValueError:
|
|
219
|
+
self.data = np.load(file_path, allow_pickle=True)
|
|
220
|
+
elif file_type == 'jsonl':
|
|
221
|
+
self.plain_text = True
|
|
222
|
+
with open(file_path, 'r') as f:
|
|
223
|
+
for line in f:
|
|
224
|
+
self.data.append(json.loads(line.strip()))
|
|
225
|
+
elif file_type == 'pkl':
|
|
226
|
+
with open(file_path, 'rb') as f:
|
|
227
|
+
self.data = pickle.load(f)
|
|
228
|
+
else:
|
|
229
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
230
|
+
|
|
231
|
+
def __len__(self):
|
|
232
|
+
return len(self.data)
|
|
233
|
+
|
|
234
|
+
def __getitem__(self, item):
|
|
235
|
+
record = self.data[item]
|
|
236
|
+
|
|
237
|
+
chosen_raw = record['chosen']
|
|
238
|
+
rejected_raw = record['rejected']
|
|
239
|
+
|
|
240
|
+
if self.plain_text:
|
|
241
|
+
chosen_id = TrainerTools().tokenizer.apply_chat_template(chosen_raw)
|
|
242
|
+
rejected_id = TrainerTools().tokenizer.apply_chat_template(rejected_raw)
|
|
243
|
+
else:
|
|
244
|
+
chosen_id = chosen_raw
|
|
245
|
+
rejected_id = rejected_raw
|
|
246
|
+
|
|
247
|
+
if isinstance(chosen_id, np.ndarray): chosen_id = chosen_id.tolist()
|
|
248
|
+
if isinstance(rejected_id, np.ndarray): rejected_id = rejected_id.tolist()
|
|
249
|
+
|
|
250
|
+
return {
|
|
251
|
+
'chosen': chosen_id[:self.max_len],
|
|
252
|
+
'rejected': rejected_id[:self.max_len]
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class RLDataset(Dataset):
|
|
257
|
+
"""
|
|
258
|
+
适用于RL阶段(例如:PPO、GRPO、GSPO),数据格式支持jsonl和pkl,如果是jsonl,则会在getitem阶段encode成token
|
|
259
|
+
npy: [
|
|
260
|
+
{'prompt': xxx, 'answer': xxx},
|
|
261
|
+
{'prompt': xxx, 'answer': xxx},
|
|
262
|
+
]
|
|
263
|
+
jsonl: {'prompt':
|
|
264
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
265
|
+
{'role': 'user', 'content': 'user_content'}]
|
|
266
|
+
'answer': '10'
|
|
267
|
+
}\n
|
|
268
|
+
{'prompt':
|
|
269
|
+
[{'role': 'system', 'content': 'system_content'},
|
|
270
|
+
{'role': 'user', 'content': 'user_content'}]
|
|
271
|
+
'answer': '10'
|
|
272
|
+
}
|
|
273
|
+
pkl: [
|
|
274
|
+
{'prompt': xxx, 'answer': xxx},
|
|
275
|
+
{'prompt': xxx, 'answer': xxx},
|
|
276
|
+
]
|
|
277
|
+
"""
|
|
278
|
+
def __init__(self, file_path):
|
|
279
|
+
self.data = []
|
|
280
|
+
self.plain_text = False
|
|
281
|
+
|
|
282
|
+
file_type = _get_file_type(file_path)
|
|
283
|
+
|
|
284
|
+
if file_type == 'npy':
|
|
285
|
+
try:
|
|
286
|
+
self.data = np.load(file_path, mmap_mode='r')
|
|
287
|
+
except ValueError:
|
|
288
|
+
self.data = np.load(file_path, allow_pickle=True)
|
|
289
|
+
elif file_type == 'jsonl':
|
|
290
|
+
self.plain_text = True
|
|
291
|
+
|
|
292
|
+
with open(file_path, 'r') as f:
|
|
293
|
+
for line in f:
|
|
294
|
+
self.data.append(json.loads(line.strip()))
|
|
295
|
+
elif file_type == 'pkl':
|
|
296
|
+
with open(file_path, 'rb') as f:
|
|
297
|
+
self.data = pickle.load(f)
|
|
298
|
+
else:
|
|
299
|
+
raise Exception(f'unsupported file type for {file_path}')
|
|
300
|
+
|
|
301
|
+
def __len__(self):
|
|
302
|
+
return len(self.data)
|
|
303
|
+
|
|
304
|
+
def __getitem__(self, item):
|
|
305
|
+
record = self.data[item]
|
|
306
|
+
|
|
307
|
+
prompt_raw = record['prompt']
|
|
308
|
+
answer_raw = record.get('answer', None)
|
|
309
|
+
|
|
310
|
+
if self.plain_text:
|
|
311
|
+
question = TrainerTools().tokenizer.apply_chat_template(prompt_raw)
|
|
312
|
+
answer = TrainerTools().tokenizer.encode(answer_raw) if answer_raw else None
|
|
313
|
+
else:
|
|
314
|
+
question = prompt_raw
|
|
315
|
+
answer = answer_raw
|
|
316
|
+
|
|
317
|
+
# 转换为 Tensor
|
|
318
|
+
if isinstance(question, np.ndarray):
|
|
319
|
+
prompt_tensor = torch.from_numpy(question.astype(np.int64))
|
|
320
|
+
else:
|
|
321
|
+
prompt_tensor = torch.tensor(question).long()
|
|
322
|
+
|
|
323
|
+
if answer is not None:
|
|
324
|
+
if isinstance(answer, np.ndarray):
|
|
325
|
+
answer_tensor = torch.from_numpy(answer.astype(np.int64))
|
|
326
|
+
else:
|
|
327
|
+
answer_tensor = torch.tensor(answer).long()
|
|
328
|
+
else:
|
|
329
|
+
answer_tensor = None
|
|
330
|
+
|
|
331
|
+
return {
|
|
332
|
+
'prompt': prompt_tensor,
|
|
333
|
+
'answer': answer_tensor
|
|
334
|
+
}
|
|
335
|
+
|