doc-page-extractor 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/PKG-INFO +3 -2
  2. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/__init__.py +2 -1
  3. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/extractor.py +34 -32
  4. doc_page_extractor-0.2.0/doc_page_extractor/latex.py +31 -0
  5. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/layout_order.py +7 -10
  6. doc_page_extractor-0.2.0/doc_page_extractor/models.py +92 -0
  7. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/ocr.py +53 -28
  8. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/predict_base.py +9 -4
  9. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/predict_cls.py +23 -3
  10. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/predict_det.py +24 -5
  11. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/predict_rec.py +30 -7
  12. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/table.py +4 -5
  13. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/types.py +29 -5
  14. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor.egg-info/PKG-INFO +3 -2
  15. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor.egg-info/SOURCES.txt +1 -0
  16. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor.egg-info/requires.txt +2 -1
  17. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/setup.py +3 -2
  18. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/tests/test_history_bus.py +1 -1
  19. doc_page_extractor-0.1.1/doc_page_extractor/latex.py +0 -57
  20. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/LICENSE +0 -0
  21. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/README.md +0 -0
  22. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/clipper.py +0 -0
  23. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/downloader.py +0 -0
  24. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/layoutreader.py +0 -0
  25. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/ocr_corrector.py +0 -0
  26. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/__init__.py +0 -0
  27. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/cls_postprocess.py +0 -0
  28. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/db_postprocess.py +0 -0
  29. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/imaug.py +0 -0
  30. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/operators.py +0 -0
  31. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/predict_system.py +0 -0
  32. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/rec_postprocess.py +0 -0
  33. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/onnxocr/utils.py +0 -0
  34. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/overlap.py +0 -0
  35. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/plot.py +0 -0
  36. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/raw_optimizer.py +0 -0
  37. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/rectangle.py +0 -0
  38. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/rotation.py +0 -0
  39. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/__init__.py +0 -0
  40. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -0
  41. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -0
  42. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -0
  43. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -0
  44. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -0
  45. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -0
  46. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -0
  47. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor/utils.py +0 -0
  48. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor.egg-info/dependency_links.txt +0 -0
  49. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/doc_page_extractor.egg-info/top_level.txt +0 -0
  50. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/setup.cfg +0 -0
  51. {doc_page_extractor-0.1.1 → doc_page_extractor-0.2.0}/tests/__init__.py +0 -0
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: doc-page-extractor
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: doc page extractor can identify text and format in images and return structured data.
5
5
  Home-page: https://github.com/Moskize91/doc-page-extractor
6
6
  Author: Tao Zeyu
7
7
  Author-email: i@taozeyu.com
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
- Requires-Dist: opencv-python<5.0,>=4.11.0
10
+ Requires-Dist: opencv-python<5.0,>=4.10.0
11
11
  Requires-Dist: pillow<11.0,>=10.3
12
12
  Requires-Dist: pyclipper<2.0,>=1.2.0
13
13
  Requires-Dist: numpy<2.0,>=1.24.0
@@ -16,6 +16,7 @@ Requires-Dist: transformers<=4.47,>=4.42.4
16
16
  Requires-Dist: doclayout_yolo>=0.0.3
17
17
  Requires-Dist: pix2tex<=0.2.0,>=0.1.4
18
18
  Requires-Dist: accelerate<2.0,>=1.6.0
19
+ Requires-Dist: huggingface_hub>=0.30.2
19
20
  Dynamic: author
20
21
  Dynamic: author-email
21
22
  Dynamic: description
@@ -12,4 +12,5 @@ from .types import (
12
12
  PlainLayout,
13
13
  FormulaLayout,
14
14
  TableLayout,
15
- )
15
+ ModelsDownloader
16
+ )
@@ -1,15 +1,14 @@
1
- import os
2
1
 
3
2
  from typing import Literal, Generator
4
- from pathlib import Path
5
3
  from PIL.Image import Image
6
4
  from doclayout_yolo import YOLOv10
5
+ from logging import Logger, getLogger
7
6
 
7
+ from .models import HuggingfaceModelsDownloader
8
8
  from .ocr import OCR
9
9
  from .ocr_corrector import correct_fragments
10
10
  from .raw_optimizer import RawOptimizer
11
11
  from .rectangle import intersection_area, Rectangle
12
- from .downloader import download
13
12
  from .table import Table
14
13
  from .latex import LaTeX
15
14
  from .layout_order import LayoutOrder
@@ -17,50 +16,53 @@ from .overlap import merge_fragments_as_line, remove_overlap_layouts
17
16
  from .clipper import clip_from_image
18
17
  from .types import (
19
18
  ExtractedResult,
19
+ ModelsDownloader,
20
20
  OCRFragment,
21
- TableLayoutParsedFormat,
22
21
  Layout,
23
22
  LayoutClass,
24
23
  PlainLayout,
25
24
  TableLayout,
26
25
  FormulaLayout,
26
+ TableLayoutParsedFormat
27
27
  )
