magic-pdf 0.5.13__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. magic_pdf/cli/magicpdf.py +18 -7
  2. magic_pdf/dict2md/ocr_mkcontent.py +2 -2
  3. magic_pdf/libs/config_reader.py +10 -0
  4. magic_pdf/libs/version.py +1 -1
  5. magic_pdf/model/__init__.py +1 -0
  6. magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
  7. magic_pdf/model/model_list.py +1 -0
  8. magic_pdf/model/pdf_extract_kit.py +200 -0
  9. magic_pdf/model/pek_sub_modules/__init__.py +0 -0
  10. magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
  11. magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
  12. magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
  13. magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
  14. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
  15. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
  16. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
  17. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
  18. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
  19. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
  20. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
  21. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
  22. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
  23. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
  24. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
  25. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
  26. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
  27. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
  28. magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
  29. magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
  30. magic_pdf/model/pek_sub_modules/post_process.py +36 -0
  31. magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
  32. magic_pdf/model/pp_structure_v2.py +7 -0
  33. magic_pdf/pipe/AbsPipe.py +8 -14
  34. magic_pdf/pipe/OCRPipe.py +12 -8
  35. magic_pdf/pipe/TXTPipe.py +12 -8
  36. magic_pdf/pipe/UNIPipe.py +9 -7
  37. magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
  38. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
  39. magic_pdf/resources/model_config/model_configs.yaml +9 -0
  40. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/METADATA +95 -12
  41. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/RECORD +45 -19
  42. magic_pdf/model/360_layout_analysis.py +0 -8
  43. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/LICENSE.md +0 -0
  44. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/WHEEL +0 -0
  45. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/entry_points.txt +0 -0
  46. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/top_level.txt +0 -0
magic_pdf/cli/magicpdf.py CHANGED
@@ -28,18 +28,20 @@ from loguru import logger
28
28
  from pathlib import Path
29
29
  from magic_pdf.libs.version import __version__
30
30
 
31
- from magic_pdf.libs.MakeContentConfig import DropMode
31
+ from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
32
32
  from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
33
33
  from magic_pdf.pipe.UNIPipe import UNIPipe
34
34
  from magic_pdf.pipe.OCRPipe import OCRPipe
35
35
  from magic_pdf.pipe.TXTPipe import TXTPipe
36
- from magic_pdf.libs.config_reader import get_s3_config
37
36
  from magic_pdf.libs.path_utils import (
38
37
  parse_s3path,
39
38
  parse_s3_range_params,
40
39
  remove_non_official_s3_args,
41
40
  )
42
- from magic_pdf.libs.config_reader import get_local_dir
41
+ from magic_pdf.libs.config_reader import (
42
+ get_local_dir,
43
+ get_s3_config,
44
+ )
43
45
  from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
44
46
  from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
45
47
  from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
@@ -81,10 +83,12 @@ def do_parse(
81
83
  f_dump_model_json=True,
82
84
  f_dump_orig_pdf=True,
83
85
  f_dump_content_list=True,
86
+ f_make_md_mode=MakeMode.MM_MD,
84
87
  ):
85
88
  orig_model_list = copy.deepcopy(model_list)
86
89
 
87
90
  local_image_dir, local_md_dir = prepare_env(pdf_file_name, parse_method)
91
+ logger.info(f"local output dir is {local_md_dir}")
88
92
  image_writer, md_writer = DiskReaderWriter(local_image_dir), DiskReaderWriter(local_md_dir)
89
93
  image_dir = str(os.path.basename(local_image_dir))
90
94
 
