doc-page-extractor 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. doc_page_extractor/__init__.py +5 -15
  2. doc_page_extractor/check_env.py +40 -0
  3. doc_page_extractor/extractor.py +88 -215
  4. doc_page_extractor/model.py +97 -0
  5. doc_page_extractor/parser.py +51 -0
  6. doc_page_extractor/plot.py +52 -79
  7. doc_page_extractor/redacter.py +111 -0
  8. doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
  9. doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
  10. {doc_page_extractor-0.2.0.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
  11. doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
  12. doc_page_extractor/clipper.py +0 -119
  13. doc_page_extractor/downloader.py +0 -16
  14. doc_page_extractor/latex.py +0 -31
  15. doc_page_extractor/layout_order.py +0 -237
  16. doc_page_extractor/layoutreader.py +0 -126
  17. doc_page_extractor/models.py +0 -92
  18. doc_page_extractor/ocr.py +0 -200
  19. doc_page_extractor/ocr_corrector.py +0 -126
  20. doc_page_extractor/onnxocr/__init__.py +0 -1
  21. doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
  22. doc_page_extractor/onnxocr/db_postprocess.py +0 -246
  23. doc_page_extractor/onnxocr/imaug.py +0 -32
  24. doc_page_extractor/onnxocr/operators.py +0 -187
  25. doc_page_extractor/onnxocr/predict_base.py +0 -57
  26. doc_page_extractor/onnxocr/predict_cls.py +0 -109
  27. doc_page_extractor/onnxocr/predict_det.py +0 -139
  28. doc_page_extractor/onnxocr/predict_rec.py +0 -344
  29. doc_page_extractor/onnxocr/predict_system.py +0 -97
  30. doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
  31. doc_page_extractor/onnxocr/utils.py +0 -71
  32. doc_page_extractor/overlap.py +0 -167
  33. doc_page_extractor/raw_optimizer.py +0 -104
  34. doc_page_extractor/rectangle.py +0 -72
  35. doc_page_extractor/rotation.py +0 -158
  36. doc_page_extractor/struct_eqtable/__init__.py +0 -49
  37. doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
  38. doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
  39. doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
  40. doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
  41. doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
  42. doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
  43. doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
  44. doc_page_extractor/table.py +0 -70
  45. doc_page_extractor/types.py +0 -91
  46. doc_page_extractor/utils.py +0 -32
  47. doc_page_extractor-0.2.0.dist-info/METADATA +0 -85
  48. doc_page_extractor-0.2.0.dist-info/RECORD +0 -45
  49. doc_page_extractor-0.2.0.dist-info/licenses/LICENSE +0 -661
  50. doc_page_extractor-0.2.0.dist-info/top_level.txt +0 -2
  51. tests/__init__.py +0 -0
  52. tests/test_history_bus.py +0 -55
@@ -1,1047 +0,0 @@
1
- import os
2
- import time
3
- import json
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- try:
9
- import tensorrt_llm
10
- import tensorrt as trt
11
- import tensorrt_llm.profiler as profiler
12
-
13
- from tensorrt_llm._utils import str_dtype_to_trt, torch_to_numpy
14
- from tensorrt_llm.lora_manager import LoraManager
15
- from tensorrt_llm.runtime import Session, TensorInfo, ModelConfig, SamplingConfig
16
- except:
17
- print("\033[93mimport tensorrt_llm failed, if do not use tensorrt, ignore this message\033[0m")
18
-
19
- from typing import List
20
- from transformers import AutoProcessor, AutoTokenizer, AutoConfig
21
-
22
-
23
- def trt_dtype_to_torch(dtype):
24
- if dtype == trt.float16:
25
- return torch.float16
26
- elif dtype == trt.float32:
27
- return torch.float32
28
- elif dtype == trt.int32:
29
- return torch.int32
30
- elif dtype == trt.bfloat16:
31
- return torch.bfloat16
32
- else:
33
- raise TypeError("%s is not supported" % dtype)
34
-
35
-
36
- class Pix2StructTensorRT(nn.Module):
37
-
38
- def __init__(
39
- self,
40
- model_path,
41
- tensorrt_path,
42
- batch_size=1,
43
- max_new_tokens=4096,
44
- cache_dir=None,
45
- local_files_only=None,
46
- **kwargs,
47
- ):
48
- self.model_ckpt_path = model_path
49
- self.tensorrt_path = tensorrt_path
50
- self.batch_size = batch_size
51
- self.max_new_tokens = max_new_tokens
52
- self.cache_dir = cache_dir
53
- self.local_files_only = local_files_only
54
-
55
- self.llm_engine_path = os.path.join(tensorrt_path, 'llm_engines')
56
- self.visual_engine_path = os.path.join(tensorrt_path, 'visual_engines')
57
-
58
- device_id = torch.cuda.current_device() % torch.cuda.device_count()
59
- self.device_id = device_id
60
- self.device = "cuda:%d" % (device_id)
61
-
62
- self.stream = torch.cuda.Stream(torch.cuda.current_device())
63
- torch.cuda.set_stream(self.stream)
64
-
65
- # parse model type from visual engine config
66
- with open(os.path.join(self.visual_engine_path, "config.json"),
67
- "r") as f:
68
- config = json.load(f)
69
- self.model_type = config['builder_config']['model_type']
70
- self.vision_precision = config['builder_config']['precision']
71
-
72
- self.vision_precision = 'float16'
73
- self.decoder_llm = not (
74
- 't5' in self.model_type
75
- or self.model_type in ['nougat', 'pix2struct', 'StructEqTable']
76
- ) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
77
-
78
- self.profiling_iterations = 20
79
-
80
- self.init_image_encoder()
81
- self.init_tokenizer()
82
- self.init_llm()
83
- self.init_image_processor()
84
-
85
- self.special_str_list = ['\\midrule', '\\hline']
86
- self.supported_output_format = ['latex']
87
-
88
- def postprocess_latex_code(self, code):
89
- for special_str in self.special_str_list:
90
- code = code.replace(special_str, special_str + ' ')
91
- return code
92
-
93
- def init_image_processor(self):
94
- self.data_processor = AutoProcessor.from_pretrained(
95
- pretrained_model_name_or_path=self.model_ckpt_path,
96
- cache_dir=self.cache_dir,
97
- local_files_only=self.local_files_only,
98
- )
99
-
100
- def init_tokenizer(self):
101
- self.tokenizer = AutoTokenizer.from_pretrained(
102
- pretrained_model_name_or_path=self.model_ckpt_path,
103
- use_fast=True,
104
- use_legacy=False,
105
- cache_dir=self.cache_dir,
106
- local_files_only=self.local_files_only,
107
- )
108
- # self.tokenizer.padding_side = "right"
109
-
110
- def init_image_encoder(self):
111
- vision_encoder_path = os.path.join(self.visual_engine_path,
112
- 'visual_encoder.engine')
113
- with open(vision_encoder_path, 'rb') as f:
114
- engine_buffer = f.read()
115
- self.visual_encoder_session = Session.from_serialized_engine(
116
- engine_buffer)
117
-
118
- def init_llm(self):
119
-
120
- self.model = TRTLLMEncDecModel.from_engine(
121
- os.path.basename(self.model_ckpt_path),
122
- self.llm_engine_path,
123
- skip_encoder=self.model_type in ['nougat', 'pix2struct', 'StructEqTable'],
124
- debug_mode=False,
125
- stream=self.stream)
126
-
127
- self.model_config = self.model.decoder_model_config
128
- self.runtime_mapping = self.model.decoder_runtime_mapping
129
-
130
- def __call__(self, image, **kwargs):
131
- # process image to tokens
132
- image_tokens = self.data_processor.image_processor(
133
- images=image,
134
- return_tensors='pt',
135
- )
136
-
137
- for k, v in image_tokens.items():
138
- image_tokens[k] = v.cuda()
139
-
140
- model_output = self.run(
141
- flattened_patches=image_tokens['flattened_patches'],
142
- attention_mask=image_tokens['attention_mask'],
143
- max_new_tokens=self.max_new_tokens
144
- )
145
-
146
- # postprocess
147
- latex_codes = []
148
- for i, code in enumerate(model_output):
149
- latex_codes.append(self.postprocess_latex_code(code[0]))
150
-
151
- return latex_codes
152
-
153
- def preprocess(self, warmup, pre_prompt, post_prompt, image,
154
- attention_mask):
155
- if not warmup:
156
- profiler.start("Vision")
157
-
158
- visual_features, visual_atts = self.get_visual_features(
159
- torch.stack(image['image_patches'], dim=0)
160
- if self.model_type == 'fuyu' else image, attention_mask)
161
-
162
- if not warmup:
163
- profiler.stop("Vision")
164
-
165
- pre_input_ids = self.tokenizer(pre_prompt,
166
- return_tensors="pt",
167
- padding=True).input_ids
168
- if post_prompt[0] is not None:
169
- post_input_ids = self.tokenizer(post_prompt,
170
- return_tensors="pt",
171
- padding=True).input_ids
172
- length = pre_input_ids.shape[1] + post_input_ids.shape[
173
- 1] + visual_atts.shape[1]
174
- else:
175
- post_input_ids = None
176
- length = pre_input_ids.shape[1] + visual_atts.shape[1]
177
-
178
- input_lengths = torch.IntTensor([length] * 1).to(
179
- torch.int32)
180
-
181
- input_ids, ptuning_args = self.setup_fake_prompts(
182
- visual_features, pre_input_ids, post_input_ids, input_lengths)
183
-
184
- return input_ids, input_lengths, ptuning_args, visual_features
185
-
186
- def generate(self, pre_prompt, post_prompt, image, decoder_input_ids,
187
- max_new_tokens, attention_mask, warmup):
188
- if not warmup:
189
- profiler.start("Generate")
190
-
191
- input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
192
- warmup, pre_prompt, post_prompt, image, attention_mask)
193
-
194
- if warmup: return None
195
-
196
- profiler.start("LLM")
197
-
198
- # Trim encoder input_ids to match visual features shape
199
- ids_shape = (self.batch_size, visual_features.shape[1])
200
-
201
- input_ids = torch.ones(ids_shape, dtype=torch.int32)
202
-
203
- output_ids = self.model.generate(
204
- input_ids,
205
- decoder_input_ids,
206
- max_new_tokens,
207
- num_beams=1,
208
- bos_token_id=self.tokenizer.bos_token_id,
209
- pad_token_id=self.tokenizer.pad_token_id,
210
- eos_token_id=self.tokenizer.eos_token_id,
211
- debug_mode=False,
212
- prompt_embedding_table=ptuning_args[0],
213
- prompt_tasks=ptuning_args[1],
214
- prompt_vocab_size=ptuning_args[2],
215
- attention_mask=attention_mask)
216
-
217
- # Reset input_lengths to match decoder_input_ids
218
- input_lengths = torch.ones(input_lengths.shape,
219
- dtype=input_lengths.dtype)
220
- profiler.stop("LLM")
221
-
222
- if tensorrt_llm.mpi_rank() == 0:
223
- # Extract a list of tensors of shape beam_width x output_ids.
224
- output_beams_list = [
225
- self.tokenizer.batch_decode(
226
- output_ids[batch_idx, :, input_lengths[batch_idx]:],
227
- skip_special_tokens=True)
228
- for batch_idx in range(self.batch_size)
229
- ]
230
-
231
- stripped_text = [[
232
- output_beams_list[batch_idx][beam_idx].strip()
233
- for beam_idx in range(1)
234
- ] for batch_idx in range(self.batch_size)]
235
- profiler.stop("Generate")
236
- return stripped_text
237
- else:
238
- profiler.stop("Generate")
239
- return None
240
-
241
- def get_visual_features(self, image, attention_mask):
242
- visual_features = {
243
- 'input':
244
- image.to(
245
- tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))
246
- }
247
- if attention_mask is not None:
248
- visual_features['attention_mask'] = attention_mask
249
- tensor_info = [
250
- TensorInfo('input', str_dtype_to_trt(self.vision_precision),
251
- image.shape)
252
- ]
253
- if attention_mask is not None:
254
- tensor_info.append(
255
- TensorInfo('attention_mask', trt.DataType.INT32,
256
- attention_mask.shape))
257
- visual_output_info = self.visual_encoder_session.infer_shapes(
258
- tensor_info)
259
- visual_outputs = {
260
- t.name: torch.empty(tuple(t.shape),
261
- dtype=trt_dtype_to_torch(t.dtype),
262
- device=image.device)
263
- for t in visual_output_info
264
- }
265
-
266
- ok = self.visual_encoder_session.run(visual_features, visual_outputs,
267
- self.stream.cuda_stream)
268
- assert ok, "Runtime execution failed for vision encoder session"
269
- self.stream.synchronize()
270
-
271
- image_embeds = visual_outputs['output']
272
- image_atts = torch.ones(image_embeds.size()[:-1],
273
- dtype=torch.long).to(image.device)
274
-
275
- return image_embeds, image_atts
276
-
277
- def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids,
278
- input_lengths):
279
- # Assemble fake prompts which points to image embedding actually
280
- fake_prompt_id = torch.arange(
281
- self.model_config.vocab_size, self.model_config.vocab_size +
282
- visual_features.shape[0] * visual_features.shape[1])
283
- fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0],
284
- visual_features.shape[1])
285
-
286
- if post_input_ids is not None:
287
- input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
288
- else:
289
- input_ids = [fake_prompt_id, pre_input_ids]
290
-
291
- input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
292
-
293
- if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
294
- ptuning_args = self.ptuning_setup(visual_features, input_ids,
295
- input_lengths)
296
- else:
297
- ptuning_args = [None, None, None]
298
-
299
- return input_ids, ptuning_args
300
-
301
- def ptuning_setup(self, prompt_table, input_ids, input_lengths):
302
- hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
303
- if prompt_table is not None:
304
- task_vocab_size = torch.tensor(
305
- [prompt_table.shape[1]],
306
- dtype=torch.int32,
307
- ).cuda()
308
- prompt_table = prompt_table.view(
309
- (prompt_table.shape[0] * prompt_table.shape[1],
310
- prompt_table.shape[2]))
311
- assert prompt_table.shape[
312
- 1] == hidden_size, "Prompt table dimensions do not match hidden size"
313
-
314
- prompt_table = prompt_table.cuda().to(
315
- dtype=tensorrt_llm._utils.str_dtype_to_torch(
316
- self.model_config.dtype))
317
- else:
318
- prompt_table = torch.empty([1, hidden_size]).cuda()
319
- task_vocab_size = torch.zeros([1]).cuda()
320
-
321
- if self.model_config.remove_input_padding:
322
- tasks = torch.zeros([torch.sum(input_lengths)],
323
- dtype=torch.int32).cuda()
324
- if self.decoder_llm: tasks = tasks.unsqueeze(0)
325
- else:
326
- tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
327
-
328
- return [prompt_table, tasks, task_vocab_size]
329
-
330
- def setup_inputs(self, input_text, raw_image):
331
- attention_mask = None
332
-
333
- image_processor = AutoProcessor.from_pretrained(
334
- pretrained_model_name_or_path=self.model_ckpt_path,
335
- cache_dir=self.cache_dir,
336
- local_files_only=self.local_files_only,
337
- )
338
- if input_text is None:
339
- input_text = ""
340
- inputs = image_processor(
341
- images=raw_image,
342
- text=input_text,
343
- return_tensors="pt",
344
- )
345
- image = inputs['flattened_patches']
346
- image = image.expand(self.batch_size, -1, -1).contiguous()
347
- attention_mask = inputs['attention_mask'].to(self.device).to(
348
- torch.int)
349
- attention_mask = attention_mask.expand(self.batch_size,
350
- -1).contiguous()
351
- pre_prompt = ""
352
- post_prompt = None
353
-
354
- # Repeat inputs to match batch size
355
- pre_prompt = [pre_prompt] * self.batch_size
356
- post_prompt = [post_prompt] * self.batch_size
357
- image = image.to(self.device)
358
-
359
- # Generate decoder_input_ids for enc-dec models
360
- # Custom prompts can be added as:
361
- # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
362
- if self.decoder_llm:
363
- decoder_input_ids = None
364
- else:
365
- config = AutoConfig.from_pretrained(
366
- pretrained_model_name_or_path=self.model_ckpt_path,
367
- cache_dir=self.cache_dir,
368
- local_files_only=self.local_files_only,
369
- )
370
- decoder_start_id = config.decoder_start_token_id # T5
371
- if decoder_start_id is None:
372
- decoder_start_id = config.decoder.bos_token_id # Nougat
373
-
374
- decoder_input_ids = torch.IntTensor([[decoder_start_id]])
375
- decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))
376
-
377
- return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask
378
-
379
- def run(self, flattened_patches, attention_mask, max_new_tokens):
380
- # input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs(
381
- # None, raw_image)
382
- pre_prompt = [""] * self.batch_size
383
- post_prompt = [None] * self.batch_size
384
- config = AutoConfig.from_pretrained(
385
- pretrained_model_name_or_path=self.model_ckpt_path,
386
- cache_dir=self.cache_dir,
387
- local_files_only=self.local_files_only,
388
- )
389
- decoder_start_id = config.decoder_start_token_id # T5
390
- decoder_input_ids = torch.IntTensor([[decoder_start_id]])
391
- decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1))
392
-
393
- processed_image = flattened_patches.expand(self.batch_size, -1, -1).contiguous()
394
- attention_mask = attention_mask.to(self.device).to(torch.int)
395
- attention_mask = attention_mask.expand(self.batch_size,-1).contiguous()
396
-
397
- self.generate(pre_prompt,
398
- post_prompt,
399
- processed_image,
400
- decoder_input_ids,
401
- max_new_tokens,
402
- attention_mask=attention_mask,
403
- warmup=True)
404
- # num_iters = self.profiling_iterations if self.args.run_profiling else 1
405
- num_iters = 1
406
- # print(num_iters)
407
- for _ in range(num_iters):
408
- output_text = self.generate(pre_prompt,
409
- post_prompt,
410
- processed_image,
411
- decoder_input_ids,
412
- max_new_tokens,
413
- attention_mask=attention_mask,
414
- warmup=False)
415
- # if self.runtime_rank == 0:
416
- # self.print_result(input_text, output_text)
417
- return output_text
418
-
419
-
420
- def read_config(config_path):
421
- with open(config_path, "r") as f:
422
- config = json.load(f)
423
-
424
- builder_config = config['build_config']
425
- plugin_config = builder_config['plugin_config']
426
- pretrained_config = config['pretrained_config']
427
- lora_config = builder_config['lora_config']
428
- auto_parallel_config = builder_config['auto_parallel_config']
429
- use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
430
- remove_input_padding = plugin_config["remove_input_padding"]
431
- use_lora_plugin = plugin_config["lora_plugin"]
432
- tp_size = pretrained_config['mapping']['tp_size']
433
- pp_size = pretrained_config['mapping']['pp_size']
434
- gpus_per_node = auto_parallel_config['gpus_per_node']
435
- world_size = tp_size * pp_size
436
- assert world_size == tensorrt_llm.mpi_world_size(), \
437
- f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
438
- num_heads = pretrained_config["num_attention_heads"]
439
- hidden_size = pretrained_config["hidden_size"]
440
- head_size = pretrained_config["head_size"]
441
- vocab_size = pretrained_config["vocab_size"]
442
- max_batch_size = builder_config["max_batch_size"]
443
- max_beam_width = builder_config["max_beam_width"]
444
- num_layers = pretrained_config["num_hidden_layers"]
445
- num_kv_heads = pretrained_config.get('num_kv_heads', num_heads)
446
-
447
- assert (num_heads % tp_size) == 0
448
- num_heads = num_heads // tp_size
449
- hidden_size = hidden_size // tp_size
450
- num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
451
-
452
- cross_attention = pretrained_config["architecture"] == "DecoderModel"
453
- skip_cross_qkv = pretrained_config.get('skip_cross_qkv', False)
454
- has_position_embedding = pretrained_config["has_position_embedding"]
455
- has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
456
- use_custom_all_reduce = plugin_config.get('use_custom_all_reduce', False)
457
- dtype = pretrained_config["dtype"]
458
-
459
- paged_kv_cache = plugin_config['paged_kv_cache']
460
- tokens_per_block = plugin_config['tokens_per_block']
461
-
462
- gather_context_logits = builder_config.get('gather_context_logits', False)
463
- gather_generation_logits = builder_config.get('gather_generation_logits',
464
- False)
465
- max_prompt_embedding_table_size = builder_config.get(
466
- 'max_prompt_embedding_table_size', 0)
467
-
468
- model_config = ModelConfig(
469
- num_heads=num_heads,
470
- num_kv_heads=num_kv_heads,
471
- hidden_size=hidden_size,
472
- head_size=head_size,
473
- max_batch_size=max_batch_size,
474
- max_beam_width=max_beam_width,
475
- vocab_size=vocab_size,
476
- num_layers=num_layers,
477
- gpt_attention_plugin=use_gpt_attention_plugin,
478
- remove_input_padding=remove_input_padding,
479
- paged_kv_cache=paged_kv_cache,
480
- tokens_per_block=tokens_per_block,
481
- cross_attention=cross_attention,
482
- has_position_embedding=has_position_embedding,
483
- has_token_type_embedding=has_token_type_embedding,
484
- use_custom_all_reduce=use_custom_all_reduce,
485
- dtype=dtype,
486
- gather_context_logits=gather_context_logits,
487
- gather_generation_logits=gather_generation_logits,
488
- max_prompt_embedding_table_size=max_prompt_embedding_table_size,
489
- lora_plugin=use_lora_plugin,
490
- lora_target_modules=lora_config.get('lora_target_modules'),
491
- trtllm_modules_to_hf_modules=lora_config.get(
492
- 'trtllm_modules_to_hf_modules'),
493
- skip_cross_qkv=skip_cross_qkv,
494
- )
495
-
496
- return model_config, tp_size, pp_size, gpus_per_node, dtype
497
-
498
-
499
- class Mapping(object):
500
- def __init__(
501
- self,
502
- world_size=1,
503
- rank=0,
504
- gpus_per_node=8,
505
- tp_size=1,
506
- pp_size=1,
507
- moe_tp_size=-1, # -1 means no moe
508
- moe_ep_size=-1): # -1 means no moe
509
- # set default values for non-moe cases
510
- if moe_tp_size == -1:
511
- moe_tp_size = tp_size
512
- moe_ep_size = 1
513
-
514
- if pp_size * tp_size != world_size:
515
- raise ValueError(
516
- f"world_size must equal to pp_size * tp_size, but got {world_size} != {pp_size} * {tp_size}"
517
- )
518
-
519
- moe_tp_ep_size = moe_tp_size * moe_ep_size
520
- if moe_tp_ep_size != tp_size:
521
- raise ValueError(
522
- f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}"
523
- )
524
-
525
- self.tp_size = tp_size
526
- self.pp_size = pp_size
527
- self.moe_tp_size = moe_tp_size
528
- self.moe_ep_size = moe_ep_size
529
- self.world_size = world_size
530
- self.rank = rank
531
- self.gpus_per_node = gpus_per_node
532
-
533
- self.pp_groups = []
534
- self.tp_groups = []
535
- self.moe_tp_groups = []
536
- self.moe_ep_groups = []
537
-
538
- # init pp group
539
- for i in range(tp_size):
540
- ranks = range(i+ self.rank, world_size+ self.rank, tp_size)
541
- self.pp_groups.append(list(ranks))
542
-
543
- # init tp group
544
- for i in range(pp_size):
545
- ranks = range(i * tp_size + self.rank, (i + 1) * tp_size + self.rank)
546
- self.tp_groups.append(list(ranks))
547
-
548
- # init moe tp group
549
- for i in range(pp_size):
550
- for j in range(moe_ep_size):
551
- ranks = range(i * moe_tp_ep_size + j, (i + 1) * moe_tp_ep_size,
552
- moe_ep_size)
553
- self.moe_tp_groups.append(list(ranks))
554
-
555
- # init moe ep group
556
- for i in range(pp_size):
557
- for j in range(moe_tp_size):
558
- ranks = range(i * moe_tp_ep_size + j * moe_ep_size,
559
- i * moe_tp_ep_size + (j + 1) * moe_ep_size)
560
- self.moe_ep_groups.append(list(ranks))
561
-
562
- # self.pp_rank = self.rank // self.tp_size
563
- # self.tp_rank = self.rank % self.tp_size
564
- self.pp_rank = 0
565
- self.tp_rank = 0
566
- self.moe_tp_rank = self.tp_rank // self.moe_ep_size
567
- self.moe_ep_rank = self.tp_rank % self.moe_ep_size
568
-
569
- # self.tp_group = self.tp_groups[self.pp_rank]
570
- # self.pp_group = self.pp_groups[self.tp_rank]
571
- self.moe_tp_group = self.moe_tp_groups[self.pp_rank * moe_ep_size +
572
- self.moe_ep_rank]
573
- self.moe_ep_group = self.moe_ep_groups[self.pp_rank * moe_tp_size +
574
- self.moe_tp_rank]
575
-
576
- self.node_rank = self.rank // self.gpus_per_node
577
- self.local_rank = self.rank % self.gpus_per_node
578
-
579
- def get_node_rank(self, rank: int):
580
- return rank // self.gpus_per_node
581
-
582
- def get_local_rank(self, rank: int):
583
- return rank % self.gpus_per_node
584
-
585
- def has_tp(self):
586
- return self.tp_size > 1
587
-
588
- def is_last_pp_rank(self):
589
- return self.pp_rank == self.pp_size - 1
590
-
591
- def is_first_pp_rank(self):
592
- return self.pp_rank == 0
593
-
594
- def has_pp(self):
595
- return self.pp_size > 1
596
-
597
- def prev_pp_rank(self):
598
- p = self.rank - self.tp_size
599
- if p < 0:
600
- p = p + self.world_size
601
- return p
602
-
603
- def next_pp_rank(self):
604
- p = self.rank + self.tp_size
605
- if p >= self.world_size:
606
- p = p - self.world_size
607
- return p
608
-
609
- def has_moe_tp(self):
610
- return self.moe_tp_size > 1
611
-
612
- def has_moe_ep(self):
613
- return self.moe_ep_size > 1
614
-
615
- def pp_layers(self, num_layers: int) -> List[int]:
616
- layers_per_pipeline_stage = num_layers // self.pp_size
617
- layers_range = range(self.pp_rank * layers_per_pipeline_stage,
618
- (self.pp_rank + 1) * layers_per_pipeline_stage)
619
- return list(layers_range)
620
-
621
- def ep_experts(self, num_experts: int) -> List[int]:
622
- experts_per_rank = num_experts // self.moe_ep_size
623
- experts_range = range(self.moe_ep_rank * experts_per_rank,
624
- (self.moe_ep_rank + 1) * experts_per_rank)
625
- return list(experts_range)
626
-
627
-
628
- def get_engine_name(rank):
629
- return 'rank{}.engine'.format(rank)
630
-
631
- class TRTLLMEncDecModel:
632
-
633
- def __init__(
634
- self,
635
- engine_name,
636
- engine_dir,
637
- lora_dir=None,
638
- lora_task_uids=None,
639
- debug_mode=False,
640
- skip_encoder=False,
641
- stream: torch.cuda.Stream = None,
642
- ):
643
- # in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
644
- # accordingly, all input & output tensors should be moved to current device
645
- # otherwise, it's default to 'cuda:0'
646
-
647
- # self.runtime_rank = tensorrt_llm.mpi_rank()
648
- self.device_id = torch.cuda.current_device()
649
- # torch.cuda.set_device(device_id)
650
- self.device = torch.cuda.current_device()
651
- self.skip_encoder = skip_encoder
652
- self.lora_task_uids = lora_task_uids
653
-
654
- # when enc-dec runs by itself, stream can be None and we create new stream here
655
- # when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
656
- self.stream = stream
657
- if self.stream is None:
658
- self.stream = torch.cuda.Stream(self.device)
659
- torch.cuda.set_stream(self.stream)
660
-
661
- def engine_setup(component):
662
- # model config
663
- config_path = os.path.join(engine_dir, component, "config.json")
664
- model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
665
- config_path)
666
-
667
- # MGMN config
668
- world_size = tp_size * pp_size
669
- # runtime_rank = tensorrt_llm.mpi_rank()
670
- runtime_rank = torch.cuda.current_device()
671
- # assert runtime_rank < world_size, "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
672
- # runtime_mapping = tensorrt_llm.Mapping(world_size,
673
- # runtime_rank,
674
- # tp_size=tp_size,
675
- # pp_size=pp_size,
676
- # gpus_per_node=gpus_per_node)
677
- # tensorrt_llm.Mapping
678
- runtime_mapping = Mapping(world_size,
679
- runtime_rank,
680
- tp_size=tp_size,
681
- pp_size=pp_size,
682
- gpus_per_node=gpus_per_node)
683
- # load engine
684
- # engine_fname = get_engine_name(runtime_rank)
685
- engine_fname = get_engine_name(0)
686
- with open(os.path.join(engine_dir, component, engine_fname), "rb") as f:
687
- engine_buffer = f.read()
688
-
689
- return model_config, runtime_mapping, engine_buffer
690
-
691
- # Note: encoder and decoder doesn't necessarily have the same TP & PP config
692
-
693
- if not skip_encoder:
694
- self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup(
695
- component='encoder')
696
-
697
- self.nccl_comm = None
698
- if self.encoder_runtime_mapping.has_pp():
699
- # for Pipeline Parallelism in encoder
700
- self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
701
- self.encoder_runtime_mapping.tp_size,
702
- self.encoder_runtime_mapping.pp_size,
703
- self.encoder_runtime_mapping.rank)
704
-
705
- # session setup
706
- self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
707
- encoder_engine_buffer)
708
-
709
- # encoder lora manager setup
710
- if self.encoder_model_config.lora_plugin:
711
- self.encoder_lora_manager = LoraManager()
712
- # TODO: this is only for bart
713
- self.encoder_lora_manager.load_from_hf(
714
- model_dirs=lora_dir,
715
- model_config=self.encoder_model_config,
716
- runtime_mapping=self.encoder_runtime_mapping,
717
- component='encoder',
718
- )
719
- else:
720
- self.encoder_lora_manager = None
721
- else:
722
- self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None
723
- self.nccl_comm, self.encoder_session = None, None
724
-
725
- self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup(
726
- component='decoder')
727
-
728
- self.decoder_session = tensorrt_llm.runtime.GenerationSession(
729
- self.decoder_model_config,
730
- decoder_engine_buffer,
731
- self.decoder_runtime_mapping,
732
- debug_mode=debug_mode)
733
-
734
- # decoder lora manager setup
735
- if self.decoder_model_config.lora_plugin:
736
- self.decoder_lora_manager = LoraManager()
737
- # TODO: this is only for bart
738
- self.decoder_lora_manager.load_from_hf(
739
- model_dirs=lora_dir,
740
- model_config=self.decoder_model_config,
741
- runtime_mapping=self.decoder_runtime_mapping,
742
- component='decoder',
743
- )
744
- else:
745
- self.decoder_lora_manager = None
746
-
747
- @classmethod
748
- def from_engine(cls,
749
- engine_name,
750
- engine_dir,
751
- lora_dir=None,
752
- lora_task_uids=None,
753
- debug_mode=False,
754
- skip_encoder=False,
755
- stream=None):
756
- return cls(engine_name,
757
- engine_dir,
758
- lora_dir,
759
- lora_task_uids,
760
- debug_mode=debug_mode,
761
- skip_encoder=skip_encoder,
762
- stream=stream)
763
-
764
- def process_input(self,
765
- input_ids,
766
- remove_input_padding=False,
767
- pad_token_id=0,
768
- prompt_tasks=None):
769
- if remove_input_padding:
770
- # in remove padding mode --> flatten input, calculate actual length and max length
771
- # Note: 1st token should never be removed, even if it is pad_token_id
772
- first_ids = input_ids[:, 0]
773
- input_ids = input_ids[:, 1:]
774
- input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(
775
- torch.IntTensor).to(self.device) # [batch_size]
776
- new_ids = []
777
- for i in range(len(input_ids)):
778
- row = input_ids[i, :]
779
- row = row[row != pad_token_id]
780
- new_ids.append(
781
- torch.cat(
782
- (torch.IntTensor([first_ids[i]]).to(self.device), row)))
783
- input_ids = torch.cat(new_ids) # [num_tokens]
784
- if prompt_tasks is not None:
785
- prompt_tasks = prompt_tasks[:input_ids.shape[0]]
786
- else:
787
- # in padding mode --> keep input, just calculate actual length and max length
788
- # Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
789
- input_lengths = torch.tensor(
790
- 1 + (input_ids[:, 1:] != pad_token_id).sum(dim=1).type(
791
- torch.IntTensor).to(self.device),
792
- dtype=torch.int32,
793
- device=self.device)
794
- max_input_length = torch.max(input_lengths).item()
795
- return input_ids, input_lengths, max_input_length, prompt_tasks
796
-
797
- def encoder_run(self,
798
- input_ids,
799
- input_lengths,
800
- max_input_length,
801
- position_ids=None,
802
- token_type_ids=None,
803
- debug_mode=False,
804
- prompt_embedding_table=None,
805
- prompt_tasks=None,
806
- prompt_vocab_size=None,
807
- attention_mask=None):
808
-
809
- # each engine has hidden_dim/TP, don't forget to multiply TP
810
- hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
811
- if input_ids.dim() == 1:
812
- hidden_states_shape = (input_ids.shape[0], hidden_size
813
- ) # [num_tokens,D]
814
- else:
815
- hidden_states_shape = (input_ids.shape[0], input_ids.shape[1],
816
- hidden_size) # [BS,seqlen,D]
817
- hidden_states_dtype = lambda name: trt_dtype_to_torch(
818
- self.encoder_session.engine.get_tensor_dtype(name))
819
-
820
- # input tensors. only first PP rank has id input, others are hidden_states input
821
- inputs = {}
822
- if self.encoder_runtime_mapping.is_first_pp_rank():
823
- inputs['input_ids'] = input_ids.contiguous()
824
- if self.encoder_model_config.has_position_embedding:
825
- if position_ids is None:
826
- if self.encoder_model_config.remove_input_padding:
827
- position_ids = [
828
- torch.arange(sample_length,
829
- dtype=torch.int32,
830
- device=input_ids.device)
831
- for sample_length in torch_to_numpy(input_lengths)
832
- ]
833
- position_ids = torch.cat(position_ids)
834
- else:
835
- bsz, seq_len = input_ids.shape[:2]
836
- position_ids = torch.arange(
837
- seq_len, dtype=torch.int32,
838
- device=input_ids.device).expand(bsz, -1)
839
- inputs['position_ids'] = position_ids.contiguous()
840
- if self.encoder_model_config.has_token_type_embedding:
841
- inputs['token_type_ids'] = token_type_ids.contiguous()
842
-
843
- if self.encoder_model_config.max_prompt_embedding_table_size > 0:
844
- inputs[
845
- 'prompt_embedding_table'] = prompt_embedding_table.contiguous(
846
- )
847
- inputs['tasks'] = prompt_tasks.contiguous()
848
- inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous()
849
- else:
850
- # just need a placeholder, engine will call NCCL to recv and fill data from previous rank
851
- inputs['hidden_states_input'] = torch.empty(
852
- hidden_states_shape,
853
- dtype=hidden_states_dtype('hidden_states_input'),
854
- device=self.device).contiguous()
855
- if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin:
856
- inputs['attention_mask'] = attention_mask.contiguous()
857
-
858
- inputs['input_lengths'] = input_lengths
859
- # use shape info to pass max length info in remove padding mode
860
- inputs['max_input_length'] = torch.empty(
861
- (max_input_length, ),
862
- dtype=hidden_states_dtype('max_input_length'),
863
- device=self.device).contiguous()
864
- batch_size = input_lengths.size(0)
865
- inputs['host_request_types'] = torch.IntTensor([0] *
866
- batch_size).to('cpu')
867
- if self.encoder_model_config.remove_input_padding:
868
- inputs['host_context_lengths'] = input_lengths.to('cpu')
869
-
870
- if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:
871
- inputs.update(
872
- self.encoder_lora_manager.input_buffers(
873
- self.lora_task_uids,
874
- self.encoder_runtime_mapping,
875
- self.encoder_model_config.num_layers,
876
- ))
877
-
878
- # Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
879
- self.encoder_session.set_shapes(inputs)
880
-
881
- # output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
882
- outputs = {}
883
- if self.encoder_runtime_mapping.is_last_pp_rank():
884
- outputs['encoder_output'] = torch.empty(
885
- hidden_states_shape,
886
- dtype=hidden_states_dtype('encoder_output'),
887
- device=self.device).contiguous()
888
- else:
889
- outputs['hidden_states_output'] = torch.empty(
890
- hidden_states_shape,
891
- dtype=hidden_states_dtype('hidden_states_output'),
892
- device=self.device).contiguous()
893
-
894
- # -------------------------------------------
895
- if debug_mode:
896
- engine = self.encoder_session.engine
897
- context = self.encoder_session.context
898
- # setup debugging buffer for the encoder
899
- for i in range(self.encoder_session.engine.num_io_tensors):
900
- name = engine.get_tensor_name(i)
901
- if engine.get_tensor_mode(
902
- name
903
- ) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
904
- dtype = engine.get_tensor_dtype(name)
905
- shape = context.get_tensor_shape(name)
906
- outputs[name] = torch.zeros(tuple(shape),
907
- dtype=trt_dtype_to_torch(dtype),
908
- device=self.device)
909
- context.set_tensor_address(name, outputs[name].data_ptr())
910
- # -------------------------------------------
911
-
912
- # TRT session run
913
- # Note: need cuda stream ID, not a torch Stream
914
- ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
915
- assert ok, "Runtime execution failed"
916
- self.stream.synchronize()
917
-
918
- # Tensor Parallelism is handled by model/engine definition
919
- # But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
920
- # After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
921
- def pp_communicate_encoder_output(encoder_output):
922
- if self.encoder_runtime_mapping.is_last_pp_rank():
923
- for pp_rank in self.encoder_runtime_mapping.pp_group:
924
- if pp_rank != self.encoder_runtime_mapping.rank:
925
- self.nccl_comm.send(encoder_output, pp_rank)
926
- return encoder_output
927
- else:
928
- self.nccl_comm.recv(encoder_output,
929
- self.encoder_runtime_mapping.pp_group[-1])
930
- return encoder_output
931
-
932
- if self.encoder_runtime_mapping.has_pp():
933
- # use hidden_states output buffer to receive output as the shapes are same
934
- encoder_output_buf = outputs[
935
- 'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank(
936
- ) else outputs['hidden_states_output']
937
- encoder_output = pp_communicate_encoder_output(encoder_output_buf)
938
- else:
939
- encoder_output = outputs['encoder_output']
940
-
941
- return encoder_output
942
-
943
- def generate(self,
944
- encoder_input_ids,
945
- decoder_input_ids,
946
- max_new_tokens,
947
- num_beams=1,
948
- pad_token_id=None,
949
- eos_token_id=None,
950
- bos_token_id=None,
951
- debug_mode=False,
952
- return_dict=False,
953
- prompt_embedding_table=None,
954
- prompt_tasks=None,
955
- prompt_vocab_size=None,
956
- attention_mask=None,
957
- time_encoder=False,
958
- return_encoder_output=False):
959
- ## ensure all externally provided tensors are on the correct device.
960
- encoder_input_ids = encoder_input_ids.to(self.device)
961
- decoder_input_ids = decoder_input_ids.to(self.device)
962
-
963
- if attention_mask is not None:
964
- attention_mask = torch.tensor(attention_mask,
965
- dtype=torch.int32,
966
- device=self.device)
967
-
968
- ## encoder run
969
- encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding
970
-
971
- encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input(
972
- encoder_input_ids, encoder_remove_input_padding, pad_token_id,
973
- prompt_tasks)
974
-
975
- if not self.skip_encoder:
976
- #logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
977
- if time_encoder:
978
- tik = time.time()
979
- encoder_output = self.encoder_run(
980
- encoder_input_ids,
981
- encoder_input_lengths,
982
- encoder_max_input_length,
983
- debug_mode=debug_mode,
984
- prompt_embedding_table=prompt_embedding_table,
985
- prompt_tasks=prompt_tasks,
986
- prompt_vocab_size=prompt_vocab_size,
987
- attention_mask=attention_mask)
988
- if time_encoder:
989
- tok = time.time()
990
- print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
991
- else:
992
- encoder_output = prompt_embedding_table
993
- if encoder_input_ids.dim() > 1:
994
- encoder_output = encoder_output.unsqueeze(0)
995
-
996
- ## decoder run
997
- # logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
998
- decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input(
999
- decoder_input_ids, self.decoder_model_config.remove_input_padding,
1000
- pad_token_id)
1001
-
1002
- # `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
1003
- # where query_len happens to be 1 in current cases, but not necessarily always, and
1004
- # `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
1005
- # the query_len is always 1 since we have kv cache.
1006
- cross_attention_mask = None
1007
- if attention_mask is not None:
1008
- cross_attention_mask = torch.tensor(attention_mask,
1009
- dtype=torch.int32,
1010
- device=self.device).reshape(
1011
- attention_mask.shape[0], 1,
1012
- attention_mask.shape[1])
1013
-
1014
- # generation config
1015
- sampling_config = SamplingConfig(end_id=eos_token_id,
1016
- pad_id=pad_token_id,
1017
- num_beams=num_beams,
1018
- min_length=1,
1019
- return_dict=return_dict)
1020
- sampling_config.update(output_cum_log_probs=return_dict,
1021
- output_log_probs=return_dict)
1022
-
1023
- # decoder autoregressive generation
1024
- self.decoder_session.setup(
1025
- decoder_input_lengths.size(0),
1026
- decoder_max_input_length,
1027
- max_new_tokens,
1028
- num_beams,
1029
- max_attention_window_size=None,
1030
- encoder_max_input_length=encoder_max_input_length,
1031
- lora_manager=self.decoder_lora_manager,
1032
- lora_uids=self.lora_task_uids,
1033
- )
1034
-
1035
- output = self.decoder_session.decode(
1036
- decoder_input_ids,
1037
- decoder_input_lengths,
1038
- sampling_config,
1039
- encoder_output=encoder_output,
1040
- encoder_input_lengths=encoder_input_lengths,
1041
- return_dict=return_dict,
1042
- cross_attention_mask=cross_attention_mask)
1043
-
1044
- if return_dict and return_encoder_output:
1045
- output['encoder_output'] = encoder_output
1046
-
1047
- return output