28
28
 
29
29
 
30
30
  class DocExtractor:
31
31
  def __init__(
32
32
  self,
33
- model_dir_path: str,
33
+ model_cache_dir: str | None = None,
34
34
  device: Literal["cpu", "cuda"] = "cpu",
35
- ocr_for_each_layouts: bool = True,
36
- extract_formula: bool = True,
37
- extract_table_format: TableLayoutParsedFormat | None = None,
35
+ models_downloader: ModelsDownloader | None = None,
36
+ logger: Logger | None = None,
38
37
  ):
39
- self._model_dir_path: str = model_dir_path
38
+ self._logger = logger or getLogger(__name__)
39
+ self._models_downloader = models_downloader or HuggingfaceModelsDownloader(self._logger, model_cache_dir)
40
+
40
41
  self._device: Literal["cpu", "cuda"] = device
41
- self._ocr_for_each_layouts: bool = ocr_for_each_layouts
42
- self._extract_formula: bool = extract_formula
43
- self._extract_table_format: TableLayoutParsedFormat | None = extract_table_format
44
42
  self._yolo: YOLOv10 | None = None
45
43
  self._ocr: OCR = OCR(
46
44
  device=device,
47
- model_dir_path=os.path.join(model_dir_path, "onnx_ocr"),
45
+ get_model_dir=self._models_downloader.onnx_ocr,
48
46
  )
49
47
  self._table: Table = Table(
50
48
  device=device,
51
- model_path=os.path.join(model_dir_path, "struct_eqtable"),
49
+ get_model_dir=self._models_downloader.struct_eqtable,
52
50
  )
53
51
  self._latex: LaTeX = LaTeX(
54
- model_path=os.path.join(model_dir_path, "latex"),
52
+ get_model_dir=self._models_downloader.latex,
53
+ device=device,
55
54
  )
56
55
  self._layout_order: LayoutOrder = LayoutOrder(
57
- model_path=os.path.join(model_dir_path, "layoutreader"),
56
+ get_model_dir=self._models_downloader.layoutreader,
58
57
  )
59
58
 
60
59
  def extract(
61
60
  self,
62
61
  image: Image,
63
- adjust_points: bool = False,
62
+ extract_formula: bool,
63
+ extract_table_format: TableLayoutParsedFormat | None = None,
64
+ ocr_for_each_layouts: bool = False,
65
+ adjust_points: bool = False
64
66
  ) -> ExtractedResult:
65
67
 
66
68
  raw_optimizer = RawOptimizer(image, adjust_points)
@@ -70,13 +72,13 @@ class DocExtractor:
70
72
  layouts = self._layouts_matched_by_fragments(fragments, layouts)
71
73
  layouts = remove_overlap_layouts(layouts)
72
74
 
73
- if self._ocr_for_each_layouts:
75
+ if ocr_for_each_layouts:
74
76
  self._correct_fragments_by_ocr_layouts(raw_optimizer.image, layouts)
75
77
 
76
78
  layouts = self._layout_order.sort(layouts, raw_optimizer.image.size)
77
79
  layouts = [layout for layout in layouts if self._should_keep_layout(layout)]
78
80
 
79
- self._parse_table_and_formula_layouts(layouts, raw_optimizer)
81
+ self._parse_table_and_formula_layouts(layouts, raw_optimizer, extract_formula=extract_formula, extract_table_format=extract_table_format)
80
82
 
81
83
  for layout in layouts:
82
84
  layout.fragments = merge_fragments_as_line(layout.fragments)
@@ -138,16 +140,22 @@ class DocExtractor:
138
140
  for layout in layouts:
139
141
  correct_fragments(self._ocr, source, layout)
140
142
 
141
- def _parse_table_and_formula_layouts(self, layouts: list[Layout], raw_optimizer: RawOptimizer):
143
+ def _parse_table_and_formula_layouts(
144
+ self,
145
+ layouts: list[Layout],
146
+ raw_optimizer: RawOptimizer,
147
+ extract_formula: bool,
148
+ extract_table_format: TableLayoutParsedFormat | None,
149
+ ):
142
150
  for layout in layouts:
143
- if isinstance(layout, FormulaLayout) and self._extract_formula:
151
+ if isinstance(layout, FormulaLayout) and extract_formula:
144
152
  image = clip_from_image(raw_optimizer.image, layout.rect)
145
153
  layout.latex = self._latex.extract(image)
146
- elif isinstance(layout, TableLayout) and self._extract_table_format is not None:
154
+ elif isinstance(layout, TableLayout) and extract_table_format is not None:
147
155
  image = clip_from_image(raw_optimizer.image, layout.rect)
148
- parsed = self._table.predict(image, self._extract_table_format)
156
+ parsed = self._table.predict(image, extract_table_format)
149
157
  if parsed is not None:
150
- layout.parsed = (parsed, self._extract_table_format)
158
+ layout.parsed = (parsed, extract_table_format)
151
159
 