@@ -105,6 +109,7 @@ def do_parse(
105
109
  if len(model_list) == 0:
106
110
  if model_config.__use_inside_model__:
107
111
  pipe.pipe_analyze()
112
+ orig_model_list = copy.deepcopy(pipe.model_list)
108
113
  else:
109
114
  logger.error("need model list input")
110
115
  exit(1)
@@ -116,7 +121,7 @@ def do_parse(
116
121
  if f_draw_span_bbox:
117
122
  draw_span_bbox(pdf_info, pdf_bytes, local_md_dir)
118
123
 
119
- md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE)
124
+ md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
120
125
  if f_dump_md:
121
126
  """写markdown"""
122
127
  md_writer.write(
@@ -175,8 +180,10 @@ def cli():
175
180
  default="auto",
176
181
  )
177
182
  @click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
178
- def json_command(json, method, inside_model):
183
+ @click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
184
+ def json_command(json, method, inside_model, model_mode):
179
185
  model_config.__use_inside_model__ = inside_model
186
+ model_config.__model_mode__ = model_mode
180
187
 
181
188
  if not json.startswith("s3://"):
182
189
  logger.error("usage: magic-pdf json-command --json s3://some_bucket/some_path")
@@ -226,8 +233,10 @@ def json_command(json, method, inside_model):
226
233
  default="auto",
227
234
  )
228
235
  @click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
229
- def local_json_command(local_json, method, inside_model):
236
+ @click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
237
+ def local_json_command(local_json, method, inside_model, model_mode):
230
238
  model_config.__use_inside_model__ = inside_model
239
+ model_config.__model_mode__ = model_mode
231
240
 
232
241
  def read_s3_path(s3path):
233
242
  bucket, key = parse_s3path(s3path)
@@ -278,8 +287,10 @@ def local_json_command(local_json, method, inside_model):
278
287
  default="auto",
279
288
  )
280
289
  @click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
281
- def pdf_command(pdf, model, method, inside_model):
290
+ @click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
291
+ def pdf_command(pdf, model, method, inside_model, model_mode):
282
292
  model_config.__use_inside_model__ = inside_model
293
+ model_config.__model_mode__ = model_mode
283
294
 
284
295
  def read_fn(path):
285
296
  disk_rw = DiskReaderWriter(os.path.dirname(path))
@@ -112,7 +112,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
112
112
  for line in block['lines']:
113
113
  for span in line['spans']:
114
114
  if span['type'] == ContentType.Image:
115
- para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
115
+ para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
116
116
  for block in para_block['blocks']: # 2nd.拼image_caption
117
117
  if block['type'] == BlockType.ImageCaption:
118
118
  para_text += merge_para_with_text(block)
@@ -128,7 +128,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
128
128
  for line in block['lines']:
129
129
  for span in line['spans']:
130
130
  if span['type'] == ContentType.Table:
131
- para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
131
+ para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
132
132
  for block in para_block['blocks']: # 3rd.拼table_footnote
133
133
  if block['type'] == BlockType.TableFootnote:
134
134
  para_text += merge_para_with_text(block)
@@ -59,5 +59,15 @@ def get_local_dir():
59
59
  return config.get("temp-output-dir", "/tmp")
60
60
 
61
61
 
62
+ def get_local_models_dir():
63
+ config = read_config()
64
+ return config.get("models-dir", "/tmp/models")
65
+
66
+
67
+ def get_device():
68
+ config = read_config()
69
+ return config.get("device-mode", "cpu")
70
+
71
+
62
72
  if __name__ == "__main__":
63
73
  ak, sk, endpoint = get_s3_config("llm-raw")
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.5.13"
1
+ __version__ = "0.6.1"
@@ -1 +1,2 @@
1
1
  __use_inside_model__ = False
2
+ __model_mode__ = "full"
@@ -1,6 +1,10 @@
1
+ import time
2
+
1
3
  import fitz
2
4
  import numpy as np
3
5
  from loguru import logger
6
+
7
+ from magic_pdf.libs.config_reader import get_local_models_dir, get_device
4
8
  from magic_pdf.model.model_list import MODEL
5
9
  import magic_pdf.model as model_config
6
10
 
@@ -21,10 +25,11 @@ def remove_duplicates_dicts(lst):
21
25
 
22
26
  def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
23
27
  try:
24
- import cv2
25
28
  from PIL import Image
26
29
  except ImportError:
27
- logger.error("opencv-python and Pillow are not installed, please install by pip.")
30
+ logger.error("Pillow not installed, please install by pip.")
31
+ exit(1)
32
+
28
33
  images = []
29
34
  with fitz.open("pdf", pdf_bytes) as doc:
30
35
  for index in range(0, doc.page_count):
@@ -32,32 +37,49 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
32
37
  mat = fitz.Matrix(dpi / 72, dpi / 72)
33
38
  pm = page.get_pixmap(matrix=mat, alpha=False)
34
39
 
35
- # if width or height > 2000 pixels, don't enlarge the image
36
- # if pm.width > 2000 or pm.height > 2000:
37
- # pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
40
+ # if width or height > 3000 pixels, don't enlarge the image
41
+ if pm.width > 3000 or pm.height > 3000:
42
+ pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
38
43
 
39
- img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
40
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
44
+ img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
45
+ img = np.array(img)
41
46
  img_dict = {"img": img, "width": pm.width, "height": pm.height}
42
47
  images.append(img_dict)
43
48
  return images
44
49
 
45
50
 
46
- def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.Paddle):
51
+ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
52
+ model = None
53
+
54
+ if model_config.__model_mode__ == "lite":
55
+ model = MODEL.Paddle
56
+ elif model_config.__model_mode__ == "full":
57
+ model = MODEL.PEK
47
58
 
48
59
  if model_config.__use_inside_model__:
49
- from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
60
+ model_init_start = time.time()
61
+ if model == MODEL.Paddle:
62
+ from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
63
+ custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
64
+ elif model == MODEL.PEK:
65
+ from magic_pdf.model.pdf_extract_kit import CustomPEKModel
66
+ # 从配置文件读取model-dir和device
67
+ local_models_dir = get_local_models_dir()
68
+ device = get_device()
69
+ custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
70
+ else:
71
+ logger.error("Not allow model_name!")
72
+ exit(1)
73
+ model_init_cost = time.time() - model_init_start
74
+ logger.info(f"model init cost: {model_init_cost}")
50
75
  else:
51
76
  logger.error("use_inside_model is False, not allow to use inside model")
52
77
  exit(1)
53
78
 
54
79
  images = load_images_from_pdf(pdf_bytes)
55
- custom_model = None
56
- if model == MODEL.Paddle:
57
- custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
58
- else:
59
- pass
80
+
60
81
  model_json = []
82
+ doc_analyze_start = time.time()
61
83
  for index, img_dict in enumerate(images):
62
84
  img = img_dict["img"]
63
85
  page_width = img_dict["width"]
@@ -65,7 +87,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
65
87
  result = custom_model(img)
66
88
  page_info = {"page_no": index, "height": page_height, "width": page_width}
67
89
  page_dict = {"layout_dets": result, "page_info": page_info}
68
-
69
90
  model_json.append(page_dict)
91
+ doc_analyze_cost = time.time() - doc_analyze_start
92
+ logger.info(f"doc analyze cost: {doc_analyze_cost}")
70
93
 
71
94
  return model_json
@@ -1,2 +1,3 @@
1
1
  class MODEL:
2
2
  Paddle = "pp_structure_v2"
