doc-page-extractor 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- doc_page_extractor/model.py +32 -13
- {doc_page_extractor-0.2.2.dist-info → doc_page_extractor-0.2.3.dist-info}/METADATA +26 -23
- {doc_page_extractor-0.2.2.dist-info → doc_page_extractor-0.2.3.dist-info}/RECORD +13 -24
- {doc_page_extractor-0.2.2.dist-info → doc_page_extractor-0.2.3.dist-info}/WHEEL +1 -2
- doc_page_extractor/struct_eqtable/__init__.py +0 -49
- doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
- doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
- doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
- doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
- doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
- doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
- doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
- doc_page_extractor-0.2.2.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -0
- tests/test_history_bus.py +0 -55
- {doc_page_extractor-0.2.2.dist-info/licenses → doc_page_extractor-0.2.3.dist-info}/LICENSE +0 -0
|
@@ -1,1047 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import time
|
|
3
|
-
import json
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
|
|
8
|
-
try:
|
|
9
|
-
import tensorrt_llm
|
|
10
|
-
import tensorrt as trt
|
|
11
|
-
import tensorrt_llm.profiler as profiler
|
|
12
|
-
|
|
13
|
-
from tensorrt_llm._utils import str_dtype_to_trt, torch_to_numpy
|
|
14
|
-
from tensorrt_llm.lora_manager import LoraManager
|
|
15
|
-
from tensorrt_llm.runtime import Session, TensorInfo, ModelConfig, SamplingConfig
|
|
16
|
-
except:
|
|
17
|
-
print("\033[93mimport tensorrt_llm failed, if do not use tensorrt, ignore this message\033[0m")
|
|
18
|
-
|
|
19
|
-
from typing import List
|
|
20
|
-
from transformers import AutoProcessor, AutoTokenizer, AutoConfig
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def trt_dtype_to_torch(dtype):
|
|
24
|
-
if dtype == trt.float16:
|
|
25
|
-
return torch.float16
|
|
26
|
-
elif dtype == trt.float32:
|
|
27
|
-
return torch.float32
|
|
28
|
-
elif dtype == trt.int32:
|
|
29
|
-
return torch.int32
|
|
30
|
-
elif dtype == trt.bfloat16:
|
|
31
|
-
return torch.bfloat16
|
|
32
|
-
else:
|
|
33
|
-
raise TypeError("%s is not supported" % dtype)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class Pix2StructTensorRT(nn.Module):
|
|
37
|
-
|
|
38
|
-
def __init__(
|
|
39
|
-
self,
|
|
40
|
-
model_path,
|
|
41
|
-
tensorrt_path,
|
|
42
|
-
batch_size=1,
|
|
43
|
-
max_new_tokens=4096,
|
|
44
|
-
cache_dir=None,
|
|
45
|
-
local_files_only=None,
|
|
46
|
-
**kwargs,
|
|
47
|
-
):
|
|
48
|
-
self.model_ckpt_path = model_path
|
|
49
|
-
self.tensorrt_path = tensorrt_path
|
|
50
|
-
self.batch_size = batch_size
|
|
51
|
-
self.max_new_tokens = max_new_tokens
|
|
52
|
-
self.cache_dir = cache_dir
|
|
53
|
-
self.local_files_only = local_files_only
|
|
54
|
-
|
|
55
|
-
self.llm_engine_path = os.path.join(tensorrt_path, 'llm_engines')
|
|
56
|
-
self.visual_engine_path = os.path.join(tensorrt_path, 'visual_engines')
|
|
57
|
-
|
|
58
|
-
device_id = torch.cuda.current_device() % torch.cuda.device_count()
|
|
59
|
-
self.device_id = device_id
|
|
60
|
-
self.device = "cuda:%d" % (device_id)
|
|
61
|
-
|
|
62
|
-
self.stream = torch.cuda.Stream(torch.cuda.current_device())
|
|
63
|
-
torch.cuda.set_stream(self.stream)
|
|
64
|
-
|
|
65
|
-
# parse model type from visual engine config
|
|
66
|
-
with open(os.path.join(self.visual_engine_path, "config.json"),
|
|
67
|
-
"r") as f:
|
|
68
|
-
config = json.load(f)
|
|
69
|
-
self.model_type = config['builder_config']['model_type']
|
|
70
|
-
self.vision_precision = config['builder_config']['precision']
|
|
71
|
-
|
|
72
|
-
self.vision_precision = 'float16'
|
|
73
|
-
self.decoder_llm = not (
|
|
74
|
-
't5' in self.model_type
|
|
75
|
-
or self.model_type in ['nougat', 'pix2struct', 'StructEqTable']
|
|
76
|
-
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
|
|
77
|
-
|
|
78
|
-
self.profiling_iterations = 20
|
|
79
|
-
|
|
80
|
-
self.init_image_encoder()
|
|
81
|
-
self.init_tokenizer()
|
|
82
|
-
self.init_llm()
|
|
83
|
-
self.init_image_processor()
|
|
84
|
-
|
|
85
|
-
self.special_str_list = ['\\midrule', '\\hline']
|
|
86
|
-
self.supported_output_format = ['latex']
|
|
87
|
-
|
|
88
|
-
def postprocess_latex_code(self, code):
|
|
89
|
-
for special_str in self.special_str_list:
|
|
90
|
-
code = code.replace(special_str, special_str + ' ')
|
|
91
|
-
return code
|
|
92
|
-
|
|
93
|
-
def init_image_processor(self):
|
|
94
|
-
self.data_processor = AutoProcessor.from_pretrained(
|
|
95
|
-
pretrained_model_name_or_path=self.model_ckpt_path,
|
|
96
|
-
cache_dir=self.cache_dir,
|
|
97
|
-
local_files_only=self.local_files_only,
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
def init_tokenizer(self):
|
|
101
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
102
|
-
pretrained_model_name_or_path=self.model_ckpt_path,
|
|
103
|
-
use_fast=True,
|
|
104
|
-
use_legacy=False,
|
|
105
|
-
cache_dir=self.cache_dir,
|
|
106
|
-
local_files_only=self.local_files_only,
|
|
107
|
-
)
|
|
108
|
-
# self.tokenizer.padding_side = "right"
|
|
109
|
-
|
|
110
|
-
def init_image_encoder(self):
|
|
111
|
-
vision_encoder_path = os.path.join(self.visual_engine_path,
|
|
112
|
-
'visual_encoder.engine')
|
|
113
|
-
with open(vision_encoder_path, 'rb') as f:
|
|
114
|
-
engine_buffer = f.read()
|
|
115
|
-
self.visual_encoder_session = Session.from_serialized_engine(
|
|
116
|
-
engine_buffer)
|
|
117
|
-
|
|
118
|
-
def init_llm(self):
|
|
119
|
-
|
|
120
|
-
self.model = TRTLLMEncDecModel.from_engine(
|
|
121
|
-
os.path.basename(self.model_ckpt_path),
|
|
122
|
-
self.llm_engine_path,
|
|
123
|
-
skip_encoder=self.model_type in ['nougat', 'pix2struct', 'StructEqTable'],
|
|
124
|
-
debug_mode=False,
|
|
125
|
-
stream=self.stream)
|
|
126
|
-
|
|
127
|
-
self.model_config = self.model.decoder_model_config
|
|
128
|
-
self.runtime_mapping = self.model.decoder_runtime_mapping
|
|
129
|
-
|
|
130
|
-
def __call__(self, image, **kwargs):
|
|
131
|
-
# process image to tokens
|
|
132
|
-
image_tokens = self.data_processor.image_processor(
|
|
133
|
-
images=image,
|
|
134
|
-
return_tensors='pt',
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
for k, v in image_tokens.items():
|
|
138
|
-
image_tokens[k] = v.cuda()
|
|
139
|
-
|
|
140
|
-
model_output = self.run(
|
|
141
|
-
flattened_patches=image_tokens['flattened_patches'],
|
|
142
|
-
attention_mask=image_tokens['attention_mask'],
|
|
143
|
-
max_new_tokens=self.max_new_tokens
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
# postprocess
|
|
147
|
-
latex_codes = []
|
|
148
|
-
for i, code in enumerate(model_output):
|
|
149
|
-
latex_codes.append(self.postprocess_latex_code(code[0]))
|
|
150
|
-
|
|
151
|
-
return latex_codes
|
|
152
|
-
|
|
153
|
-
def preprocess(self, warmup, pre_prompt, post_prompt, image,
|
|
154
|
-
attention_mask):
|
|
155
|
-
if not warmup:
|
|
156
|
-
profiler.start("Vision")
|
|
157
|
-
|
|
158
|
-
visual_features, visual_atts = self.get_visual_features(
|
|
159
|
-
torch.stack(image['image_patches'], dim=0)
|
|
160
|
-
if self.model_type == 'fuyu' else image, attention_mask)
|
|
161
|
-
|
|
162
|
-
if not warmup:
|
|
163
|
-
profiler.stop("Vision")
|
|
164
|
-
|
|
165
|
-
pre_input_ids = self.tokenizer(pre_prompt,
|
|
166
|
-
return_tensors="pt",
|
|
167
|
-
padding=True).input_ids
|
|
168
|
-
if post_prompt[0] is not None:
|
|
169
|
-
post_input_ids = self.tokenizer(post_prompt,
|
|
170
|
-
return_tensors="pt",
|
|
171
|
-
padding=True).input_ids
|
|
172
|
-
length = pre_input_ids.shape[1] + post_input_ids.shape[
|
|
173
|
-
1] + visual_atts.shape[1]
|
|
174
|
-
else:
|
|
175
|
-
post_input_ids = None
|
|
176
|
-
length = pre_input_ids.shape[1] + visual_atts.shape[1]
|
|
177
|
-
|
|
178
|
-
input_lengths = torch.IntTensor([length] * 1).to(
|
|
179
|
-
torch.int32)
|
|
180
|
-
|
|
181
|
-
input_ids, ptuning_args = self.setup_fake_prompts(
|
|
182
|
-
visual_features, pre_input_ids, post_input_ids, input_lengths)
|
|
183
|
-
|
|
184
|
-
return input_ids, input_lengths, ptuning_args, visual_features
|
|
185
|
-
|
|
186
|
-
def generate(self, pre_prompt, post_prompt, image, decoder_input_ids,
|
|
187
|
-
max_new_tokens, attention_mask, warmup):
|
|
188
|
-
if not warmup:
|
|
189
|
-
profiler.start("Generate")
|
|
190
|
-
|
|
191
|
-
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
|
|
192
|
-
warmup, pre_prompt, post_prompt, image, attention_mask)
|
|
193
|
-
|
|
194
|
-
if warmup: return None
|
|
195
|
-
|
|
196
|
-
profiler.start("LLM")
|
|
197
|
-
|
|
198
|
-
# Trim encoder input_ids to match visual features shape
|
|
199
|
-
ids_shape = (self.batch_size, visual_features.shape[1])
|
|
200
|
-
|
|
201
|
-
input_ids = torch.ones(ids_shape, dtype=torch.int32)
|
|
202
|
-
|
|
203
|
-
output_ids = self.model.generate(
|
|
204
|
-
input_ids,
|
|
205
|
-
decoder_input_ids,
|
|
206
|
-
max_new_tokens,
|
|
207
|
-
num_beams=1,
|
|
208
|
-
bos_token_id=self.tokenizer.bos_token_id,
|
|
209
|
-
pad_token_id=self.tokenizer.pad_token_id,
|
|
210
|
-
eos_token_id=self.tokenizer.eos_token_id,
|
|
211
|
-
debug_mode=False,
|
|
212
|
-
prompt_embedding_table=ptuning_args[0],
|
|
213
|
-
prompt_tasks=ptuning_args[1],
|
|
214
|
-
prompt_vocab_size=ptuning_args[2],
|
|
215
|
-
attention_mask=attention_mask)
|
|
216
|
-
|
|
217
|
-
# Reset input_lengths to match decoder_input_ids
|
|
218
|
-
input_lengths = torch.ones(input_lengths.shape,
|
|
219
|
-
dtype=input_lengths.dtype)
|
|
220
|
-
profiler.stop("LLM")
|
|
221
|
-
|
|
222
|
-
if tensorrt_llm.mpi_rank() == 0:
|
|
223
|
-
# Extract a list of tensors of shape beam_width x output_ids.
|
|
224
|
-
output_beams_list = [
|
|
225
|
-
self.tokenizer.batch_decode(
|
|
226
|
-
output_ids[batch_idx, :, input_lengths[batch_idx]:],
|
|
227
|
-
skip_special_tokens=True)
|
|
228
|
-
for batch_idx in range(self.batch_size)
|
|
229
|
-
]
|
|
230
|
-
|
|
231
|
-
stripped_text = [[
|
|
232
|
-
output_beams_list[batch_idx][beam_idx].strip()
|
|
233
|
-
for beam_idx in range(1)
|
|
234
|
-
] for batch_idx in range(self.batch_size)]
|
|
235
|
-
profiler.stop("Generate")
|
|
236
|
-
return stripped_text
|
|
237
|
-
else:
|
|
238
|
-
profiler.stop("Generate")
|
|
239
|
-
return None
|
|
240
|
-
|
|
241
|
-
def get_visual_features(self, image, attention_mask):
|
|
242
|
-
visual_features = {
|
|
243
|
-
'input':
|
|
244
|
-
image.to(
|
|
245
|
-
tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))
|
|
246
|
-
}
|
|
247
|
-
if attention_mask is not None:
|
|
248
|
-
visual_features['attention_mask'] = attention_mask
|
|
249
|
-
tensor_info = [
|
|
250
|
-
TensorInfo('input', str_dtype_to_trt(self.vision_precision),
|
|
251
|
-
image.shape)
|
|
252
|
-
]
|
|
253
|
-
if attention_mask is not None:
|
|
254
|
-
tensor_info.append(
|
|
255
|
-
TensorInfo('attention_mask', trt.DataType.INT32,
|
|
256
|
-
attention_mask.shape))
|
|
257
|
-
visual_output_info = self.visual_encoder_session.infer_shapes(
|
|
258
|
-
tensor_info)
|
|
259
|
-
visual_outputs = {
|
|
260
|
-
t.name: torch.empty(tuple(t.shape),
|
|
261
|
-
dtype=trt_dtype_to_torch(t.dtype),
|
|
262
|
-
device=image.device)
|
|
263
|
-
for t in visual_output_info
|
|
264
|
-
}
|
|
265
|
-
|
|
266
|
-
ok = self.visual_encoder_session.run(visual_features, visual_outputs,
|
|
267
|
-
self.stream.cuda_stream)
|
|
268
|
-
assert ok, "Runtime execution failed for vision encoder session"
|
|
269
|
-
self.stream.synchronize()
|
|
270
|
-
|
|
271
|
-
image_embeds = visual_outputs['output']
|
|
272
|
-
image_atts = torch.ones(image_embeds.size()[:-1],
|
|
273
|
-
dtype=torch.long).to(image.device)
|
|
274
|
-
|
|
275
|
-
return image_embeds, image_atts
|
|
276
|
-
|
|
277
|
-
def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids,
|
|
278
|
-
input_lengths):
|
|
279
|
-
# Assemble fake prompts which points to image embedding actually
|
|
280
|
-
fake_prompt_id = torch.arange(
|
|
281
|
-
self.model_config.vocab_size, self.model_config.vocab_size +
|
|
282
|
-
visual_features.shape[0] * visual_features.shape[1])
|
|
283
|
-
fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0],
|
|
284
|
-
visual_features.shape[1])
|
|
285
|
-
|
|
286
|
-
if post_input_ids is not None:
|
|
287
|
-
input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
|
|
288
|
-
else:
|
|
289
|
-
input_ids = [fake_prompt_id, pre_input_ids]
|
|
290
|
-
|
|
291
|
-
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
|
|
292
|
-
|
|
293
|
-
if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
|
|
294
|
-
ptuning_args = self.ptuning_setup(visual_features, input_ids,
|
|
295
|
-
input_lengths)
|
|
296
|
-
else:
|
|
297
|
-
ptuning_args = [None, None, None]
|
|
298
|
-
|
|
299
|
-
return input_ids, ptuning_args
|
|
300
|
-
|
|
301
|
-
def ptuning_setup(self, prompt_table, input_ids, input_lengths):
|
|
302
|
-
hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
|
|
303
|
-
if prompt_table is not None:
|
|
304
|
-
task_vocab_size = torch.tensor(
|
|
305
|
-
[prompt_table.shape[1]],
|
|
306
|
-
dtype=torch.int32,
|
|
307
|
-
).cuda()
|
|
308
|
-
prompt_table = prompt_table.view(
|
|
309
|
-
(prompt_table.shape[0] * prompt_table.shape[1],
|
|
310
|
-
prompt_table.shape[2]))
|
|
311
|
-
assert prompt_table.shape[
|
|
312
|
-
1] == hidden_size, "Prompt table dimensions do not match hidden size"
|
|
313
|
-
|
|
314
|
-
prompt_table = prompt_table.cuda().to(
|
|
315
|
-
dtype=tensorrt_llm._utils.str_dtype_to_torch(
|
|
316
|
-
self.model_config.dtype))
|
|
317
|
-
else:
|
|
318
|
-
prompt_table = torch.empty([1, hidden_size]).cuda()
|
|
319
|
-
task_vocab_size = torch.zeros([1]).cuda()
|
|
320
|
-
|
|
321
|
-
if self.model_config.remove_input_padding:
|
|
322
|
-
tasks = torch.zeros([torch.sum(input_lengths)],
|
|
323
|
-
dtype=torch.int32).cuda()
|
|
324
|
-
if self.decoder_llm: tasks = tasks.unsqueeze(0)
|
|
325
|
-
else:
|
|
326
|
-
tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
|
|
327
|
-
|
|
328
|
-
return [prompt_table, tasks, task_vocab_size]
|
|
329
|
-
|
|
330
|
-
def setup_inputs(self, input_text, raw_image):
|
|
331
|
-
attention_mask = None
|
|
332
|
-
|
|
333
|
-
image_processor = AutoProcessor.from_pretrained(
|
|
334
|
-
pretrained_model_name_or_path=self.model_ckpt_path,
|
|
335
|
-
cache_dir=self.cache_dir,
|
|
336
|
-
local_files_only=self.local_files_only,
|
|
337
|
-
)
|
|
338
|
-
if input_text is None:
|
|
339
|
-
input_text = ""
|
|
340
|
-
inputs = image_processor(
|
|
341
|
-
images=raw_image,
|
|
342
|
-
text=input_text,
|
|
343
|
-
return_tensors="pt",
|
|
344
|
-
)
|
|
345
|
-
image = inputs['flattened_patches']
|
|
346
|
-
image = image.expand(self.batch_size, -1, -1).contiguous()
|
|
347
|
-
attention_mask = inputs['attention_mask'].to(self.device).to(
|
|
348
|
-
torch.int)
|
|
349
|
-
attention_mask = attention_mask.expand(self.batch_size,
|
|
350
|
-
-1).contiguous()
|
|
351
|
-
pre_prompt = ""
|
|
352
|
-
post_prompt = None
|
|
353
|
-
|
|
354
|
-
# Repeat inputs to match batch size
|
|
355
|
-
pre_prompt = [pre_prompt] * self.batch_size
|
|
356
|
-
post_prompt = [post_prompt] * self.batch_size
|
|
357
|
-
image = image.to(self.device)
|
|
358
|
-
|
|
359
|
-
# Generate decoder_input_ids for enc-dec models
|
|
360
|
-
# Custom prompts can be added as:
|
|
361
|
-
# decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
|
|
362
|
-
if self.decoder_llm:
|
|
363
|
-
decoder_input_ids = None
|
|
364
|
-
else:
|
|
365
|
-
config = AutoConfig.from_pretrained(
|
|
366
|
-
pretrained_model_name_or_path=self.model_ckpt_path,
|
|
367
|
-
cache_dir=self.cache_dir,
|
|
368
|
-
local_files_only=self.local_files_only,
|
|
369
|
-
)
|
|
370
|
-
decoder_start_id = config.decoder_start_token_id # T5
|
|
371
|
-
if decoder_start_id is None:
|
|
372
|
-
decoder_start_id = config.decoder.bos_token_id # Nougat
|
|
373
|
-
|
|
374
|
-
decoder_input_ids = torch.IntTensor([[decoder_start_id]])
|
|
375
|
-
decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))
|
|
376
|
-
|
|
377
|
-
return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask
|
|
378
|
-
|
|
379
|
-
def run(self, flattened_patches, attention_mask, max_new_tokens):
|
|
380
|
-
# input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs(
|
|
381
|
-
# None, raw_image)
|
|
382
|
-
pre_prompt = [""] * self.batch_size
|
|
383
|
-
post_prompt = [None] * self.batch_size
|
|
384
|
-
config = AutoConfig.from_pretrained(
|
|
385
|
-
pretrained_model_name_or_path=self.model_ckpt_path,
|
|
386
|
-
cache_dir=self.cache_dir,
|
|
387
|
-
local_files_only=self.local_files_only,
|
|
388
|
-
)
|
|
389
|
-
decoder_start_id = config.decoder_start_token_id # T5
|
|
390
|
-
decoder_input_ids = torch.IntTensor([[decoder_start_id]])
|
|
391
|
-
decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))
|
|
392
|
-
|
|
393
|
-
processed_image = flattened_patches.expand(self.batch_size, -1, -1).contiguous()
|
|
394
|
-
attention_mask = attention_mask.to(self.device).to(torch.int)
|
|
395
|
-
attention_mask = attention_mask.expand(self.batch_size,-1).contiguous()
|
|
396
|
-
|
|
397
|
-
self.generate(pre_prompt,
|
|
398
|
-
post_prompt,
|
|
399
|
-
processed_image,
|
|
400
|
-
decoder_input_ids,
|
|
401
|
-
max_new_tokens,
|
|
402
|
-
attention_mask=attention_mask,
|
|
403
|
-
warmup=True)
|
|
404
|
-
# num_iters = self.profiling_iterations if self.args.run_profiling else 1
|
|
405
|
-
num_iters = 1
|
|
406
|
-
# print(num_iters)
|
|
407
|
-
for _ in range(num_iters):
|
|
408
|
-
output_text = self.generate(pre_prompt,
|
|
409
|
-
post_prompt,
|
|
410
|
-
processed_image,
|
|
411
|
-
decoder_input_ids,
|
|
412
|
-
max_new_tokens,
|
|
413
|
-
attention_mask=attention_mask,
|
|
414
|
-
warmup=False)
|
|
415
|
-
# if self.runtime_rank == 0:
|
|
416
|
-
# self.print_result(input_text, output_text)
|
|
417
|
-
return output_text
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
def read_config(config_path):
|
|
421
|
-
with open(config_path, "r") as f:
|
|
422
|
-
config = json.load(f)
|
|
423
|
-
|
|
424
|
-
builder_config = config['build_config']
|
|
425
|
-
plugin_config = builder_config['plugin_config']
|
|
426
|
-
pretrained_config = config['pretrained_config']
|
|
427
|
-
lora_config = builder_config['lora_config']
|
|
428
|
-
auto_parallel_config = builder_config['auto_parallel_config']
|
|
429
|
-
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
|
|
430
|
-
remove_input_padding = plugin_config["remove_input_padding"]
|
|
431
|
-
use_lora_plugin = plugin_config["lora_plugin"]
|
|
432
|
-
tp_size = pretrained_config['mapping']['tp_size']
|
|
433
|
-
pp_size = pretrained_config['mapping']['pp_size']
|
|
434
|
-
gpus_per_node = auto_parallel_config['gpus_per_node']
|
|
435
|
-
world_size = tp_size * pp_size
|
|
436
|
-
assert world_size == tensorrt_llm.mpi_world_size(), \
|
|
437
|
-
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
|
438
|
-
num_heads = pretrained_config["num_attention_heads"]
|
|
439
|
-
hidden_size = pretrained_config["hidden_size"]
|
|
440
|
-
head_size = pretrained_config["head_size"]
|
|
441
|
-
vocab_size = pretrained_config["vocab_size"]
|
|
442
|
-
max_batch_size = builder_config["max_batch_size"]
|
|
443
|
-
max_beam_width = builder_config["max_beam_width"]
|
|
444
|
-
num_layers = pretrained_config["num_hidden_layers"]
|
|
445
|
-
num_kv_heads = pretrained_config.get('num_kv_heads', num_heads)
|
|
446
|
-
|
|
447
|
-
assert (num_heads % tp_size) == 0
|
|
448
|
-
num_heads = num_heads // tp_size
|
|
449
|
-
hidden_size = hidden_size // tp_size
|
|
450
|
-
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
|
|
451
|
-
|
|
452
|
-
cross_attention = pretrained_config["architecture"] == "DecoderModel"
|
|
453
|
-
skip_cross_qkv = pretrained_config.get('skip_cross_qkv', False)
|
|
454
|
-
has_position_embedding = pretrained_config["has_position_embedding"]
|
|
455
|
-
has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
|
|
456
|
-
use_custom_all_reduce = plugin_config.get('use_custom_all_reduce', False)
|
|
457
|
-
dtype = pretrained_config["dtype"]
|
|
458
|
-
|
|
459
|
-
paged_kv_cache = plugin_config['paged_kv_cache']
|
|
460
|
-
tokens_per_block = plugin_config['tokens_per_block']
|
|
461
|
-
|
|
462
|
-
gather_context_logits = builder_config.get('gather_context_logits', False)
|
|
463
|
-
gather_generation_logits = builder_config.get('gather_generation_logits',
|
|
464
|
-
False)
|
|
465
|
-
max_prompt_embedding_table_size = builder_config.get(
|
|
466
|
-
'max_prompt_embedding_table_size', 0)
|
|
467
|
-
|
|
468
|
-
model_config = ModelConfig(
|
|
469
|
-
num_heads=num_heads,
|
|
470
|
-
num_kv_heads=num_kv_heads,
|
|
471
|
-
hidden_size=hidden_size,
|
|
472
|
-
head_size=head_size,
|
|
473
|
-
max_batch_size=max_batch_size,
|
|
474
|
-
max_beam_width=max_beam_width,
|
|
475
|
-
vocab_size=vocab_size,
|
|
476
|
-
num_layers=num_layers,
|
|
477
|
-
gpt_attention_plugin=use_gpt_attention_plugin,
|
|
478
|
-
remove_input_padding=remove_input_padding,
|
|
479
|
-
paged_kv_cache=paged_kv_cache,
|
|
480
|
-
tokens_per_block=tokens_per_block,
|
|
481
|
-
cross_attention=cross_attention,
|
|
482
|
-
has_position_embedding=has_position_embedding,
|
|
483
|
-
has_token_type_embedding=has_token_type_embedding,
|
|
484
|
-
use_custom_all_reduce=use_custom_all_reduce,
|
|
485
|
-
dtype=dtype,
|
|
486
|
-
gather_context_logits=gather_context_logits,
|
|
487
|
-
gather_generation_logits=gather_generation_logits,
|
|
488
|
-
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
|
|
489
|
-
lora_plugin=use_lora_plugin,
|
|
490
|
-
lora_target_modules=lora_config.get('lora_target_modules'),
|
|
491
|
-
trtllm_modules_to_hf_modules=lora_config.get(
|
|
492
|
-
'trtllm_modules_to_hf_modules'),
|
|
493
|
-
skip_cross_qkv=skip_cross_qkv,
|
|
494
|
-
)
|
|
495
|
-
|
|
496
|
-
return model_config, tp_size, pp_size, gpus_per_node, dtype
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
class Mapping(object):
|
|
500
|
-
def __init__(
|
|
501
|
-
self,
|
|
502
|
-
world_size=1,
|
|
503
|
-
rank=0,
|
|
504
|
-
gpus_per_node=8,
|
|
505
|
-
tp_size=1,
|
|
506
|
-
pp_size=1,
|
|
507
|
-
moe_tp_size=-1, # -1 means no moe
|
|
508
|
-
moe_ep_size=-1): # -1 means no moe
|
|
509
|
-
# set default values for non-moe cases
|
|
510
|
-
if moe_tp_size == -1:
|
|
511
|
-
moe_tp_size = tp_size
|
|
512
|
-
moe_ep_size = 1
|
|
513
|
-
|
|
514
|
-
if pp_size * tp_size != world_size:
|
|
515
|
-
raise ValueError(
|
|
516
|
-
f"world_size must equal to pp_size * tp_size, but got {world_size} != {pp_size} * {tp_size}"
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
moe_tp_ep_size = moe_tp_size * moe_ep_size
|
|
520
|
-
if moe_tp_ep_size != tp_size:
|
|
521
|
-
raise ValueError(
|
|
522
|
-
f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
|
|
523
|
-
)
|
|
524
|
-
|
|
525
|
-
self.tp_size = tp_size
|
|
526
|
-
self.pp_size = pp_size
|
|
527
|
-
self.moe_tp_size = moe_tp_size
|
|
528
|
-
self.moe_ep_size = moe_ep_size
|
|
529
|
-
self.world_size = world_size
|
|
530
|
-
self.rank = rank
|
|
531
|
-
self.gpus_per_node = gpus_per_node
|
|
532
|
-
|
|
533
|
-
self.pp_groups = []
|
|
534
|
-
self.tp_groups = []
|
|
535
|
-
self.moe_tp_groups = []
|
|
536
|
-
self.moe_ep_groups = []
|
|
537
|
-
|
|
538
|
-
# init pp group
|
|
539
|
-
for i in range(tp_size):
|
|
540
|
-
ranks = range(i+ self.rank, world_size+ self.rank, tp_size)
|
|
541
|
-
self.pp_groups.append(list(ranks))
|
|
542
|
-
|
|
543
|
-
# init tp group
|
|
544
|
-
for i in range(pp_size):
|
|
545
|
-
ranks = range(i * tp_size + self.rank, (i + 1) * tp_size + self.rank)
|
|
546
|
-
self.tp_groups.append(list(ranks))
|
|
547
|
-
|
|
548
|
-
# init moe tp group
|
|
549
|
-
for i in range(pp_size):
|
|
550
|
-
for j in range(moe_ep_size):
|
|
551
|
-
ranks = range(i * moe_tp_ep_size + j, (i + 1) * moe_tp_ep_size,
|
|
552
|
-
moe_ep_size)
|
|
553
|
-
self.moe_tp_groups.append(list(ranks))
|
|
554
|
-
|
|
555
|
-
# init moe ep group
|
|
556
|
-
for i in range(pp_size):
|
|
557
|
-
for j in range(moe_tp_size):
|
|
558
|
-
ranks = range(i * moe_tp_ep_size + j * moe_ep_size,
|
|
559
|
-
i * moe_tp_ep_size + (j + 1) * moe_ep_size)
|
|
560
|
-
self.moe_ep_groups.append(list(ranks))
|
|
561
|
-
|
|
562
|
-
# self.pp_rank = self.rank // self.tp_size
|
|
563
|
-
# self.tp_rank = self.rank % self.tp_size
|
|
564
|
-
self.pp_rank = 0
|
|
565
|
-
self.tp_rank = 0
|
|
566
|
-
self.moe_tp_rank = self.tp_rank // self.moe_ep_size
|
|
567
|
-
self.moe_ep_rank = self.tp_rank % self.moe_ep_size
|
|
568
|
-
|
|
569
|
-
# self.tp_group = self.tp_groups[self.pp_rank]
|
|
570
|
-
# self.pp_group = self.pp_groups[self.tp_rank]
|
|
571
|
-
self.moe_tp_group = self.moe_tp_groups[self.pp_rank * moe_ep_size +
|
|
572
|
-
self.moe_ep_rank]
|
|
573
|
-
self.moe_ep_group = self.moe_ep_groups[self.pp_rank * moe_tp_size +
|
|
574
|
-
self.moe_tp_rank]
|
|
575
|
-
|
|
576
|
-
self.node_rank = self.rank // self.gpus_per_node
|
|
577
|
-
self.local_rank = self.rank % self.gpus_per_node
|
|
578
|
-
|
|
579
|
-
def get_node_rank(self, rank: int):
|
|
580
|
-
return rank // self.gpus_per_node
|
|
581
|
-
|
|
582
|
-
def get_local_rank(self, rank: int):
|
|
583
|
-
return rank % self.gpus_per_node
|
|
584
|
-
|
|
585
|
-
def has_tp(self):
|
|
586
|
-
return self.tp_size > 1
|
|
587
|
-
|
|
588
|
-
def is_last_pp_rank(self):
|
|
589
|
-
return self.pp_rank == self.pp_size - 1
|
|
590
|
-
|
|
591
|
-
def is_first_pp_rank(self):
|
|
592
|
-
return self.pp_rank == 0
|
|
593
|
-
|
|
594
|
-
def has_pp(self):
|
|
595
|
-
return self.pp_size > 1
|
|
596
|
-
|
|
597
|
-
def prev_pp_rank(self):
|
|
598
|
-
p = self.rank - self.tp_size
|
|
599
|
-
if p < 0:
|
|
600
|
-
p = p + self.world_size
|
|
601
|
-
return p
|
|
602
|
-
|
|
603
|
-
def next_pp_rank(self):
|
|
604
|
-
p = self.rank + self.tp_size
|
|
605
|
-
if p >= self.world_size:
|
|
606
|
-
p = p - self.world_size
|
|
607
|
-
return p
|
|
608
|
-
|
|
609
|
-
def has_moe_tp(self):
|
|
610
|
-
return self.moe_tp_size > 1
|
|
611
|
-
|
|
612
|
-
def has_moe_ep(self):
|
|
613
|
-
return self.moe_ep_size > 1
|
|
614
|
-
|
|
615
|
-
def pp_layers(self, num_layers: int) -> List[int]:
|
|
616
|
-
layers_per_pipeline_stage = num_layers // self.pp_size
|
|
617
|
-
layers_range = range(self.pp_rank * layers_per_pipeline_stage,
|
|
618
|
-
(self.pp_rank + 1) * layers_per_pipeline_stage)
|
|
619
|
-
return list(layers_range)
|
|
620
|
-
|
|
621
|
-
def ep_experts(self, num_experts: int) -> List[int]:
|
|
622
|
-
experts_per_rank = num_experts // self.moe_ep_size
|
|
623
|
-
experts_range = range(self.moe_ep_rank * experts_per_rank,
|
|
624
|
-
(self.moe_ep_rank + 1) * experts_per_rank)
|
|
625
|
-
return list(experts_range)
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
def get_engine_name(rank):
|
|
629
|
-
return 'rank{}.engine'.format(rank)
|
|
630
|
-
|
|
631
|
-
class TRTLLMEncDecModel:
|
|
632
|
-
|
|
633
|
-
def __init__(
|
|
634
|
-
self,
|
|
635
|
-
engine_name,
|
|
636
|
-
engine_dir,
|
|
637
|
-
lora_dir=None,
|
|
638
|
-
lora_task_uids=None,
|
|
639
|
-
debug_mode=False,
|
|
640
|
-
skip_encoder=False,
|
|
641
|
-
stream: torch.cuda.Stream = None,
|
|
642
|
-
):
|
|
643
|
-
# in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
|
|
644
|
-
# accordingly, all input & output tensors should be moved to current device
|
|
645
|
-
# otherwise, it's default to 'cuda:0'
|
|
646
|
-
|
|
647
|
-
# self.runtime_rank = tensorrt_llm.mpi_rank()
|
|
648
|
-
self.device_id = torch.cuda.current_device()
|
|
649
|
-
# torch.cuda.set_device(device_id)
|
|
650
|
-
self.device = torch.cuda.current_device()
|
|
651
|
-
self.skip_encoder = skip_encoder
|
|
652
|
-
self.lora_task_uids = lora_task_uids
|
|
653
|
-
|
|
654
|
-
# when enc-dec runs by itself, stream can be None and we create new stream here
|
|
655
|
-
# when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
|
|
656
|
-
self.stream = stream
|
|
657
|
-
if self.stream is None:
|
|
658
|
-
self.stream = torch.cuda.Stream(self.device)
|
|
659
|
-
torch.cuda.set_stream(self.stream)
|
|
660
|
-
|
|
661
|
-
def engine_setup(component):
|
|
662
|
-
# model config
|
|
663
|
-
config_path = os.path.join(engine_dir, component, "config.json")
|
|
664
|
-
model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
|
|
665
|
-
config_path)
|
|
666
|
-
|
|
667
|
-
# MGMN config
|
|
668
|
-
world_size = tp_size * pp_size
|
|
669
|
-
# runtime_rank = tensorrt_llm.mpi_rank()
|
|
670
|
-
runtime_rank = torch.cuda.current_device()
|
|
671
|
-
# assert runtime_rank < world_size, "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
|
|
672
|
-
# runtime_mapping = tensorrt_llm.Mapping(world_size,
|
|
673
|
-
# runtime_rank,
|
|
674
|
-
# tp_size=tp_size,
|
|
675
|
-
# pp_size=pp_size,
|
|
676
|
-
# gpus_per_node=gpus_per_node)
|
|
677
|
-
# tensorrt_llm.Mapping
|
|
678
|
-
runtime_mapping = Mapping(world_size,
|
|
679
|
-
runtime_rank,
|
|
680
|
-
tp_size=tp_size,
|
|
681
|
-
pp_size=pp_size,
|
|
682
|
-
gpus_per_node=gpus_per_node)
|
|
683
|
-
# load engine
|
|
684
|
-
# engine_fname = get_engine_name(runtime_rank)
|
|
685
|
-
engine_fname = get_engine_name(0)
|
|
686
|
-
with open(os.path.join(engine_dir, component, engine_fname), "rb") as f:
|
|
687
|
-
engine_buffer = f.read()
|
|
688
|
-
|
|
689
|
-
return model_config, runtime_mapping, engine_buffer
|
|
690
|
-
|
|
691
|
-
# Note: encoder and decoder doesn't necessarily have the same TP & PP config
|
|
692
|
-
|
|
693
|
-
if not skip_encoder:
|
|
694
|
-
self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup(
|
|
695
|
-
component='encoder')
|
|
696
|
-
|
|
697
|
-
self.nccl_comm = None
|
|
698
|
-
if self.encoder_runtime_mapping.has_pp():
|
|
699
|
-
# for Pipeline Parallelism in encoder
|
|
700
|
-
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
|
|
701
|
-
self.encoder_runtime_mapping.tp_size,
|
|
702
|
-
self.encoder_runtime_mapping.pp_size,
|
|
703
|
-
self.encoder_runtime_mapping.rank)
|
|
704
|
-
|
|
705
|
-
# session setup
|
|
706
|
-
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
|
707
|
-
encoder_engine_buffer)
|
|
708
|
-
|
|
709
|
-
# encoder lora manager setup
|
|
710
|
-
if self.encoder_model_config.lora_plugin:
|
|
711
|
-
self.encoder_lora_manager = LoraManager()
|
|
712
|
-
# TODO: this is only for bart
|
|
713
|
-
self.encoder_lora_manager.load_from_hf(
|
|
714
|
-
model_dirs=lora_dir,
|
|
715
|
-
model_config=self.encoder_model_config,
|
|
716
|
-
runtime_mapping=self.encoder_runtime_mapping,
|
|
717
|
-
component='encoder',
|
|
718
|
-
)
|
|
719
|
-
else:
|
|
720
|
-
self.encoder_lora_manager = None
|
|
721
|
-
else:
|
|
722
|
-
self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None
|
|
723
|
-
self.nccl_comm, self.encoder_session = None, None
|
|
724
|
-
|
|
725
|
-
self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup(
|
|
726
|
-
component='decoder')
|
|
727
|
-
|
|
728
|
-
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
|
|
729
|
-
self.decoder_model_config,
|
|
730
|
-
decoder_engine_buffer,
|
|
731
|
-
self.decoder_runtime_mapping,
|
|
732
|
-
debug_mode=debug_mode)
|
|
733
|
-
|
|
734
|
-
# decoder lora manager setup
|
|
735
|
-
if self.decoder_model_config.lora_plugin:
|
|
736
|
-
self.decoder_lora_manager = LoraManager()
|
|
737
|
-
# TODO: this is only for bart
|
|
738
|
-
self.decoder_lora_manager.load_from_hf(
|
|
739
|
-
model_dirs=lora_dir,
|
|
740
|
-
model_config=self.decoder_model_config,
|
|
741
|
-
runtime_mapping=self.decoder_runtime_mapping,
|
|
742
|
-
component='decoder',
|
|
743
|
-
)
|
|
744
|
-
else:
|
|
745
|
-
self.decoder_lora_manager = None
|
|
746
|
-
|
|
747
|
-
@classmethod
|
|
748
|
-
def from_engine(cls,
|
|
749
|
-
engine_name,
|
|
750
|
-
engine_dir,
|
|
751
|
-
lora_dir=None,
|
|
752
|
-
lora_task_uids=None,
|
|
753
|
-
debug_mode=False,
|
|
754
|
-
skip_encoder=False,
|
|
755
|
-
stream=None):
|
|
756
|
-
return cls(engine_name,
|
|
757
|
-
engine_dir,
|
|
758
|
-
lora_dir,
|
|
759
|
-
lora_task_uids,
|
|
760
|
-
debug_mode=debug_mode,
|
|
761
|
-
skip_encoder=skip_encoder,
|
|
762
|
-
stream=stream)
|
|
763
|
-
|
|
764
|
-
def process_input(self,
|
|
765
|
-
input_ids,
|
|
766
|
-
remove_input_padding=False,
|
|
767
|
-
pad_token_id=0,
|
|
768
|
-
prompt_tasks=None):
|
|
769
|
-
if remove_input_padding:
|
|
770
|
-
# in remove padding mode --> flatten input, calculate actual length and max length
|
|
771
|
-
# Note: 1st token should never be removed, even if it is pad_token_id
|
|
772
|
-
first_ids = input_ids[:, 0]
|
|
773
|
-
input_ids = input_ids[:, 1:]
|
|
774
|
-
input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(
|
|
775
|
-
torch.IntTensor).to(self.device) # [batch_size]
|
|
776
|
-
new_ids = []
|
|
777
|
-
for i in range(len(input_ids)):
|
|
778
|
-
row = input_ids[i, :]
|
|
779
|
-
row = row[row != pad_token_id]
|
|
780
|
-
new_ids.append(
|
|
781
|
-
torch.cat(
|
|
782
|
-
(torch.IntTensor([first_ids[i]]).to(self.device), row)))
|
|
783
|
-
input_ids = torch.cat(new_ids) # [num_tokens]
|
|
784
|
-
if prompt_tasks is not None:
|
|
785
|
-
prompt_tasks = prompt_tasks[:input_ids.shape[0]]
|
|
786
|
-
else:
|
|
787
|
-
# in padding mode --> keep input, just calculate actual length and max length
|
|
788
|
-
# Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
|
|
789
|
-
input_lengths = torch.tensor(
|
|
790
|
-
1 + (input_ids[:, 1:] != pad_token_id).sum(dim=1).type(
|
|
791
|
-
torch.IntTensor).to(self.device),
|
|
792
|
-
dtype=torch.int32,
|
|
793
|
-
device=self.device)
|
|
794
|
-
max_input_length = torch.max(input_lengths).item()
|
|
795
|
-
return input_ids, input_lengths, max_input_length, prompt_tasks
|
|
796
|
-
|
|
797
|
-
def encoder_run(self,
|
|
798
|
-
input_ids,
|
|
799
|
-
input_lengths,
|
|
800
|
-
max_input_length,
|
|
801
|
-
position_ids=None,
|
|
802
|
-
token_type_ids=None,
|
|
803
|
-
debug_mode=False,
|
|
804
|
-
prompt_embedding_table=None,
|
|
805
|
-
prompt_tasks=None,
|
|
806
|
-
prompt_vocab_size=None,
|
|
807
|
-
attention_mask=None):
|
|
808
|
-
|
|
809
|
-
# each engine has hidden_dim/TP, don't forget to multiply TP
|
|
810
|
-
hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
|
|
811
|
-
if input_ids.dim() == 1:
|
|
812
|
-
hidden_states_shape = (input_ids.shape[0], hidden_size
|
|
813
|
-
) # [num_tokens,D]
|
|
814
|
-
else:
|
|
815
|
-
hidden_states_shape = (input_ids.shape[0], input_ids.shape[1],
|
|
816
|
-
hidden_size) # [BS,seqlen,D]
|
|
817
|
-
hidden_states_dtype = lambda name: trt_dtype_to_torch(
|
|
818
|
-
self.encoder_session.engine.get_tensor_dtype(name))
|
|
819
|
-
|
|
820
|
-
# input tensors. only first PP rank has id input, others are hidden_states input
|
|
821
|
-
inputs = {}
|
|
822
|
-
if self.encoder_runtime_mapping.is_first_pp_rank():
|
|
823
|
-
inputs['input_ids'] = input_ids.contiguous()
|
|
824
|
-
if self.encoder_model_config.has_position_embedding:
|
|
825
|
-
if position_ids is None:
|
|
826
|
-
if self.encoder_model_config.remove_input_padding:
|
|
827
|
-
position_ids = [
|
|
828
|
-
torch.arange(sample_length,
|
|
829
|
-
dtype=torch.int32,
|
|
830
|
-
device=input_ids.device)
|
|
831
|
-
for sample_length in torch_to_numpy(input_lengths)
|
|
832
|
-
]
|
|
833
|
-
position_ids = torch.cat(position_ids)
|
|
834
|
-
else:
|
|
835
|
-
bsz, seq_len = input_ids.shape[:2]
|
|
836
|
-
position_ids = torch.arange(
|
|
837
|
-
seq_len, dtype=torch.int32,
|
|
838
|
-
device=input_ids.device).expand(bsz, -1)
|
|
839
|
-
inputs['position_ids'] = position_ids.contiguous()
|
|
840
|
-
if self.encoder_model_config.has_token_type_embedding:
|
|
841
|
-
inputs['token_type_ids'] = token_type_ids.contiguous()
|
|
842
|
-
|
|
843
|
-
if self.encoder_model_config.max_prompt_embedding_table_size > 0:
|
|
844
|
-
inputs[
|
|
845
|
-
'prompt_embedding_table'] = prompt_embedding_table.contiguous(
|
|
846
|
-
)
|
|
847
|
-
inputs['tasks'] = prompt_tasks.contiguous()
|
|
848
|
-
inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous()
|
|
849
|
-
else:
|
|
850
|
-
# just need a placeholder, engine will call NCCL to recv and fill data from previous rank
|
|
851
|
-
inputs['hidden_states_input'] = torch.empty(
|
|
852
|
-
hidden_states_shape,
|
|
853
|
-
dtype=hidden_states_dtype('hidden_states_input'),
|
|
854
|
-
device=self.device).contiguous()
|
|
855
|
-
if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin:
|
|
856
|
-
inputs['attention_mask'] = attention_mask.contiguous()
|
|
857
|
-
|
|
858
|
-
inputs['input_lengths'] = input_lengths
|
|
859
|
-
# use shape info to pass max length info in remove padding mode
|
|
860
|
-
inputs['max_input_length'] = torch.empty(
|
|
861
|
-
(max_input_length, ),
|
|
862
|
-
dtype=hidden_states_dtype('max_input_length'),
|
|
863
|
-
device=self.device).contiguous()
|
|
864
|
-
batch_size = input_lengths.size(0)
|
|
865
|
-
inputs['host_request_types'] = torch.IntTensor([0] *
|
|
866
|
-
batch_size).to('cpu')
|
|
867
|
-
if self.encoder_model_config.remove_input_padding:
|
|
868
|
-
inputs['host_context_lengths'] = input_lengths.to('cpu')
|
|
869
|
-
|
|
870
|
-
if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:
|
|
871
|
-
inputs.update(
|
|
872
|
-
self.encoder_lora_manager.input_buffers(
|
|
873
|
-
self.lora_task_uids,
|
|
874
|
-
self.encoder_runtime_mapping,
|
|
875
|
-
self.encoder_model_config.num_layers,
|
|
876
|
-
))
|
|
877
|
-
|
|
878
|
-
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
|
|
879
|
-
self.encoder_session.set_shapes(inputs)
|
|
880
|
-
|
|
881
|
-
# output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
|
|
882
|
-
outputs = {}
|
|
883
|
-
if self.encoder_runtime_mapping.is_last_pp_rank():
|
|
884
|
-
outputs['encoder_output'] = torch.empty(
|
|
885
|
-
hidden_states_shape,
|
|
886
|
-
dtype=hidden_states_dtype('encoder_output'),
|
|
887
|
-
device=self.device).contiguous()
|
|
888
|
-
else:
|
|
889
|
-
outputs['hidden_states_output'] = torch.empty(
|
|
890
|
-
hidden_states_shape,
|
|
891
|
-
dtype=hidden_states_dtype('hidden_states_output'),
|
|
892
|
-
device=self.device).contiguous()
|
|
893
|
-
|
|
894
|
-
# -------------------------------------------
|
|
895
|
-
if debug_mode:
|
|
896
|
-
engine = self.encoder_session.engine
|
|
897
|
-
context = self.encoder_session.context
|
|
898
|
-
# setup debugging buffer for the encoder
|
|
899
|
-
for i in range(self.encoder_session.engine.num_io_tensors):
|
|
900
|
-
name = engine.get_tensor_name(i)
|
|
901
|
-
if engine.get_tensor_mode(
|
|
902
|
-
name
|
|
903
|
-
) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
|
|
904
|
-
dtype = engine.get_tensor_dtype(name)
|
|
905
|
-
shape = context.get_tensor_shape(name)
|
|
906
|
-
outputs[name] = torch.zeros(tuple(shape),
|
|
907
|
-
dtype=trt_dtype_to_torch(dtype),
|
|
908
|
-
device=self.device)
|
|
909
|
-
context.set_tensor_address(name, outputs[name].data_ptr())
|
|
910
|
-
# -------------------------------------------
|
|
911
|
-
|
|
912
|
-
# TRT session run
|
|
913
|
-
# Note: need cuda stream ID, not a torch Stream
|
|
914
|
-
ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
|
|
915
|
-
assert ok, "Runtime execution failed"
|
|
916
|
-
self.stream.synchronize()
|
|
917
|
-
|
|
918
|
-
# Tensor Parallelism is handled by model/engine definition
|
|
919
|
-
# But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
|
|
920
|
-
# After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
|
|
921
|
-
def pp_communicate_encoder_output(encoder_output):
|
|
922
|
-
if self.encoder_runtime_mapping.is_last_pp_rank():
|
|
923
|
-
for pp_rank in self.encoder_runtime_mapping.pp_group:
|
|
924
|
-
if pp_rank != self.encoder_runtime_mapping.rank:
|
|
925
|
-
self.nccl_comm.send(encoder_output, pp_rank)
|
|
926
|
-
return encoder_output
|
|
927
|
-
else:
|
|
928
|
-
self.nccl_comm.recv(encoder_output,
|
|
929
|
-
self.encoder_runtime_mapping.pp_group[-1])
|
|
930
|
-
return encoder_output
|
|
931
|
-
|
|
932
|
-
if self.encoder_runtime_mapping.has_pp():
|
|
933
|
-
# use hidden_states output buffer to receive output as the shapes are same
|
|
934
|
-
encoder_output_buf = outputs[
|
|
935
|
-
'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank(
|
|
936
|
-
) else outputs['hidden_states_output']
|
|
937
|
-
encoder_output = pp_communicate_encoder_output(encoder_output_buf)
|
|
938
|
-
else:
|
|
939
|
-
encoder_output = outputs['encoder_output']
|
|
940
|
-
|
|
941
|
-
return encoder_output
|
|
942
|
-
|
|
943
|
-
def generate(self,
|
|
944
|
-
encoder_input_ids,
|
|
945
|
-
decoder_input_ids,
|
|
946
|
-
max_new_tokens,
|
|
947
|
-
num_beams=1,
|
|
948
|
-
pad_token_id=None,
|
|
949
|
-
eos_token_id=None,
|
|
950
|
-
bos_token_id=None,
|
|
951
|
-
debug_mode=False,
|
|
952
|
-
return_dict=False,
|
|
953
|
-
prompt_embedding_table=None,
|
|
954
|
-
prompt_tasks=None,
|
|
955
|
-
prompt_vocab_size=None,
|
|
956
|
-
attention_mask=None,
|
|
957
|
-
time_encoder=False,
|
|
958
|
-
return_encoder_output=False):
|
|
959
|
-
## ensure all externally provided tensors are on the correct device.
|
|
960
|
-
encoder_input_ids = encoder_input_ids.to(self.device)
|
|
961
|
-
decoder_input_ids = decoder_input_ids.to(self.device)
|
|
962
|
-
|
|
963
|
-
if attention_mask is not None:
|
|
964
|
-
attention_mask = torch.tensor(attention_mask,
|
|
965
|
-
dtype=torch.int32,
|
|
966
|
-
device=self.device)
|
|
967
|
-
|
|
968
|
-
## encoder run
|
|
969
|
-
encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding
|
|
970
|
-
|
|
971
|
-
encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input(
|
|
972
|
-
encoder_input_ids, encoder_remove_input_padding, pad_token_id,
|
|
973
|
-
prompt_tasks)
|
|
974
|
-
|
|
975
|
-
if not self.skip_encoder:
|
|
976
|
-
#logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
|
|
977
|
-
if time_encoder:
|
|
978
|
-
tik = time.time()
|
|
979
|
-
encoder_output = self.encoder_run(
|
|
980
|
-
encoder_input_ids,
|
|
981
|
-
encoder_input_lengths,
|
|
982
|
-
encoder_max_input_length,
|
|
983
|
-
debug_mode=debug_mode,
|
|
984
|
-
prompt_embedding_table=prompt_embedding_table,
|
|
985
|
-
prompt_tasks=prompt_tasks,
|
|
986
|
-
prompt_vocab_size=prompt_vocab_size,
|
|
987
|
-
attention_mask=attention_mask)
|
|
988
|
-
if time_encoder:
|
|
989
|
-
tok = time.time()
|
|
990
|
-
print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
|
|
991
|
-
else:
|
|
992
|
-
encoder_output = prompt_embedding_table
|
|
993
|
-
if encoder_input_ids.dim() > 1:
|
|
994
|
-
encoder_output = encoder_output.unsqueeze(0)
|
|
995
|
-
|
|
996
|
-
## decoder run
|
|
997
|
-
# logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
|
|
998
|
-
decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input(
|
|
999
|
-
decoder_input_ids, self.decoder_model_config.remove_input_padding,
|
|
1000
|
-
pad_token_id)
|
|
1001
|
-
|
|
1002
|
-
# `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
|
|
1003
|
-
# where query_len happens to be 1 in current cases, but not necessarily always, and
|
|
1004
|
-
# `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
|
|
1005
|
-
# the query_len is always 1 since we have kv cache.
|
|
1006
|
-
cross_attention_mask = None
|
|
1007
|
-
if attention_mask is not None:
|
|
1008
|
-
cross_attention_mask = torch.tensor(attention_mask,
|
|
1009
|
-
dtype=torch.int32,
|
|
1010
|
-
device=self.device).reshape(
|
|
1011
|
-
attention_mask.shape[0], 1,
|
|
1012
|
-
attention_mask.shape[1])
|
|
1013
|
-
|
|
1014
|
-
# generation config
|
|
1015
|
-
sampling_config = SamplingConfig(end_id=eos_token_id,
|
|
1016
|
-
pad_id=pad_token_id,
|
|
1017
|
-
num_beams=num_beams,
|
|
1018
|
-
min_length=1,
|
|
1019
|
-
return_dict=return_dict)
|
|
1020
|
-
sampling_config.update(output_cum_log_probs=return_dict,
|
|
1021
|
-
output_log_probs=return_dict)
|
|
1022
|
-
|
|
1023
|
-
# decoder autoregressive generation
|
|
1024
|
-
self.decoder_session.setup(
|
|
1025
|
-
decoder_input_lengths.size(0),
|
|
1026
|
-
decoder_max_input_length,
|
|
1027
|
-
max_new_tokens,
|
|
1028
|
-
num_beams,
|
|
1029
|
-
max_attention_window_size=None,
|
|
1030
|
-
encoder_max_input_length=encoder_max_input_length,
|
|
1031
|
-
lora_manager=self.decoder_lora_manager,
|
|
1032
|
-
lora_uids=self.lora_task_uids,
|
|
1033
|
-
)
|
|
1034
|
-
|
|
1035
|
-
output = self.decoder_session.decode(
|
|
1036
|
-
decoder_input_ids,
|
|
1037
|
-
decoder_input_lengths,
|
|
1038
|
-
sampling_config,
|
|
1039
|
-
encoder_output=encoder_output,
|
|
1040
|
-
encoder_input_lengths=encoder_input_lengths,
|
|
1041
|
-
return_dict=return_dict,
|
|
1042
|
-
cross_attention_mask=cross_attention_mask)
|
|
1043
|
-
|
|
1044
|
-
if return_dict and return_encoder_output:
|
|
1045
|
-
output['encoder_output'] = encoder_output
|
|
1046
|
-
|
|
1047
|
-
return output
|