152
160
  def _split_layouts_by_group(self, layouts: list[Layout]):
153
161
  texts_layouts: list[Layout] = []
@@ -191,14 +199,8 @@ class DocExtractor:
191
199
 
192
200
  def _get_yolo(self) -> YOLOv10:
193
201
  if self._yolo is None:
194
- base_path = os.path.join(self._model_dir_path, "yolo")
195
- os.makedirs(base_path, exist_ok=True)
196
- yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
197
- yolo_model_name = "doclayout_yolo_ft.pt"
198
- yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
199
- if not yolo_model_path.exists():
200
- download(yolo_model_url, yolo_model_path)
201
- self._yolo = YOLOv10(str(yolo_model_path))
202
+ model_path = self._models_downloader.yolo()
203
+ self._yolo = YOLOv10(str(model_path))
202
204
  return self._yolo
203
205
 
204
206
  def _should_keep_layout(self, layout: Layout) -> bool:
@@ -0,0 +1,31 @@
1
+ import os
2
+ import torch
3
+
4
+ from munch import Munch
5
+ from pix2tex.cli import LatexOCR
6
+ from PIL.Image import Image
7
+ from typing import Literal
8
+ from .utils import expand_image
9
+ from .types import GetModelDir
10
+
11
+ class LaTeX:
12
+ def __init__(self, device: Literal["cpu", "cuda"],get_model_dir: GetModelDir):
13
+ self._model_path: str = get_model_dir()
14
+ self._model: LatexOCR | None = None
15
+ self._device: Literal["cpu", "cuda"] = device
16
+
17
+ def extract(self, image: Image) -> str | None:
18
+ image = expand_image(image, 0.1) # 添加边缘提高识别准确率
19
+ model = self._get_model()
20
+ with torch.no_grad():
21
+ return model(image)
22
+
23
+ def _get_model(self) -> LatexOCR:
24
+ if self._model is None:
25
+ self._model = LatexOCR(Munch({
26
+ "config": os.path.join("settings", "config.yaml"),
27
+ "checkpoint": os.path.join(self._model_path, "checkpoints", "weights.pth"),
28
+ "no_cuda": self._device == "cpu",
29
+ "no_resize": False,
30
+ }))
31
+ return self._model
@@ -1,13 +1,11 @@
1
- import os
2
1
  import torch
3
2
 
4
3
  from typing import Generator
5
4
  from dataclasses import dataclass
6
5
  from transformers import LayoutLMv3ForTokenClassification
7
6
 
8
- from .types import Layout, LayoutClass
7
+ from .types import Layout, LayoutClass, GetModelDir
9
8
  from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
10
- from .utils import ensure_dir
11
9
 
12
10
 
13
11
  @dataclass
@@ -19,18 +17,17 @@ class _BBox:
19
17
  value: tuple[float, float, float, float]
20
18
 
21
19
  class LayoutOrder:
22
- def __init__(self, model_path: str):
23
- self._model_path: str = model_path
20
+ def __init__(self, get_model_dir: GetModelDir):
21
+ self._model_path: str = get_model_dir()
24
22
  self._model: LayoutLMv3ForTokenClassification | None = None
23
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
24
 
26
25
  def _get_model(self) -> LayoutLMv3ForTokenClassification:
27
26
  if self._model is None:
28
- model_path = ensure_dir(self._model_path)
29
27
  self._model = LayoutLMv3ForTokenClassification.from_pretrained(
30
- pretrained_model_name_or_path="hantian/layoutreader",
31
- cache_dir=model_path,
32
- local_files_only=os.path.exists(os.path.join(model_path, "models--hantian--layoutreader")),
33
- )
28
+ pretrained_model_name_or_path=self._model_path,
29
+ local_files_only=True,
30
+ ).to(device=self._device)
34
31
  return self._model
35
32
 
36
33
  def sort(self, layouts: list[Layout], size: tuple[int, int]) -> list[Layout]:
@@ -0,0 +1,92 @@
1
+ import os
2
+
3
+ from logging import Logger
4
+ from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache
5
+ from .types import ModelsDownloader
6
+
7
+ class HuggingfaceModelsDownloader(ModelsDownloader):
8
+ def __init__(
9
+ self,
10
+ logger: Logger,
11
+ model_dir_path: str | None
12
+ ):
13
+ self._logger = logger
14
+ self._model_dir_path: str | None = model_dir_path
15
+
16
+ def onnx_ocr(self) -> str:
17
+ repo_path = try_to_load_from_cache(
18
+ repo_id="moskize/OnnxOCR",
19
+ filename="README.md",
20
+ cache_dir=self._model_dir_path
21
+ )
22
+ if isinstance(repo_path, str):
23
+ return os.path.dirname(repo_path)
24
+ else:
25
+ self._logger.info("Downloading OCR model...")
26
+ return snapshot_download(
27
+ cache_dir=self._model_dir_path,
28
+ repo_id="moskize/OnnxOCR",
29
+ )
30
+
31
+ def yolo(self) -> str:
32
+ yolo_file_path = try_to_load_from_cache(
33
+ repo_id="opendatalab/PDF-Extract-Kit-1.0",
34
+ filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
35
+ cache_dir=self._model_dir_path
36
+ )
37
+ if isinstance(yolo_file_path, str):
38
+ return yolo_file_path
39
+ else:
40
+ self._logger.info("Downloading YOLO model...")
41
+ return hf_hub_download(
42
+ cache_dir=self._model_dir_path,
43
+ repo_id="opendatalab/PDF-Extract-Kit-1.0",
44
+ filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
45
+ )
46
+
47
+ def layoutreader(self) -> str:
48
+ repo_path = try_to_load_from_cache(
49
+ repo_id="hantian/layoutreader",
50
+ filename="model.safetensors",
51
+ cache_dir=self._model_dir_path
52
+ )
53
+ if isinstance(repo_path, str):
54
+ return os.path.dirname(repo_path)
55
+ else:
56
+ self._logger.info("Downloading LayoutReader model...")
57
+ return snapshot_download(
58
+ cache_dir=self._model_dir_path,
59
+ repo_id="hantian/layoutreader",
60
+ )
61
+
62
+ def struct_eqtable(self) -> str:
63
+ repo_path = try_to_load_from_cache(
64
+ repo_id="U4R/StructTable-InternVL2-1B",
65
+ filename="model.safetensors",
66
+ cache_dir=self._model_dir_path
67
+ )
68
+ if isinstance(repo_path, str):
69
+ return os.path.dirname(repo_path)
70
+ else:
71
+ self._logger.info("Downloading StructEqTable model...")
72
+ return snapshot_download(
73
+ cache_dir=self._model_dir_path,
74
+ repo_id="U4R/StructTable-InternVL2-1B",
75
+ )
76
+
77
+ def latex(self):
78
+ repo_path = try_to_load_from_cache(
79
+ repo_id="lukbl/LaTeX-OCR",
80
+ filename="checkpoints/weights.pth",
81
+ repo_type="space",
82
+ cache_dir=self._model_dir_path
83
+ )
84
+ if isinstance(repo_path, str):
85
+ return os.path.dirname(os.path.dirname(repo_path))
86
+ else:
87
+ self._logger.info("Downloading LaTeX model...")
88
+ return snapshot_download(
89
+ cache_dir=self._model_dir_path,
90
+ repo_type="space",
91
+ repo_id="lukbl/LaTeX-OCR",
92
+ )
@@ -5,9 +5,8 @@ import os
5
5
  from typing import Literal, Generator
6
6
  from dataclasses import dataclass
7
7
  from .onnxocr import TextSystem
8
- from .types import OCRFragment
8
+ from .types import GetModelDir, OCRFragment
9
9
  from .rectangle import Rectangle
10
- from .downloader import download
11
10
  from .utils import is_space_text
12
11
 
13
12
 
@@ -47,14 +46,17 @@ class _OONXParams:
47
46
  det_model_dir: str
48
47
  rec_char_dict_path: str
49
48
 
49
+
50
+
51
+
50
52
  class OCR:
51
53
  def __init__(
52
54
  self,
53
55
  device: Literal["cpu", "cuda"],
54
- model_dir_path: str,
56
+ get_model_dir: GetModelDir,
55
57
  ):
56
58
  self._device: Literal["cpu", "cuda"] = device
57
- self._model_dir_path: str = model_dir_path
59
+ self._get_model_dir: GetModelDir = get_model_dir
58
60
  self._text_system: TextSystem | None = None
59
61
 
60
62
  def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
@@ -87,20 +89,17 @@ class OCR:
87
89
  for box, res in zip(dt_boxes, rec_res):
88
90
  yield box.tolist(), res
89
91
 
92
+ def make_model_paths(self) -> list[str]:
93
+ model_paths = []
94
+ model_dir = self._get_model_dir()
95
+ for model_path in _MODELS:
96
+ file_name = os.path.join(*model_path)
97
+ model_paths.append(os.path.join(model_dir, file_name))
98
+ return model_paths
99
+
90
100
  def _get_text_system(self) -> TextSystem:
91
101
  if self._text_system is None:
92
- for model_path in _MODELS:
93
- file_path = os.path.join(self._model_dir_path, *model_path)
94
- if os.path.exists(file_path):
95
- continue
96
-
97
- file_dir_path = os.path.dirname(file_path)
98
- os.makedirs(file_dir_path, exist_ok=True)
99
-
100
- url_path = "/".join(model_path)
101
- url = f"https://huggingface.co/moskize/OnnxOCR/resolve/main/{url_path}"
102
- download(url, file_path)
103
-
102
+ model_paths = self.make_model_paths()
104
103
  self._text_system = TextSystem(_OONXParams(
105
104
  use_angle_cls=True,
106
105
  use_gpu=(self._device != "cpu"),
@@ -123,10 +122,10 @@ class OCR:
123
122
  save_crop_res=False,
124
123
  rec_algorithm="SVTR_LCNet",
125
124
  use_space_char=True,
126
- rec_model_dir=os.path.join(self._model_dir_path, *_MODELS[0]),
127
- cls_model_dir=os.path.join(self._model_dir_path, *_MODELS[1]),
128
- det_model_dir=os.path.join(self._model_dir_path, *_MODELS[2]),
129
- rec_char_dict_path=os.path.join(self._model_dir_path, *_MODELS[3]),
125
+ rec_model_dir=model_paths[0],
126
+ cls_model_dir=model_paths[1],
127
+ det_model_dir=model_paths[2],
128
+ rec_char_dict_path=model_paths[3],
130
129
  ))
131
130
 
132
131
  return self._text_system
@@ -142,14 +141,40 @@ class OCR:
142
141
  beta=255,
143
142
  norm_type=cv2.NORM_MINMAX,
144
143
  )
145
- image = cv2.fastNlMeansDenoisingColored(
146
- src=image,
147
- dst=None,
148
- h=10,
149
- hColor=10,
150
- templateWindowSize=7,
151
- searchWindowSize=15,
152
- )
144
+ if cv2.cuda.getCudaEnabledDeviceCount() > 0:
145
+ gpu_frame = cv2.cuda.GpuMat()
146
+ gpu_frame.upload(image)
147
+ image = cv2.cuda.fastNlMeansDenoisingColored(
148
+ src=gpu_frame,
149
+ dst=None,
150
+ h_luminance=10,
151
+ photo_render=10,
152
+ search_window=15,
153
+ block_size=7,
154
+ )
155
+ image = gpu_frame.download()
156
+ elif cv2.ocl.haveOpenCL():
157
+ cv2.ocl.setUseOpenCL(True)
158
+ gpu_frame = cv2.UMat(image)
159
+ image = cv2.fastNlMeansDenoisingColored(
160
+ src=gpu_frame,
161
+ dst=None,
162
+ h=10,
163
+ hColor=10,
164
+ templateWindowSize=7,
165
+ searchWindowSize=15,
166
+ )
167
+ image = image.get()
168
+ else:
169
+ image = cv2.fastNlMeansDenoisingColored(
170
+ src=image,
171
+ dst=None,
172
+ h=10,
173
+ hColor=10,
174
+ templateWindowSize=7,
175
+ searchWindowSize=15,
176
+ )
177
+
153
178
  # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
154
179
  return image
155
180
 
@@ -1,8 +1,13 @@
1
- import onnxruntime
2
-
3
1
  class PredictBase(object):
4
2
  def __init__(self):
5
- pass
3
+ self._onnxruntime = None
4
+
5
+ @property
6
+ def onnxruntime(self):
7
+ if self._onnxruntime is None:
8
+ import onnxruntime
9
+ self._onnxruntime = onnxruntime
10
+ return self._onnxruntime
6
11
 
7
12
  def get_onnx_session(self, model_dir, use_gpu):
8
13
  # 使用gpu
@@ -11,7 +16,7 @@ class PredictBase(object):
11
16
  else:
12
17
  providers = providers = ['CPUExecutionProvider']
13
18
 
14
- onnx_session = onnxruntime.InferenceSession(model_dir, None,providers=providers)
19
+ onnx_session = self.onnxruntime.InferenceSession(model_dir, None, providers=providers)
15
20
 
16
21
  # print("providers:", onnxruntime.get_device())
17
22
  return onnx_session
@@ -9,15 +9,35 @@ from .predict_base import PredictBase
9
9
 
10
10
  class TextClassifier(PredictBase):
11
11
  def __init__(self, args):
12
+ super().__init__()
12
13
  self.cls_image_shape = args.cls_image_shape
13
14
  self.cls_batch_num = args.cls_batch_num
14
15
  self.cls_thresh = args.cls_thresh
15
16
  self.postprocess_op = ClsPostProcess(label_list=args.label_list)
17
+ self._args = args
16
18
 
17
19
  # 初始化模型
18
- self.cls_onnx_session = self.get_onnx_session(args.cls_model_dir, args.use_gpu)
19
- self.cls_input_name = self.get_input_name(self.cls_onnx_session)
20
- self.cls_output_name = self.get_output_name(self.cls_onnx_session)
20
+ self._cls_onnx_session = None
21
+ self._cls_input_name = None
22
+ self._cls_output_name = None
23
+
24
+ @property
25
+ def cls_onnx_session(self):
26
+ if self._cls_onnx_session is None:
27
+ self._cls_onnx_session = self.get_onnx_session(self._args.cls_model_dir, self._args.use_gpu)
28
+ return self._cls_onnx_session
29
+
30
+ @property
31
+ def cls_input_name(self):
32
+ if self._cls_input_name is None:
33
+ self._cls_input_name = self.get_input_name(self.cls_onnx_session)
34
+ return self._cls_input_name
35
+
36
+ @property
37
+ def cls_output_name(self):
38
+ if self._cls_output_name is None:
39
+ self._cls_output_name = self.get_output_name(self.cls_onnx_session)
40
+ return self._cls_output_name
21
41
 
