doc-page-extractor 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of doc-page-extractor might be problematic. Click here for more details.

Files changed (51) hide show
  1. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/PKG-INFO +2 -1
  2. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/__init__.py +2 -1
  3. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/extractor.py +22 -23
  4. doc_page_extractor-0.1.2/doc_page_extractor/latex.py +29 -0
  5. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/layout_order.py +5 -9
  6. doc_page_extractor-0.1.2/doc_page_extractor/models.py +92 -0
  7. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/ocr.py +20 -21
  8. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/rectangle.py +6 -0
  9. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/table.py +4 -5
  10. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/types.py +23 -2
  11. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor.egg-info/PKG-INFO +2 -1
  12. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor.egg-info/SOURCES.txt +1 -0
  13. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor.egg-info/requires.txt +1 -0
  14. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/setup.py +2 -1
  15. doc_page_extractor-0.1.0/doc_page_extractor/latex.py +0 -57
  16. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/LICENSE +0 -0
  17. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/README.md +0 -0
  18. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/clipper.py +0 -0
  19. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/downloader.py +0 -0
  20. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/layoutreader.py +0 -0
  21. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/ocr_corrector.py +0 -0
  22. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/__init__.py +0 -0
  23. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/cls_postprocess.py +0 -0
  24. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/db_postprocess.py +0 -0
  25. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/imaug.py +0 -0
  26. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/operators.py +0 -0
  27. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/predict_base.py +0 -0
  28. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/predict_cls.py +0 -0
  29. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/predict_det.py +0 -0
  30. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/predict_rec.py +0 -0
  31. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/predict_system.py +0 -0
  32. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/rec_postprocess.py +0 -0
  33. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/onnxocr/utils.py +0 -0
  34. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/overlap.py +0 -0
  35. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/plot.py +0 -0
  36. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/raw_optimizer.py +0 -0
  37. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/rotation.py +0 -0
  38. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/__init__.py +0 -0
  39. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -0
  40. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -0
  41. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -0
  42. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -0
  43. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -0
  44. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -0
  45. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -0
  46. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor/utils.py +0 -0
  47. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor.egg-info/dependency_links.txt +0 -0
  48. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/doc_page_extractor.egg-info/top_level.txt +0 -0
  49. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/setup.cfg +0 -0
  50. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/tests/__init__.py +0 -0
  51. {doc_page_extractor-0.1.0 → doc_page_extractor-0.1.2}/tests/test_history_bus.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: doc-page-extractor
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: doc page extractor can identify text and format in images and return structured data.
5
5
  Home-page: https://github.com/Moskize91/doc-page-extractor
6
6
  Author: Tao Zeyu
@@ -16,6 +16,7 @@ Requires-Dist: transformers<=4.47,>=4.42.4
16
16
  Requires-Dist: doclayout_yolo>=0.0.3
17
17
  Requires-Dist: pix2tex<=0.2.0,>=0.1.4
18
18
  Requires-Dist: accelerate<2.0,>=1.6.0
19
+ Requires-Dist: huggingface_hub>=0.30.2
19
20
  Dynamic: author
20
21
  Dynamic: author-email
21
22
  Dynamic: description
@@ -12,4 +12,5 @@ from .types import (
12
12
  PlainLayout,
13
13
  FormulaLayout,
14
14
  TableLayout,
15
- )
15
+ ModelsDownloader
16
+ )
@@ -1,15 +1,14 @@
1
- import os
2
1
 
3
2
  from typing import Literal, Generator
4
- from pathlib import Path
5
3
  from PIL.Image import Image
6
4
  from doclayout_yolo import YOLOv10
5
+ from logging import Logger, getLogger
7
6
 
7
+ from .models import HuggingfaceModelsDownloader
8
8
  from .ocr import OCR
9
9
  from .ocr_corrector import correct_fragments
10
10
  from .raw_optimizer import RawOptimizer
11
11
  from .rectangle import intersection_area, Rectangle
12
- from .downloader import download
13
12
  from .table import Table
14
13
  from .latex import LaTeX
15
14
  from .layout_order import LayoutOrder
@@ -17,6 +16,7 @@ from .overlap import merge_fragments_as_line, remove_overlap_layouts
17
16
  from .clipper import clip_from_image
