magic-pdf 0.6.2b1__py3-none-any.whl → 0.7.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +10 -3
  2. magic_pdf/libs/Constants.py +4 -1
  3. magic_pdf/libs/config_reader.py +10 -10
  4. magic_pdf/libs/draw_bbox.py +66 -1
  5. magic_pdf/libs/ocr_content_type.py +14 -0
  6. magic_pdf/libs/version.py +1 -1
  7. magic_pdf/model/doc_analyze_by_custom_model.py +10 -4
  8. magic_pdf/model/magic_model.py +4 -0
  9. magic_pdf/model/pdf_extract_kit.py +83 -39
  10. magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +22 -0
  11. magic_pdf/resources/model_config/model_configs.yaml +4 -0
  12. magic_pdf/rw/AbsReaderWriter.py +1 -18
  13. magic_pdf/rw/DiskReaderWriter.py +32 -24
  14. magic_pdf/rw/S3ReaderWriter.py +83 -48
  15. magic_pdf/tools/cli.py +79 -0
  16. magic_pdf/tools/cli_dev.py +155 -0
  17. magic_pdf/tools/common.py +122 -0
  18. magic_pdf-0.7.0b1.dist-info/METADATA +421 -0
  19. {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/RECORD +25 -27
  20. {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/WHEEL +1 -1
  21. magic_pdf-0.7.0b1.dist-info/entry_points.txt +3 -0
  22. magic_pdf/cli/magicpdf.py +0 -359
  23. magic_pdf/pdf_parse_for_train.py +0 -685
  24. magic_pdf/train_utils/convert_to_train_format.py +0 -65
  25. magic_pdf/train_utils/extract_caption.py +0 -59
  26. magic_pdf/train_utils/remove_footer_header.py +0 -159
  27. magic_pdf/train_utils/vis_utils.py +0 -327
  28. magic_pdf-0.6.2b1.dist-info/METADATA +0 -344
  29. magic_pdf-0.6.2b1.dist-info/entry_points.txt +0 -2
  30. /magic_pdf/{cli → model/pek_sub_modules/structeqtable}/__init__.py +0 -0
  31. /magic_pdf/{train_utils → tools}/__init__.py +0 -0
  32. {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/LICENSE.md +0 -0
  33. {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/top_level.txt +0 -0
@@ -120,15 +120,20 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
120
120
  if mode == 'nlp':
121
121
  continue
122
122
  elif mode == 'mm':
123
+ table_caption = ''
123
124
  for block in para_block['blocks']: # 1st.拼table_caption
124
125
  if block['type'] == BlockType.TableCaption:
125
- para_text += merge_para_with_text(block)
126
+ table_caption = merge_para_with_text(block)
126
127
  for block in para_block['blocks']: # 2nd.拼table_body
127
128
  if block['type'] == BlockType.TableBody:
128
129
  for line in block['lines']:
129
130
  for span in line['spans']:
130
131
  if span['type'] == ContentType.Table:
131
- para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
132
+ # if processed by table model
133
+ if span.get('latex', ''):
134
+ para_text += f"\n\n$\n {span['latex']}\n$\n\n"
135
+ else:
136
+ para_text += f"\n![{table_caption}]({join_path(img_buket_path, span['image_path'])}) \n"
132
137
  for block in para_block['blocks']: # 3rd.拼table_footnote
133
138
  if block['type'] == BlockType.TableFootnote:
134
139
  para_text += merge_para_with_text(block)
@@ -163,7 +168,7 @@ def merge_para_with_text(para_block):
163
168
  else:
164
169
  content = ocr_escape_special_markdown_char(content)
165
170
  elif span_type == ContentType.InlineEquation:
166
- content = f"${span['content']}$"
171
+ content = f" ${span['content']}$ "
167
172
  elif span_type == ContentType.InterlineEquation:
168
173
  content = f"\n$$\n{span['content']}\n$$\n"
169
174
 
@@ -249,6 +254,8 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
249
254
  }
250
255
  for block in para_block['blocks']:
251
256
  if block['type'] == BlockType.TableBody:
257
+ if block["lines"][0]["spans"][0].get('latex', ''):
258
+ para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['latex']}\n$\n\n"
252
259
  para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
253
260
  if block['type'] == BlockType.TableCaption:
254
261
  para_content['table_caption'] = merge_para_with_text(block)
@@ -8,4 +8,7 @@ CROSS_PAGE = "cross_page"
8
8
  block维度自定义字段
9
9
  """
10
10
  # block中lines是否被删除
11
- LINES_DELETED = "lines_deleted"
11
+ LINES_DELETED = "lines_deleted"
12
+
13
+ # table recognition max time default value
14
+ TABLE_MAX_TIME_VALUE = 400
@@ -57,16 +57,6 @@ def get_bucket_name(path):
57
57
  return bucket
58
58
 
59
59
 
60
- def get_local_dir():
61
- config = read_config()
62
- local_dir = config.get("temp-output-dir")
63
- if local_dir is None:
64
- logger.warning(f"'temp-output-dir' not found in {CONFIG_FILE_NAME}, use '/tmp' as default")
65
- return "/tmp"
66
- else:
67
- return local_dir
68
-
69
-
70
60
  def get_local_models_dir():
71
61
  config = read_config()
72
62
  models_dir = config.get("models-dir")
@@ -87,5 +77,15 @@ def get_device():
87
77
  return device
88
78
 
89
79
 
80
+ def get_table_recog_config():
81
+ config = read_config()
82
+ table_config = config.get("table-config")
83
+ if table_config is None:
84
+ logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
85
+ return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
86
+ else:
87
+ return table_config
88
+
89
+
90
90
  if __name__ == "__main__":
91
91
  ak, sk, endpoint = get_s3_config("llm-raw")
@@ -1,6 +1,7 @@
1
1
  from magic_pdf.libs.Constants import CROSS_PAGE
2
2
  from magic_pdf.libs.commons import fitz # PyMuPDF
3
- from magic_pdf.libs.ocr_content_type import ContentType, BlockType
3
+ from magic_pdf.libs.ocr_content_type import ContentType, BlockType, CategoryId
4
+ from magic_pdf.model.magic_model import MagicModel
4
5
 
5
6
 
6
7
  def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
@@ -225,3 +226,67 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path):
225
226
 
226
227
  # Save the PDF
227
228
  pdf_docs.save(f"{out_path}/spans.pdf")
229
+
230
+
231
+ def drow_model_bbox(model_list: list, pdf_bytes, out_path):
232
+ dropped_bbox_list = []
233
+ tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
234
+ imgs_body_list, imgs_caption_list = [], []
235
+ titles_list = []
236
+ texts_list = []
237
+ interequations_list = []
238
+ pdf_docs = fitz.open("pdf", pdf_bytes)
239
+ magic_model = MagicModel(model_list, pdf_docs)
240
+ for i in range(len(model_list)):
241
+ page_dropped_list = []
242
+ tables_body, tables_caption, tables_footnote = [], [], []
243
+ imgs_body, imgs_caption = [], []
244
+ titles = []
245
+ texts = []
246
+ interequations = []
247
+ page_info = magic_model.get_model_list(i)
248
+ layout_dets = page_info["layout_dets"]
249
+ for layout_det in layout_dets:
250
+ bbox = layout_det["bbox"]
251
+ if layout_det["category_id"] == CategoryId.Text:
252
+ texts.append(bbox)
253
+ elif layout_det["category_id"] == CategoryId.Title:
254
+ titles.append(bbox)
255
+ elif layout_det["category_id"] == CategoryId.TableBody:
256
+ tables_body.append(bbox)
257
+ elif layout_det["category_id"] == CategoryId.TableCaption:
258
+ tables_caption.append(bbox)
259
+ elif layout_det["category_id"] == CategoryId.TableFootnote:
260
+ tables_footnote.append(bbox)
261
+ elif layout_det["category_id"] == CategoryId.ImageBody:
262
+ imgs_body.append(bbox)
263
+ elif layout_det["category_id"] == CategoryId.ImageCaption:
264
+ imgs_caption.append(bbox)
265
+ elif layout_det["category_id"] == CategoryId.InterlineEquation_YOLO:
266
+ interequations.append(bbox)
267
+ elif layout_det["category_id"] == CategoryId.Abandon:
268
+ page_dropped_list.append(bbox)
269
+
270
+ tables_body_list.append(tables_body)
271
+ tables_caption_list.append(tables_caption)
272
+ tables_footnote_list.append(tables_footnote)
273
+ imgs_body_list.append(imgs_body)
274
+ imgs_caption_list.append(imgs_caption)
275
+ titles_list.append(titles)
276
+ texts_list.append(texts)
277
+ interequations_list.append(interequations)
278
+ dropped_bbox_list.append(page_dropped_list)
279
+
280
+ for i, page in enumerate(pdf_docs):
281
+ draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158], True) # color !
282
+ draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
283
+ draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
284
+ draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
285
+ draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
286
+ draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
287
+ draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
288
+ draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
289
+ draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
290
+
291
+ # Save the PDF
292
+ pdf_docs.save(f"{out_path}/model.pdf")
@@ -19,3 +19,17 @@ class BlockType:
19
19
  Footnote = "footnote"
20
20
  Discarded = "discarded"
21
21
 
22
+
23
+ class CategoryId:
24
+ Title = 0
25
+ Text = 1
26
+ Abandon = 2
27
+ ImageBody = 3
28
+ ImageCaption = 4
29
+ TableBody = 5
30
+ TableCaption = 6
31
+ TableFootnote = 7
32
+ InterlineEquation_Layout = 8
33
+ InlineEquation = 13
34
+ InterlineEquation_YOLO = 14
35
+ OcrText = 15
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.6.2b1"
1
+ __version__ = "0.7.0b1"
@@ -4,7 +4,7 @@ import fitz
4
4
  import numpy as np
5
5
  from loguru import logger
6
6
 
7
- from magic_pdf.libs.config_reader import get_local_models_dir, get_device
7
+ from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
8
8
  from magic_pdf.model.model_list import MODEL
9
9
  import magic_pdf.model as model_config
10
10
 
@@ -37,8 +37,8 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
37
37
  mat = fitz.Matrix(dpi / 72, dpi / 72)
38
38
  pm = page.get_pixmap(matrix=mat, alpha=False)
39
39
 
40
- # if width or height > 3000 pixels, don't enlarge the image
41
- if pm.width > 3000 or pm.height > 3000:
40
+ # If the width or height exceeds 9000 after scaling, do not scale further.
41
+ if pm.width > 9000 or pm.height > 9000:
42
42
  pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
43
43
 
44
44
  img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
@@ -84,7 +84,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
84
84
  # 从配置文件读取model-dir和device
85
85
  local_models_dir = get_local_models_dir()
86
86
  device = get_device()
87
- custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
87
+ table_config = get_table_recog_config()
88
+ model_input = {"ocr": ocr,
89
+ "show_log": show_log,
90
+ "models_dir": local_models_dir,
91
+ "device": device,
92
+ "table_config": table_config}
93
+ custom_model = CustomPEKModel(**model_input)
88
94
  else:
89
95
  logger.error("Not allow model_name!")
90
96
  exit(1)
@@ -560,6 +560,10 @@ class MagicModel:
560
560
  if category_id == 3:
561
561
  span["type"] = ContentType.Image
562
562
  elif category_id == 5:
563
+ # 获取table模型结果
564
+ latex = layout_det.get("latex", None)
565
+ if latex:
566
+ span["latex"] = latex
563
567
  span["type"] = ContentType.Table
564
568
  elif category_id == 13:
565
569
  span["content"] = layout_det["latex"]
@@ -2,6 +2,8 @@ from loguru import logger
2
2
  import os
3
3
  import time
4
4
 
5
+ from magic_pdf.libs.Constants import TABLE_MAX_TIME_VALUE
6
+
5
7
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
6
8
  try:
7
9
  import cv2
@@ -10,6 +12,7 @@ try:
10
12
  import numpy as np
11
13
  import torch
12
14
  import torchtext
15
+
13
16
  if torchtext.__version__ >= "0.18.0":
14
17
  torchtext.disable_torchtext_deprecation_warning()
15
18
  from PIL import Image
@@ -24,12 +27,18 @@ except ImportError as e:
24
27
  logger.exception(e)
25
28
  logger.error(
26
29
  'Required dependency not installed, please install by \n'
27
- '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
30
+ '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
28
31
  exit(1)
29
32
 
30
33
  from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
31
34
  from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
32
35
  from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
36
+ from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
37
+
38
+
39
+ def table_model_init(model_path, max_time, _device_='cpu'):
40
+ table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
41
+ return table_model
33
42
 
34
43
 
35
44
  def mfd_model_init(weight):
@@ -95,10 +104,13 @@ class CustomPEKModel:
95
104
  # 初始化解析配置
96
105
  self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
97
106
  self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
107
+ self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
108
+ self.apply_table = self.table_config.get("is_table_recog_enable", False)
109
+ self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
98
110
  self.apply_ocr = ocr
99
111
  logger.info(
100
- "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
101
- self.apply_layout, self.apply_formula, self.apply_ocr
112
+ "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
113
+ self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
102
114
  )
103
115
  )
104
116
  assert self.apply_layout, "DocAnalysis must contain layout model."
@@ -129,6 +141,10 @@ class CustomPEKModel:
129
141
  if self.apply_ocr:
130
142
  self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
131
143
 
144
+ # init structeqtable
145
+ if self.apply_table:
146
+ self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
147
+ max_time = self.table_max_time, _device_=self.device)
132
148
  logger.info('DocAnalysis init done!')
133
149
 
134
150
  def __call__(self, image):
@@ -172,50 +188,56 @@ class CustomPEKModel:
172
188
  mfr_cost = round(time.time() - mfr_start, 2)
173
189
  logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
174
190
 
191
+ # Select regions for OCR / formula regions / table regions
192
+ ocr_res_list = []
193
+ table_res_list = []
194
+ single_page_mfdetrec_res = []
195
+ for res in layout_res:
196
+ if int(res['category_id']) in [13, 14]:
197
+ single_page_mfdetrec_res.append({
198
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
199
+ int(res['poly'][4]), int(res['poly'][5])],
200
+ })
201
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
202
+ ocr_res_list.append(res)
203
+ elif int(res['category_id']) in [5]:
204
+ table_res_list.append(res)
205
+
206
+ # Unified crop img logic
207
+ def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
208
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
209
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
210
+ # Create a white background with an additional width and height of 50
211
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
212
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
213
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
214
+
215
+ # Crop image
216
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
217
+ cropped_img = input_pil_img.crop(crop_box)
218
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
219
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
220
+ return return_image, return_list
221
+
222
+ pil_img = Image.fromarray(image)
223
+
175
224
  # ocr识别
176
225
  if self.apply_ocr:
177
226
  ocr_start = time.time()
178
- pil_img = Image.fromarray(image)
179
-
180
- # 筛选出需要OCR的区域和公式区域
181
- ocr_res_list = []
182
- single_page_mfdetrec_res = []
183
- for res in layout_res:
184
- if int(res['category_id']) in [13, 14]:
185
- single_page_mfdetrec_res.append({
186
- "bbox": [int(res['poly'][0]), int(res['poly'][1]),
187
- int(res['poly'][4]), int(res['poly'][5])],
188
- })
189
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
190
- ocr_res_list.append(res)
191
-
192
- # 对每一个需OCR处理的区域进行处理
227
+ # Process each area that requires OCR processing
193
228
  for res in ocr_res_list:
194
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
195
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
196
-
197
- paste_x = 50
198
- paste_y = 50
199
- # 创建一个宽高各多50的白色背景
200
- new_width = xmax - xmin + paste_x * 2
201
- new_height = ymax - ymin + paste_y * 2
202
- new_image = Image.new('RGB', (new_width, new_height), 'white')
203
-
204
- # 裁剪图像
205
- crop_box = (xmin, ymin, xmax, ymax)
206
- cropped_img = pil_img.crop(crop_box)
207
- new_image.paste(cropped_img, (paste_x, paste_y))
208
-
209
- # 调整公式区域坐标
229
+ new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
230
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
231
+ # Adjust the coordinates of the formula area
210
232
  adjusted_mfdetrec_res = []
211
233
  for mf_res in single_page_mfdetrec_res:
212
234
  mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
213
- # 将公式区域坐标调整为相对于裁剪区域的坐标
235
+ # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
214
236
  x0 = mf_xmin - xmin + paste_x
215
237
  y0 = mf_ymin - ymin + paste_y
216
238
  x1 = mf_xmax - xmin + paste_x
217
239
  y1 = mf_ymax - ymin + paste_y
218
- # 过滤在图外的公式块
240
+ # Filter formula blocks outside the graph
219
241
  if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
220
242
  continue
221
243
  else:
@@ -223,17 +245,17 @@ class CustomPEKModel:
223
245
  "bbox": [x0, y0, x1, y1],
224
246
  })
225
247
 
226
- # OCR识别
248
+ # OCR recognition
227
249
  new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
228
250
  ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
229
251
 
230
- # 整合结果
252
+ # Integration results
231
253
  if ocr_res:
232
254
  for box_ocr_res in ocr_res:
233
255
  p1, p2, p3, p4 = box_ocr_res[0]
234
256
  text, score = box_ocr_res[1]
235
257
 
236
- # 将坐标转换回原图坐标系
258
+ # Convert the coordinates back to the original coordinate system
237
259
  p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
238
260
  p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
239
261
  p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
@@ -249,4 +271,26 @@ class CustomPEKModel:
249
271
  ocr_cost = round(time.time() - ocr_start, 2)
250
272
  logger.info(f"ocr cost: {ocr_cost}")
251
273
 
274
+ # 表格识别 table recognition
275
+ if self.apply_table:
276
+ table_start = time.time()
277
+ for res in table_res_list:
278
+ new_image, _ = crop_img(res, pil_img)
279
+ single_table_start_time = time.time()
280
+ logger.info("------------------table recognition processing begins-----------------")
281
+ with torch.no_grad():
282
+ latex_code = self.table_model.image2latex(new_image)[0]
283
+ run_time = time.time() - single_table_start_time
284
+ logger.info(f"------------table recognition processing ends within {run_time}s-----")
285
+ if run_time > self.table_max_time:
286
+ logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
287
+ # 判断是否返回正常
288
+ expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
289
+ if latex_code and expected_ending:
290
+ res["latex"] = latex_code
291
+ else:
292
+ logger.warning(f"------------table recognition processing fails----------")
293
+ table_cost = round(time.time() - table_start, 2)
294
+ logger.info(f"table cost: {table_cost}")
295
+
252
296
  return layout_res
@@ -0,0 +1,22 @@
1
+ from struct_eqtable.model import StructTable
2
+ from pypandoc import convert_text
3
+ class StructTableModel:
4
+ def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
5
+ # init
6
+ self.model_path = model_path
7
+ self.max_new_tokens = max_new_tokens # maximum output tokens length
8
+ self.max_time = max_time # timeout for processing in seconds
9
+ if device == 'cuda':
10
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
11
+ else:
12
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
13
+
14
+ def image2latex(self, image) -> str:
15
+ #
16
+ table_latex = self.model.forward(image)
17
+ return table_latex
18
+
19
+ def image2html(self, image) -> str:
20
+ table_latex = self.image2latex(image)
21
+ table_html = convert_text(table_latex, 'html', format='latex')
22
+ return table_html
@@ -2,8 +2,12 @@ config:
2
2
  device: cpu
3
3
  layout: True
4
4
  formula: True
5
+ table_config:
6
+ is_table_recog_enable: False
7
+ max_time: 400
5
8
 
6
9
  weights:
7
10
  layout: Layout/model_final.pth
8
11
  mfd: MFD/weights.pt
9
12
  mfr: MFR/UniMERNet
13
+ table: TabRec/StructEqTable
@@ -2,33 +2,16 @@ from abc import ABC, abstractmethod
2
2
 
3
3
 
4
4
  class AbsReaderWriter(ABC):
5
- """
6
- 同时支持二进制和文本读写的抽象类
7
- """
8
5
  MODE_TXT = "text"
9
6
  MODE_BIN = "binary"
10
-
11
- def __init__(self, parent_path):
12
- # 初始化代码可以在这里添加,如果需要的话
13
- self.parent_path = parent_path # 对于本地目录是父目录,对于s3是会写到这个path下。
14
-
15
7
  @abstractmethod
16
8
  def read(self, path: str, mode=MODE_TXT):
17
- """
18
- 无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
19
- """
20
9
  raise NotImplementedError
21
10
 
22
11
  @abstractmethod
23
12
  def write(self, content: str, path: str, mode=MODE_TXT):
24
- """
25
- 无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
26
- """
27
13
  raise NotImplementedError
28
14
 
29
15
  @abstractmethod
30
- def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding='utf-8'):
31
- """
32
- 无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
33
- """
16
+ def read_offset(self, path: str, offset=0, limit=None) -> bytes:
34
17
  raise NotImplementedError
@@ -3,34 +3,29 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
3
3
  from loguru import logger
4
4
 
5
5
 
6
- MODE_TXT = "text"
7
- MODE_BIN = "binary"
8
-
9
-
10
6
  class DiskReaderWriter(AbsReaderWriter):
11
-
12
7
  def __init__(self, parent_path, encoding="utf-8"):
13
8
  self.path = parent_path
14
9
  self.encoding = encoding
15
10
 
16
- def read(self, path, mode=MODE_TXT):
11
+ def read(self, path, mode=AbsReaderWriter.MODE_TXT):
17
12
  if os.path.isabs(path):
18
13
  abspath = path
19
14
  else:
20
15
  abspath = os.path.join(self.path, path)
21
16
  if not os.path.exists(abspath):
22
- logger.error(f"文件 {abspath} 不存在")
23
- raise Exception(f"文件 {abspath} 不存在")
24
- if mode == MODE_TXT:
17
+ logger.error(f"file {abspath} not exists")
18
+ raise Exception(f"file {abspath} no exists")
19
+ if mode == AbsReaderWriter.MODE_TXT:
25
20
  with open(abspath, "r", encoding=self.encoding) as f:
26
21
  return f.read()
27
- elif mode == MODE_BIN:
22
+ elif mode == AbsReaderWriter.MODE_BIN:
28
23
  with open(abspath, "rb") as f:
29
24
  return f.read()
30
25
  else:
31
26
  raise ValueError("Invalid mode. Use 'text' or 'binary'.")
32
27
 
33
- def write(self, content, path, mode=MODE_TXT):
28
+ def write(self, content, path, mode=AbsReaderWriter.MODE_TXT):
34
29
  if os.path.isabs(path):
35
30
  abspath = path
36
31
  else:
@@ -38,29 +33,42 @@ class DiskReaderWriter(AbsReaderWriter):
38
33
  directory_path = os.path.dirname(abspath)
39
34
  if not os.path.exists(directory_path):
40
35
  os.makedirs(directory_path)
41
- if mode == MODE_TXT:
36
+ if mode == AbsReaderWriter.MODE_TXT:
42
37
  with open(abspath, "w", encoding=self.encoding, errors="replace") as f:
43
38
  f.write(content)
44
39
 
45
- elif mode == MODE_BIN:
40
+ elif mode == AbsReaderWriter.MODE_BIN:
46
41
  with open(abspath, "wb") as f:
47
42
  f.write(content)
48
43
  else:
49
44
  raise ValueError("Invalid mode. Use 'text' or 'binary'.")
50
45
 
51
- def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding="utf-8"):
52
- return self.read(path)
46
+ def read_offset(self, path: str, offset=0, limit=None):
47
+ abspath = path
48
+ if not os.path.isabs(path):
49
+ abspath = os.path.join(self.path, path)
50
+ with open(abspath, "rb") as f:
51
+ f.seek(offset)
52
+ return f.read(limit)
53
53
 
54
54
 
55
- # 使用示例
56
55
  if __name__ == "__main__":
57
- file_path = "io/test/example.txt"
58
- drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf")
56
+ if 0:
57
+ file_path = "io/test/example.txt"
58
+ drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf")
59
+
60
+ # 写入内容到文件
61
+ drw.write(b"Hello, World!", path="io/test/example.txt", mode="binary")
62
+
63
+ # 从文件读取内容
64
+ content = drw.read(path=file_path)
65
+ if content:
66
+ logger.info(f"从 {file_path} 读取的内容: {content}")
67
+ if 1:
68
+ drw = DiskReaderWriter("/opt/data/pdf/resources/test/io/")
69
+ content_bin = drw.read_offset("1.txt")
70
+ assert content_bin == b"ABCD!"
59
71
 
60
- # 写入内容到文件
61
- drw.write(b"Hello, World!", path="io/test/example.txt", mode="binary")
72
+ content_bin = drw.read_offset("1.txt", offset=1, limit=2)
73
+ assert content_bin == b"BC"
62
74
 
63
- # 从文件读取内容
64
- content = drw.read(path=file_path)
65
- if content:
66
- logger.info(f"从 {file_path} 读取的内容: {content}")