22
42
  def resize_norm_img(self, img):
23
43
  imgC, imgH, imgW = self.cls_image_shape
@@ -6,7 +6,8 @@ from .predict_base import PredictBase
6
6
 
7
7
  class TextDetector(PredictBase):
8
8
  def __init__(self, args):
9
- self.args = args
9
+ super().__init__()
10
+ self._args = args
10
11
  self.det_algorithm = args.det_algorithm
11
12
  pre_process_list = [
12
13
  {
@@ -43,9 +44,27 @@ class TextDetector(PredictBase):
43
44
  self.postprocess_op = DBPostProcess(**postprocess_params)
44
45
 
45
46
  # 初始化模型
46
- self.det_onnx_session = self.get_onnx_session(args.det_model_dir, args.use_gpu)
47
- self.det_input_name = self.get_input_name(self.det_onnx_session)
48
- self.det_output_name = self.get_output_name(self.det_onnx_session)
47
+ self._det_onnx_session = None
48
+ self._det_input_name = None
49
+ self._det_output_name = None
50
+
51
+ @property
52
+ def det_onnx_session(self):
53
+ if self._det_onnx_session is None:
54
+ self._det_onnx_session = self.get_onnx_session(self._args.det_model_dir, self._args.use_gpu)
55
+ return self._det_onnx_session
56
+
57
+ @property
58
+ def det_input_name(self):
59
+ if self._det_input_name is None:
60
+ self._det_input_name = self.get_input_name(self.det_onnx_session)
61
+ return self._det_input_name
62
+
63
+ @property
64
+ def det_output_name(self):
65
+ if self._det_output_name is None:
66
+ self._det_output_name = self.get_output_name(self.det_onnx_session)
67
+ return self._det_output_name
49
68
 
50
69
  def order_points_clockwise(self, pts):
51
70
  rect = np.zeros((4, 2), dtype="float32")
@@ -112,7 +131,7 @@ class TextDetector(PredictBase):
112
131
  post_result = self.postprocess_op(preds, shape_list)
113
132
  dt_boxes = post_result[0]["points"]
114
133
 
115
- if self.args.det_box_type == "poly":
134
+ if self._args.det_box_type == "poly":
116
135
  dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
117
136
  else:
118
137
  dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
@@ -10,6 +10,8 @@ from .predict_base import PredictBase
10
10
 
11
11
  class TextRecognizer(PredictBase):
12
12
  def __init__(self, args):
13
+ super().__init__()
14
+ self._args = args
13
15
  self.rec_image_shape = args.rec_image_shape
14
16
  self.rec_batch_num = args.rec_batch_num
15
17
  self.rec_algorithm = args.rec_algorithm
@@ -19,9 +21,29 @@ class TextRecognizer(PredictBase):
19
21
  )
20
22
 
21
23
  # 初始化模型
22
- self.rec_onnx_session = self.get_onnx_session(args.rec_model_dir, args.use_gpu)
23
- self.rec_input_name = self.get_input_name(self.rec_onnx_session)
24
- self.rec_output_name = self.get_output_name(self.rec_onnx_session)
24
+ self._rec_onnx_session = None
25
+ self._rec_input_name = None
26
+ self._rec_output_name = None
27
+
28
+ @property
29
+ def rec_onnx_session(self):
30
+ if self._rec_onnx_session is None:
31
+ self._rec_onnx_session = self.get_onnx_session(
32
+ self._args.rec_model_dir, self._args.use_gpu
33
+ )
34
+ return self._rec_onnx_session
35
+
36
+ @property
37
+ def rec_input_name(self):
38
+ if self._rec_input_name is None:
39
+ self._rec_input_name = self.get_input_name(self.rec_onnx_session)
40
+ return self._rec_input_name
41
+
42
+ @property
43
+ def rec_output_name(self):
44
+ if self._rec_output_name is None:
45
+ self._rec_output_name = self.get_output_name(self.rec_onnx_session)
46
+ return self._rec_output_name
25
47
 
26
48
  def resize_norm_img(self, img, max_wh_ratio):
27
49
  imgC, imgH, imgW = self.rec_image_shape
@@ -30,9 +52,9 @@ class TextRecognizer(PredictBase):
30
52
  # return padding_im
31
53
  image_pil = Image.fromarray(np.uint8(img))
32
54
  if self.rec_algorithm == "ViTSTR":
33
- img = image_pil.resize([imgW, imgH], Image.BICUBIC)
55
+ img = image_pil.resize([imgW, imgH], Image.Resampling.BICUBIC)
34
56
  else:
35
- img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
57
+ img = image_pil.resize([imgW, imgH], Image.Resampling.LANCZOS)
36
58
  img = np.array(img)
37
59
  norm_img = np.expand_dims(img, -1)
38
60
  norm_img = norm_img.transpose((2, 0, 1))
@@ -250,8 +272,9 @@ class TextRecognizer(PredictBase):
250
272
  def norm_img_can(self, img, image_shape):
251
273
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
252
274
 
253
- if self.inverse:
254
- img = 255 - img
275
+ # FIXME
276
+ # if self.inverse:
277
+ # img = 255 - img
255
278
 
256
279
  if self.rec_image_shape[0] == 1:
257
280
  h, w = img.shape
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
  from typing import Literal, Any
5
5
  from PIL.Image import Image
6
- from .types import TableLayoutParsedFormat
6
+ from .types import TableLayoutParsedFormat, GetModelDir
7
7
  from .utils import expand_image
8
8
 
9
9
 
@@ -13,10 +13,10 @@ class Table:
13
13
  def __init__(
14
14
  self,
15
15
  device: Literal["cpu", "cuda"],
16
- model_path: str,
16
+ get_model_dir: GetModelDir,
17
17
  ):
18
18
  self._model: Any | None = None
19
- self._model_path: str = model_path
19
+ self._model_path: str = get_model_dir()
20
20
  self._ban: bool = False
21
21
  if device == "cpu" or not torch.cuda.is_available():
22
22
  self._ban = True
@@ -58,13 +58,12 @@ class Table:
58
58
 
59
59
  from .struct_eqtable import build_model
60
60
  model = build_model(
61
- model_ckpt="U4R/StructTable-InternVL2-1B",
61
+ model_ckpt=self._model_path,
62
62
  max_new_tokens=1024,
63
63
  max_time=30,
64
64
  lmdeploy=False,
65
65
  flash_attn=True,
66
66
  batch_size=1,
67
- cache_dir=self._model_path,
68
67
  local_files_only=local_files_only,
69
68
  )
70
69
  self._model = model.cuda()
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Literal
2
+ from typing import Literal, Callable, Protocol, runtime_checkable, List
3
3
  from enum import auto, Enum
4
4
  from PIL.Image import Image
5
5
  from .rectangle import Rectangle
@@ -32,7 +32,7 @@ class TableLayoutParsedFormat(Enum):
32
32
  @dataclass
33
33
  class BaseLayout:
34
34
  rect: Rectangle
35
- fragments: list[OCRFragment]
35
+ fragments: List[OCRFragment]
36
36
 
37
37
  @dataclass
38
38
  class PlainLayout(BaseLayout):
@@ -59,9 +59,33 @@ class FormulaLayout(BaseLayout):
59
59
 
60
60
  Layout = PlainLayout | TableLayout | FormulaLayout
61
61
 
62
+
62
63
  @dataclass
63
64
  class ExtractedResult:
64
65
  rotation: float
65
- layouts: list[Layout]
66
- extracted_image: Image
67
- adjusted_image: Image | None
66
+ layouts: List[Layout]
67
+ extracted_image: Image | None
68
+ adjusted_image: Image | None
69
+
70
+ GetModelDir = Callable[[], str]
71
+
72
+
73
+ @runtime_checkable
74
+ class ModelsDownloader(Protocol):
75
+
76
+ def onnx_ocr(self) -> str:
77
+ pass
78
+
79
+ def yolo(self) -> str:
80
+ pass
81
+
82
+ def layoutreader(self) -> str:
83
+ pass
84
+
85
+ def struct_eqtable(self) -> str:
86
+ pass
87
+
88
+ def latex(self) -> str:
89
+ pass
90
+
91
+
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: doc-page-extractor
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: doc page extractor can identify text and format in images and return structured data.
5
5
  Home-page: https://github.com/Moskize91/doc-page-extractor
6
6
  Author: Tao Zeyu
7
7
  Author-email: i@taozeyu.com
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
- Requires-Dist: opencv-python<5.0,>=4.11.0
10
+ Requires-Dist: opencv-python<5.0,>=4.10.0
11
11
  Requires-Dist: pillow<11.0,>=10.3
12
12
  Requires-Dist: pyclipper<2.0,>=1.2.0
13
13
  Requires-Dist: numpy<2.0,>=1.24.0
@@ -16,6 +16,7 @@ Requires-Dist: transformers<=4.47,>=4.42.4
16
16
  Requires-Dist: doclayout_yolo>=0.0.3
17
17
  Requires-Dist: pix2tex<=0.2.0,>=0.1.4
18
18
  Requires-Dist: accelerate<2.0,>=1.6.0
19
+ Requires-Dist: huggingface_hub>=0.30.2
19
20
  Dynamic: author
20
21
  Dynamic: author-email
21
22
  Dynamic: description
@@ -8,6 +8,7 @@ doc_page_extractor/extractor.py
8
8
  doc_page_extractor/latex.py
9
9
  doc_page_extractor/layout_order.py
10
10
  doc_page_extractor/layoutreader.py
11
+ doc_page_extractor/models.py
11
12
  doc_page_extractor/ocr.py
12
13
  doc_page_extractor/ocr_corrector.py
13
14
  doc_page_extractor/overlap.py
@@ -1,4 +1,4 @@
1
- opencv-python<5.0,>=4.11.0
1
+ opencv-python<5.0,>=4.10.0
2
2
  pillow<11.0,>=10.3
3
3
  pyclipper<2.0,>=1.2.0
4
4
  numpy<2.0,>=1.24.0
@@ -7,3 +7,4 @@ transformers<=4.47,>=4.42.4
7
7
  doclayout_yolo>=0.0.3
8
8
  pix2tex<=0.2.0,>=0.1.4
9
9
  accelerate<2.0,>=1.6.0
10
+ huggingface_hub>=0.30.2
@@ -5,7 +5,7 @@ if "doc_page_extractor.struct_eqtable" not in find_packages():
5
5
 
6
6
  setup(
7
7
  name="doc-page-extractor",
8
- version="0.1.1",
8
+ version="0.2.0",
9
9
  author="Tao Zeyu",
10
10
  author_email="i@taozeyu.com",
11
11
  url="https://github.com/Moskize91/doc-page-extractor",
@@ -14,7 +14,7 @@ setup(
14
14
  long_description=open("./README.md", encoding="utf8").read(),
15
15
  long_description_content_type="text/markdown",
16
16
  install_requires=[
17
- "opencv-python>=4.11.0,<5.0",
17
+ "opencv-python>=4.10.0,<5.0",
18
18
  "pillow>=10.3,<11.0",
19
19
  "pyclipper>=1.2.0,<2.0",
20
20
  "numpy>=1.24.0,<2.0",
@@ -23,5 +23,6 @@ setup(
23
23
  "doclayout_yolo>=0.0.3",
24
24
  "pix2tex>=0.1.4,<=0.2.0",
25
25
  "accelerate>=1.6.0,<2.0",
26
+ "huggingface_hub>=0.30.2",
26
27
  ],
27
28
  )
@@ -15,7 +15,7 @@ class TestGroup(unittest.TestCase):
15
15
  layouts: list[tuple[LayoutClass, list[str]]]
16
16
 
17
17
  with Image.open(image_path) as image:
18
- result = extractor.extract(image, "ch")
18
+ result = extractor.extract(image, extract_formula=False)
19
19
  layouts = [self._format_Layout(layout) for layout in result.layouts]
20
20
 
21
21
  self.assertEqual(layouts, [
@@ -1,57 +0,0 @@
1
- import os
2
- import torch
3
- import requests
4
-
5
- from munch import Munch
6
- from pix2tex.cli import LatexOCR
7
- from PIL.Image import Image
8
- from .utils import expand_image
9
-
10
-
11
- class LaTeX:
12
- def __init__(self, model_path: str):
13
- self._model_path: str = model_path
14
- self._model: LatexOCR | None = None
15
-
16
- def extract(self, image: Image) -> str | None:
17
- image = expand_image(image, 0.1) # 添加边缘提高识别准确率
18
- model = self._get_model()
19
- with torch.no_grad():
20
- return model(image)
21
-
22
- def _get_model(self) -> LatexOCR:
23
- if self._model is None:
24
- if not os.path.exists(self._model_path):
25
- self._download_model()
26
-
27
- self._model = LatexOCR(Munch({
28
- "config": os.path.join("settings", "config.yaml"),
29
- "checkpoint": os.path.join(self._model_path, "weights.pth"),
30
- "no_cuda": True,
31
- "no_resize": False,
32
- }))
33
- return self._model
34
-
35
- # from https://github.com/lukas-blecher/LaTeX-OCR/blob/5c1ac929bd19a7ecf86d5fb8d94771c8969fcb80/pix2tex/model/checkpoints/get_latest_checkpoint.py#L37-L45
36
- def _download_model(self):
37
- os.makedirs(self._model_path, exist_ok=True)
38
- tag = "v0.0.1"
39
- files: list[tuple[str, str]] = (
40
- ("weights.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/weights.pth"),
41
- ("image_resizer.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/image_resizer.pth")
42
- )
43
- for file_name, url in files:
44
- file_path = os.path.join(self._model_path, file_name)
45
- try:
46
- with open(file_path, "wb") as file:
47
- response = requests.get(url, stream=True, timeout=15)
48
- response.raise_for_status()
49
- for chunk in response.iter_content(chunk_size=8192):
50
- if chunk: # 过滤掉保持连接的新块
51
- file.write(chunk)
52
- file.flush()
53
-
54
- except BaseException as e:
55
- if os.path.exists(file_path):
56
- os.remove(file_path)
57
- raise e