18
17
  from .types import (
19
18
  ExtractedResult,
19
+ ModelsDownloader,
20
20
  OCRFragment,
21
21
  TableLayoutParsedFormat,
22
22
  Layout,
@@ -30,13 +30,17 @@ from .types import (
30
30
  class DocExtractor:
31
31
  def __init__(
32
32
  self,
33
- model_dir_path: str,
33
+ model_cache_dir: str | None = None,
34
34
  device: Literal["cpu", "cuda"] = "cpu",
35
35
  ocr_for_each_layouts: bool = True,
36
36
  extract_formula: bool = True,
37
37
  extract_table_format: TableLayoutParsedFormat | None = None,
38
+ models_downloader: ModelsDownloader | None = None,
39
+ logger: Logger | None = None,
38
40
  ):
39
- self._model_dir_path: str = model_dir_path
41
+ self._logger = logger or getLogger(__name__)
42
+ self._models_downloader = models_downloader or HuggingfaceModelsDownloader(self._logger, model_cache_dir)
43
+
40
44
  self._device: Literal["cpu", "cuda"] = device
41
45
  self._ocr_for_each_layouts: bool = ocr_for_each_layouts
42
46
  self._extract_formula: bool = extract_formula
@@ -44,17 +48,17 @@ class DocExtractor:
44
48
  self._yolo: YOLOv10 | None = None
45
49
  self._ocr: OCR = OCR(
46
50
  device=device,
47
- model_dir_path=os.path.join(model_dir_path, "onnx_ocr"),
51
+ get_model_dir=self._models_downloader.onnx_ocr,
48
52
  )
49
53
  self._table: Table = Table(
50
54
  device=device,
51
- model_path=os.path.join(model_dir_path, "struct_eqtable"),
55
+ get_model_dir=self._models_downloader.struct_eqtable,
52
56
  )
53
57
  self._latex: LaTeX = LaTeX(
54
- model_path=os.path.join(model_dir_path, "latex"),
58
+ get_model_dir=self._models_downloader.latex,
55
59
  )
56
60
  self._layout_order: LayoutOrder = LayoutOrder(
57
- model_path=os.path.join(model_dir_path, "layoutreader"),
61
+ get_model_dir=self._models_downloader.layoutreader,
58
62
  )
59
63
 
60
64
  def extract(
@@ -116,12 +120,13 @@ class DocExtractor:
116
120
  lb=(x1, y2),
117
121
  rb=(x2, y2),
118
122
  )
119
- if cls == LayoutClass.TABLE:
120
- yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
121
- elif cls == LayoutClass.ISOLATE_FORMULA:
122
- yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
123
- else:
124
- yield PlainLayout(cls=cls, rect=rect, fragments=[])
123
+ if rect.is_valid:
124
+ if cls == LayoutClass.TABLE:
125
+ yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
126
+ elif cls == LayoutClass.ISOLATE_FORMULA:
127
+ yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
128
+ else:
129
+ yield PlainLayout(cls=cls, rect=rect, fragments=[])
125
130
 
126
131
  def _layouts_matched_by_fragments(self, fragments: list[OCRFragment], layouts: list[Layout]):
127
132
  layouts_group = self._split_layouts_by_group(layouts)
@@ -190,14 +195,8 @@ class DocExtractor:
190
195
 
191
196
  def _get_yolo(self) -> YOLOv10:
192
197
  if self._yolo is None:
193
- base_path = os.path.join(self._model_dir_path, "yolo")
194
- os.makedirs(base_path, exist_ok=True)
195
- yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
196
- yolo_model_name = "doclayout_yolo_ft.pt"
197
- yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
198
- if not yolo_model_path.exists():
199
- download(yolo_model_url, yolo_model_path)
200
- self._yolo = YOLOv10(str(yolo_model_path))
198
+ model_path = self._models_downloader.yolo()
199
+ self._yolo = YOLOv10(str(model_path))
201
200
  return self._yolo
202
201
 
203
202
  def _should_keep_layout(self, layout: Layout) -> bool:
@@ -0,0 +1,29 @@
1
+ import os
2
+ import torch
3
+
4
+ from munch import Munch
5
+ from pix2tex.cli import LatexOCR
6
+ from PIL.Image import Image
7
+ from .utils import expand_image
8
+ from .types import GetModelDir
9
+
10
+ class LaTeX:
11
+ def __init__(self, get_model_dir: GetModelDir):
12
+ self._model_path: str = get_model_dir()
13
+ self._model: LatexOCR | None = None
14
+
15
+ def extract(self, image: Image) -> str | None:
16
+ image = expand_image(image, 0.1) # 添加边缘提高识别准确率
17
+ model = self._get_model()
18
+ with torch.no_grad():
19
+ return model(image)
20
+
21
+ def _get_model(self) -> LatexOCR:
22
+ if self._model is None:
23
+ self._model = LatexOCR(Munch({
24
+ "config": os.path.join("settings", "config.yaml"),
25
+ "checkpoint": os.path.join(self._model_path, "checkpoints", "weights.pth"),
26
+ "no_cuda": True,
27
+ "no_resize": False,
28
+ }))
29
+ return self._model
@@ -1,13 +1,11 @@
1
- import os
2
1
  import torch
3
2
 
4
3
  from typing import Generator
5
4
  from dataclasses import dataclass
6
5
  from transformers import LayoutLMv3ForTokenClassification
7
6
 
8
- from .types import Layout, LayoutClass
7
+ from .types import Layout, LayoutClass, GetModelDir
9
8
  from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
10
- from .utils import ensure_dir
11
9
 
12
10
 
13
11
  @dataclass
@@ -19,17 +17,15 @@ class _BBox:
19
17
  value: tuple[float, float, float, float]
20
18
 
21
19
  class LayoutOrder:
22
- def __init__(self, model_path: str):
23
- self._model_path: str = model_path
20
+ def __init__(self, get_model_dir: GetModelDir):
21
+ self._model_path: str = get_model_dir()
24
22
  self._model: LayoutLMv3ForTokenClassification | None = None
25
23
 
26
24
  def _get_model(self) -> LayoutLMv3ForTokenClassification:
27
25
  if self._model is None:
28
- model_path = ensure_dir(self._model_path)
29
26
  self._model = LayoutLMv3ForTokenClassification.from_pretrained(
30
- pretrained_model_name_or_path="hantian/layoutreader",
31
- cache_dir=model_path,
32
- local_files_only=os.path.exists(os.path.join(model_path, "models--hantian--layoutreader")),
27
+ pretrained_model_name_or_path=self._model_path,
28
+ local_files_only=True,
33
29
  )
34
30
  return self._model
35
31
 
@@ -0,0 +1,92 @@
1
+ import os
2
+
3
+ from logging import Logger
4
+ from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache
5
+ from .types import ModelsDownloader
6
+
7
+ class HuggingfaceModelsDownloader(ModelsDownloader):
8
+ def __init__(
9
+ self,
10
+ logger: Logger,
11
+ model_dir_path: str | None
12
+ ):
13
+ self._logger = logger
14
+ self._model_dir_path: str | None = model_dir_path
15
+
16
+ def onnx_ocr(self) -> str:
17
+ repo_path = try_to_load_from_cache(
18
+ repo_id="moskize/OnnxOCR",
19
+ filename="README.md",
20
+ cache_dir=self._model_dir_path
21
+ )
22
+ if isinstance(repo_path, str):
23
+ return os.path.dirname(repo_path)
24
+ else:
25
+ self._logger.info("Downloading OCR model...")
26
+ return snapshot_download(
27
+ cache_dir=self._model_dir_path,
28
+ repo_id="moskize/OnnxOCR",
29
+ )
30
+
31
+ def yolo(self) -> str:
32
+ yolo_file_path = try_to_load_from_cache(
33
+ repo_id="opendatalab/PDF-Extract-Kit-1.0",
34
+ filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
35
+ cache_dir=self._model_dir_path
36
+ )
37
+ if isinstance(yolo_file_path, str):
38
+ return yolo_file_path
39
+ else:
40
+ self._logger.info("Downloading YOLO model...")
41
+ return hf_hub_download(
42
+ cache_dir=self._model_dir_path,
43
+ repo_id="opendatalab/PDF-Extract-Kit-1.0",
44
+ filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
45
+ )
46
+
47
+ def layoutreader(self) -> str:
48
+ repo_path = try_to_load_from_cache(
49
+ repo_id="hantian/layoutreader",
50
+ filename="model.safetensors",
51
+ cache_dir=self._model_dir_path
52
+ )
53
+ if isinstance(repo_path, str):
54
+ return os.path.dirname(repo_path)
55
+ else:
56
+ self._logger.info("Downloading LayoutReader model...")
57
+ return snapshot_download(
58
+ cache_dir=self._model_dir_path,
59
+ repo_id="hantian/layoutreader",
60
+ )
61
+
62
+ def struct_eqtable(self) -> str:
63
+ repo_path = try_to_load_from_cache(
64
+ repo_id="U4R/StructTable-InternVL2-1B",
65
+ filename="model.safetensors",
66
+ cache_dir=self._model_dir_path
67
+ )
68
+ if isinstance(repo_path, str):
69
+ return os.path.dirname(repo_path)
70
+ else:
71
+ self._logger.info("Downloading StructEqTable model...")
72
+ return snapshot_download(
73
+ cache_dir=self._model_dir_path,
74
+ repo_id="U4R/StructTable-InternVL2-1B",
75
+ )
76
+
77
+ def latex(self):
78
+ repo_path = try_to_load_from_cache(
79
+ repo_id="lukbl/LaTeX-OCR",
80
+ filename="checkpoints/weights.pth",
81
+ repo_type="space",
82
+ cache_dir=self._model_dir_path
83
+ )
84
+ if isinstance(repo_path, str):
85
+ return os.path.dirname(os.path.dirname(repo_path))
86
+ else:
87
+ self._logger.info("Downloading LaTeX model...")
88
+ return snapshot_download(
89
+ cache_dir=self._model_dir_path,
90
+ repo_type="space",
91
+ repo_id="lukbl/LaTeX-OCR",
92
+ )
@@ -5,9 +5,8 @@ import os
5
5
  from typing import Literal, Generator
6
6
  from dataclasses import dataclass
7
7
  from .onnxocr import TextSystem
8
- from .types import OCRFragment
8
+ from .types import GetModelDir, OCRFragment
9
9
  from .rectangle import Rectangle
10
- from .downloader import download
11
10
  from .utils import is_space_text
12
11
 
13
12
 
@@ -47,14 +46,17 @@ class _OONXParams:
47
46
  det_model_dir: str
48
47
  rec_char_dict_path: str
49
48
 
49
+
50
+
51
+
50
52
  class OCR:
51
53
  def __init__(
52
54
  self,
53
55
  device: Literal["cpu", "cuda"],
54
- model_dir_path: str,
56
+ get_model_dir: GetModelDir,
55
57
  ):
56
58
  self._device: Literal["cpu", "cuda"] = device
57
- self._model_dir_path: str = model_dir_path
59
+ self._get_model_dir: GetModelDir = get_model_dir
58
60
  self._text_system: TextSystem | None = None
59
61
 
60
62
  def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
@@ -69,7 +71,7 @@ class OCR:
69
71
  rb=(box[2][0], box[2][1]),
70
72
  lb=(box[3][0], box[3][1]),
71
73
  )
72
- if rect.area == 0.0:
74
+ if not rect.is_valid or rect.area == 0.0:
73
75
  continue
74
76
 
75
77
  yield OCRFragment(
@@ -87,20 +89,17 @@ class OCR:
87
89
  for box, res in zip(dt_boxes, rec_res):
88
90
  yield box.tolist(), res
89
91
 
92
+ def make_model_paths(self) -> list[str]:
93
+ model_paths = []
94
+ model_dir = self._get_model_dir()
95
+ for model_path in _MODELS:
96
+ file_name = os.path.join(*model_path)
97
+ model_paths.append(os.path.join(model_dir, file_name))
98
+ return model_paths
99
+
90
100
  def _get_text_system(self) -> TextSystem:
91
101
  if self._text_system is None:
92
- for model_path in _MODELS:
93
- file_path = os.path.join(self._model_dir_path, *model_path)
94
- if os.path.exists(file_path):
95
- continue
96
-
97
- file_dir_path = os.path.dirname(file_path)
98
- os.makedirs(file_dir_path, exist_ok=True)
99
-
100
- url_path = "/".join(model_path)
101
- url = f"https://huggingface.co/moskize/OnnxOCR/resolve/main/{url_path}"
102
- download(url, file_path)
103
-
102
+ model_paths = self.make_model_paths()
104
103
  self._text_system = TextSystem(_OONXParams(
105
104
  use_angle_cls=True,
106
105
  use_gpu=(self._device != "cpu"),
@@ -123,10 +122,10 @@ class OCR:
123
122
  save_crop_res=False,
124
123
  rec_algorithm="SVTR_LCNet",
125
124
  use_space_char=True,
126
- rec_model_dir=os.path.join(self._model_dir_path, *_MODELS[0]),
127
- cls_model_dir=os.path.join(self._model_dir_path, *_MODELS[1]),
128
- det_model_dir=os.path.join(self._model_dir_path, *_MODELS[2]),
129
- rec_char_dict_path=os.path.join(self._model_dir_path, *_MODELS[3]),
125
+ rec_model_dir=model_paths[0],
126
+ cls_model_dir=model_paths[1],
127
+ det_model_dir=model_paths[2],
128
+ rec_char_dict_path=model_paths[3],
130
129
  ))
131
130
 
132
131
  return self._text_system
@@ -19,6 +19,10 @@ class Rectangle:
19
19
  yield self.rb
20
20
  yield self.rt
21
21
 
22
+ @property
23
+ def is_valid(self) -> bool:
24
+ return Polygon(self).is_valid
25
+
22
26
  @property
23
27
  def segments(self) -> Generator[tuple[Point, Point], None, None]:
24
28
  yield (self.lt, self.lb)
@@ -60,6 +64,8 @@ class Rectangle:
60
64
  def intersection_area(rect1: Rectangle, rect2: Rectangle) -> float:
61
65
  poly1 = Polygon(rect1)
62
66
  poly2 = Polygon(rect2)
67
+ if not poly1.is_valid or not poly2.is_valid:
68
+ return 0.0
63
69
  intersection = poly1.intersection(poly2)
64
70
  if intersection.is_empty:
65
71
  return 0.0
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
  from typing import Literal, Any
5
5
  from PIL.Image import Image
6
- from .types import TableLayoutParsedFormat
6
+ from .types import TableLayoutParsedFormat, GetModelDir
7
7
  from .utils import expand_image
8
8
 
9
9
 
@@ -13,10 +13,10 @@ class Table:
13
13
  def __init__(
14
14
  self,
15
15
  device: Literal["cpu", "cuda"],
16
- model_path: str,
16
+ get_model_dir: GetModelDir,
17
17
  ):
18
18
  self._model: Any | None = None
19
- self._model_path: str = model_path
19
+ self._model_path: str = get_model_dir()
20
20
  self._ban: bool = False
21
21
  if device == "cpu" or not torch.cuda.is_available():
22
22
  self._ban = True
@@ -58,13 +58,12 @@ class Table:
58
58
 
59
59
  from .struct_eqtable import build_model
60
60
  model = build_model(
61
- model_ckpt="U4R/StructTable-InternVL2-1B",
61
+ model_ckpt=self._model_path,
62
62
  max_new_tokens=1024,
63
63
  max_time=30,
64
64
  lmdeploy=False,
65
65
  flash_attn=True,
66
66
  batch_size=1,
67
- cache_dir=self._model_path,
68
67
  local_files_only=local_files_only,
69
68
  )
70
69
  self._model = model.cuda()
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Literal
2
+ from typing import Literal, Callable, Protocol, runtime_checkable
3
3
  from enum import auto, Enum
4
4
  from PIL.Image import Image
5
5
  from .rectangle import Rectangle
@@ -64,4 +64,25 @@ class ExtractedResult:
64
64
  rotation: float
65
65
  layouts: list[Layout]
66
66
  extracted_image: Image
67
- adjusted_image: Image | None
67
+ adjusted_image: Image | None
68
+
69
+ GetModelDir = Callable[[], str]
70
+
71
+
72
+ @runtime_checkable
73
+ class ModelsDownloader(Protocol):
74
+
75
+ def onnx_ocr(self) -> str:
76
+ pass
77
+
78
+ def yolo(self) -> str:
79
+ pass
80
+
81
+ def layoutreader(self) -> str:
82
+ pass
83
+
84
+ def struct_eqtable(self) -> str:
85
+ pass
86
+
87
+ def latex(self) -> str:
88
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: doc-page-extractor
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: doc page extractor can identify text and format in images and return structured data.
5
5
  Home-page: https://github.com/Moskize91/doc-page-extractor
6
6
  Author: Tao Zeyu
@@ -16,6 +16,7 @@ Requires-Dist: transformers<=4.47,>=4.42.4
16
16
  Requires-Dist: doclayout_yolo>=0.0.3
17
17
  Requires-Dist: pix2tex<=0.2.0,>=0.1.4
18
18
  Requires-Dist: accelerate<2.0,>=1.6.0
19
+ Requires-Dist: huggingface_hub>=0.30.2
19
20
  Dynamic: author
20
21
  Dynamic: author-email
21
22
  Dynamic: description
@@ -8,6 +8,7 @@ doc_page_extractor/extractor.py
8
8
  doc_page_extractor/latex.py
9
9
  doc_page_extractor/layout_order.py
10
10
  doc_page_extractor/layoutreader.py
11
+ doc_page_extractor/models.py
11
12
  doc_page_extractor/ocr.py
12
13
  doc_page_extractor/ocr_corrector.py
13
14
  doc_page_extractor/overlap.py
@@ -7,3 +7,4 @@ transformers<=4.47,>=4.42.4
7
7
  doclayout_yolo>=0.0.3
8
8
  pix2tex<=0.2.0,>=0.1.4
9
9
  accelerate<2.0,>=1.6.0
10
+ huggingface_hub>=0.30.2
@@ -5,7 +5,7 @@ if "doc_page_extractor.struct_eqtable" not in find_packages():
5
5
 
6
6
  setup(
7
7
  name="doc-page-extractor",
8
- version="0.1.0",
8
+ version="0.1.2",
9
9
  author="Tao Zeyu",
10
10
  author_email="i@taozeyu.com",
11
11
  url="https://github.com/Moskize91/doc-page-extractor",
@@ -23,5 +23,6 @@ setup(
23
23
  "doclayout_yolo>=0.0.3",
24
24
  "pix2tex>=0.1.4,<=0.2.0",
25
25
  "accelerate>=1.6.0,<2.0",
26
+ "huggingface_hub>=0.30.2",
26
27
  ],
27
28
  )
@@ -1,57 +0,0 @@
1
- import os
2
- import torch
3
- import requests
4
-
5
- from munch import Munch
6
- from pix2tex.cli import LatexOCR
7
- from PIL.Image import Image
8
- from .utils import expand_image
9
-
10
-
11
- class LaTeX:
12
- def __init__(self, model_path: str):
13
- self._model_path: str = model_path
14
- self._model: LatexOCR | None = None
15
-
16
- def extract(self, image: Image) -> str | None:
17
- image = expand_image(image, 0.1) # 添加边缘提高识别准确率
18
- model = self._get_model()
19
- with torch.no_grad():
20
- return model(image)
21
-
22
- def _get_model(self) -> LatexOCR:
23
- if self._model is None:
24
- if not os.path.exists(self._model_path):
25
- self._download_model()
26
-
27
- self._model = LatexOCR(Munch({
28
- "config": os.path.join("settings", "config.yaml"),
29
- "checkpoint": os.path.join(self._model_path, "weights.pth"),
30
- "no_cuda": True,
31
- "no_resize": False,
32
- }))
33
- return self._model
34
-
35
- # from https://github.com/lukas-blecher/LaTeX-OCR/blob/5c1ac929bd19a7ecf86d5fb8d94771c8969fcb80/pix2tex/model/checkpoints/get_latest_checkpoint.py#L37-L45
36
- def _download_model(self):
37
- os.makedirs(self._model_path, exist_ok=True)
38
- tag = "v0.0.1"
39
- files: list[tuple[str, str]] = (
40
- ("weights.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/weights.pth"),
41
- ("image_resizer.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/image_resizer.pth")
42
- )
43
- for file_name, url in files:
44
- file_path = os.path.join(self._model_path, file_name)
45
- try:
46
- with open(file_path, "wb") as file:
47
- response = requests.get(url, stream=True, timeout=15)
48
- response.raise_for_status()
49
- for chunk in response.iter_content(chunk_size=8192):
50
- if chunk: # 过滤掉保持连接的新块
51
- file.write(chunk)
52
- file.flush()
53
-
54
- except BaseException as e:
55
- if os.path.exists(file_path):
56
- os.remove(file_path)
57
- raise e