3
+ PEK = "pdf_extract_kit"
@@ -0,0 +1,200 @@
1
+ from loguru import logger
2
+ import os
3
+ try:
4
+ import cv2
5
+ import yaml
6
+ import time
7
+ import argparse
8
+ import numpy as np
9
+ import torch
10
+
11
+ from paddleocr import draw_ocr
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from ultralytics import YOLO
16
+ from unimernet.common.config import Config
17
+ import unimernet.tasks as tasks
18
+ from unimernet.processors import load_processor
19
+
20
+ from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
21
+ from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
22
+ from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
23
+ except ImportError:
24
+ logger.error('Required dependency not installed, please install by \n"pip install magic-pdf[full-cpu] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
25
+ exit(1)
26
+
27
+
28
+ def mfd_model_init(weight):
29
+ mfd_model = YOLO(weight)
30
+ return mfd_model
31
+
32
+
33
+ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
34
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
35
+ cfg = Config(args)
36
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
37
+ cfg.config.model.model_config.model_name = weight_dir
38
+ cfg.config.model.tokenizer_config.path = weight_dir
39
+ task = tasks.setup_task(cfg)
40
+ model = task.build_model(cfg)
41
+ model = model.to(_device_)
42
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
43
+ return model, vis_processor
44
+
45
+
46
+ def layout_model_init(weight, config_file, device):
47
+ model = Layoutlmv3_Predictor(weight, config_file, device)
48
+ return model
49
+
50
+
51
+ class MathDataset(Dataset):
52
+ def __init__(self, image_paths, transform=None):
53
+ self.image_paths = image_paths
54
+ self.transform = transform
55
+
56
+ def __len__(self):
57
+ return len(self.image_paths)
58
+
59
+ def __getitem__(self, idx):
60
+ # if not pil image, then convert to pil image
61
+ if isinstance(self.image_paths[idx], str):
62
+ raw_image = Image.open(self.image_paths[idx])
63
+ else:
64
+ raw_image = self.image_paths[idx]
65
+ if self.transform:
66
+ image = self.transform(raw_image)
67
+ return image
68
+
69
+
70
+ class CustomPEKModel:
71
+
72
+ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
73
+ """
74
+ ======== model init ========
75
+ """
76
+ # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
77
+ current_file_path = os.path.abspath(__file__)
78
+ # 获取当前文件所在的目录(model)
79
+ current_dir = os.path.dirname(current_file_path)
80
+ # 上一级目录(magic_pdf)
81
+ root_dir = os.path.dirname(current_dir)
82
+ # model_config目录
83
+ model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
84
+ # 构建 model_configs.yaml 文件的完整路径
85
+ config_path = os.path.join(model_config_dir, 'model_configs.yaml')
86
+ with open(config_path, "r") as f:
87
+ self.configs = yaml.load(f, Loader=yaml.FullLoader)
88
+ # 初始化解析配置
89
+ self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
90
+ self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
91
+ self.apply_ocr = ocr
92
+ logger.info(
93
+ "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
94
+ self.apply_layout, self.apply_formula, self.apply_ocr
95
+ )
96
+ )
97
+ assert self.apply_layout, "DocAnalysis must contain layout model."
98
+ # 初始化解析方案
99
+ self.device = kwargs.get("device", self.configs["config"]["device"])
100
+ logger.info("using device: {}".format(self.device))
101
+ models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
102
+
103
+ # 初始化公式识别
104
+ if self.apply_formula:
105
+ # 初始化公式检测模型
106
+ self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
107
+
108
+ # 初始化公式解析模型
109
+ mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
110
+ mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
111
+ self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
112
+ self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
113
+
114
+ # 初始化layout模型
115
+ self.layout_model = Layoutlmv3_Predictor(
116
+ str(os.path.join(models_dir, self.configs['weights']['layout'])),
117
+ str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
118
+ device=self.device
119
+ )
120
+ # 初始化ocr
121
+ if self.apply_ocr:
122
+ self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
123
+
124
+ logger.info('DocAnalysis init done!')
125
+
126
+ def __call__(self, image):
127
+
128
+ latex_filling_list = []
129
+ mf_image_list = []
130
+
131
+ # layout检测
132
+ layout_start = time.time()
133
+ layout_res = self.layout_model(image, ignore_catids=[])
134
+ layout_cost = round(time.time() - layout_start, 2)
135
+ logger.info(f"layout detection cost: {layout_cost}")
136
+
137
+ # 公式检测
138
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
139
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
140
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
141
+ new_item = {
142
+ 'category_id': 13 + int(cla.item()),
143
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
144
+ 'score': round(float(conf.item()), 2),
145
+ 'latex': '',
146
+ }
147
+ layout_res.append(new_item)
148
+ latex_filling_list.append(new_item)
149
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
150
+ mf_image_list.append(bbox_img)
151
+
152
+ # 公式识别
153
+ mfr_start = time.time()
154
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
155
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
156
+ mfr_res = []
157
+ for mf_img in dataloader:
158
+ mf_img = mf_img.to(self.device)
159
+ output = self.mfr_model.generate({'image': mf_img})
160
+ mfr_res.extend(output['pred_str'])
161
+ for res, latex in zip(latex_filling_list, mfr_res):
162
+ res['latex'] = latex_rm_whitespace(latex)
163
+ mfr_cost = round(time.time() - mfr_start, 2)
164
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
165
+
166
+ # ocr识别
167
+ if self.apply_ocr:
168
+ ocr_start = time.time()
169
+ pil_img = Image.fromarray(image)
170
+ single_page_mfdetrec_res = []
171
+ for res in layout_res:
172
+ if int(res['category_id']) in [13, 14]:
173
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
174
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
175
+ single_page_mfdetrec_res.append({
176
+ "bbox": [xmin, ymin, xmax, ymax],
177
+ })
178
+ for res in layout_res:
179
+ if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
180
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
181
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
182
+ crop_box = (xmin, ymin, xmax, ymax)
183
+ cropped_img = Image.new('RGB', pil_img.size, 'white')
184
+ cropped_img.paste(pil_img.crop(crop_box), crop_box)
185
+ cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
186
+ ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
187
+ if ocr_res:
188
+ for box_ocr_res in ocr_res:
189
+ p1, p2, p3, p4 = box_ocr_res[0]
190
+ text, score = box_ocr_res[1]
191
+ layout_res.append({
192
+ 'category_id': 15,
193
+ 'poly': p1 + p2 + p3 + p4,
194
+ 'score': round(score, 2),
195
+ 'text': text,
196
+ })
197
+ ocr_cost = round(time.time() - ocr_start, 2)
198
+ logger.info(f"ocr cost: {ocr_cost}")
199
+
200
+ return layout_res
File without changes
File without changes
@@ -0,0 +1,179 @@
1
+ # --------------------------------------------------------------------------------
2
+ # VIT: Multi-Path Vision Transformer for Dense Prediction
3
+ # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
4
+ # All Rights Reserved.
5
+ # Written by Youngwan Lee
6
+ # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ # --------------------------------------------------------------------------------
9
+ # References:
10
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
11
+ # CoaT: https://github.com/mlpc-ucsd/CoaT
12
+ # --------------------------------------------------------------------------------
13
+
14
+
15
+ import torch
16
+
17
+ from detectron2.layers import (
18
+ ShapeSpec,
19
+ )
20
+ from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
21
+ from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
22
+
23
+ from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
24
+ from .deit import deit_base_patch16, mae_base_patch16
25
+ from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
26
+ from transformers import AutoConfig
27
+
28
+ __all__ = [
29
+ "build_vit_fpn_backbone",
30
+ ]
31
+
32
+
33
+ class VIT_Backbone(Backbone):
34
+ """
35
+ Implement VIT backbone.
36
+ """
37
+
38
+ def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
39
+ config_path=None, image_only=False, cfg=None):
40
+ super().__init__()
41
+ self._out_features = out_features
42
+ if 'base' in name:
43
+ self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
44
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
45
+ else:
46
+ self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
47
+ self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
48
+
49
+ if name == 'beit_base_patch16':
50
+ model_func = beit_base_patch16
51
+ elif name == 'dit_base_patch16':
52
+ model_func = dit_base_patch16
53
+ elif name == "deit_base_patch16":
54
+ model_func = deit_base_patch16
55
+ elif name == "mae_base_patch16":
56
+ model_func = mae_base_patch16
57
+ elif name == "dit_large_patch16":
58
+ model_func = dit_large_patch16
59
+ elif name == "beit_large_patch16":
60
+ model_func = beit_large_patch16
61
+
62
+ if 'beit' in name or 'dit' in name:
63
+ if pos_type == "abs":
64
+ self.backbone = model_func(img_size=img_size,
65
+ out_features=out_features,
66
+ drop_path_rate=drop_path,
67
+ use_abs_pos_emb=True,
68
+ **model_kwargs)
69
+ elif pos_type == "shared_rel":
70
+ self.backbone = model_func(img_size=img_size,
71
+ out_features=out_features,
72
+ drop_path_rate=drop_path,
73
+ use_shared_rel_pos_bias=True,
74
+ **model_kwargs)
75
+ elif pos_type == "rel":
76
+ self.backbone = model_func(img_size=img_size,
77
+ out_features=out_features,
78
+ drop_path_rate=drop_path,
79
+ use_rel_pos_bias=True,
80
+ **model_kwargs)
81
+ else:
82
+ raise ValueError()
83
+ elif "layoutlmv3" in name:
84
+ config = AutoConfig.from_pretrained(config_path)
85
+ # disable relative bias as DiT
86
+ config.has_spatial_attention_bias = False
87
+ config.has_relative_attention_bias = False
88
+ self.backbone = LayoutLMv3Model(config, detection=True,
89
+ out_features=out_features, image_only=image_only)
90
+ else:
91
+ self.backbone = model_func(img_size=img_size,
92
+ out_features=out_features,
93
+ drop_path_rate=drop_path,
94
+ **model_kwargs)
95
+ self.name = name
96
+
97
+ def forward(self, x):
98
+ """
99
+ Args:
100
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
101
+
102
+ Returns:
103
+ dict[str->Tensor]: names and the corresponding features
104
+ """
105
+ if "layoutlmv3" in self.name:
106
+ return self.backbone.forward(
107
+ input_ids=x["input_ids"] if "input_ids" in x else None,
108
+ bbox=x["bbox"] if "bbox" in x else None,
109
+ images=x["images"] if "images" in x else None,
110
+ attention_mask=x["attention_mask"] if "attention_mask" in x else None,
111
+ # output_hidden_states=True,
112
+ )
113
+ assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
114
+ return self.backbone.forward_features(x)
115
+
116
+ def output_shape(self):
117
+ return {
118
+ name: ShapeSpec(
119
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
120
+ )
121
+ for name in self._out_features
122
+ }
123
+
124
+
125
+ def build_VIT_backbone(cfg):
126
+ """
127
+ Create a VIT instance from config.
128
+
129
+ Args:
130
+ cfg: a detectron2 CfgNode
131
+
132
+ Returns:
133
+ A VIT backbone instance.
134
+ """
135
+ # fmt: off
136
+ name = cfg.MODEL.VIT.NAME
137
+ out_features = cfg.MODEL.VIT.OUT_FEATURES
138
+ drop_path = cfg.MODEL.VIT.DROP_PATH
139
+ img_size = cfg.MODEL.VIT.IMG_SIZE
140
+ pos_type = cfg.MODEL.VIT.POS_TYPE
141
+
142
+ model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
143
+
144
+ if 'layoutlmv3' in name:
145
+ if cfg.MODEL.CONFIG_PATH != '':
146
+ config_path = cfg.MODEL.CONFIG_PATH
147
+ else:
148
+ config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
149
+ config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
150
+ else:
151
+ config_path = None
152
+
153
+ return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
154
+ config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
155
+
156
+
157
+ @BACKBONE_REGISTRY.register()
158
+ def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
159
+ """
160
+ Create a VIT w/ FPN backbone.
161
+
162
+ Args:
163
+ cfg: a detectron2 CfgNode
164
+
165
+ Returns:
166
+ backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
167
+ """
168
+ bottom_up = build_VIT_backbone(cfg)
169
+ in_features = cfg.MODEL.FPN.IN_FEATURES
170
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
171
+ backbone = FPN(
172
+ bottom_up=bottom_up,
173
+ in_features=in_features,
174
+ out_channels=out_channels,
175
+ norm=cfg.MODEL.FPN.NORM,
176
+ top_block=LastLevelMaxPool(),
177
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
178
+ )
179
+ return backbone