doc-page-extractor 0.1.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. doc_page_extractor/__init__.py +5 -14
  2. doc_page_extractor/check_env.py +40 -0
  3. doc_page_extractor/extractor.py +87 -212
  4. doc_page_extractor/model.py +97 -0
  5. doc_page_extractor/parser.py +51 -0
  6. doc_page_extractor/plot.py +52 -79
  7. doc_page_extractor/redacter.py +111 -0
  8. doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
  9. doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
  10. {doc_page_extractor-0.1.1.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
  11. doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
  12. doc_page_extractor/clipper.py +0 -119
  13. doc_page_extractor/downloader.py +0 -16
  14. doc_page_extractor/latex.py +0 -57
  15. doc_page_extractor/layout_order.py +0 -240
  16. doc_page_extractor/layoutreader.py +0 -126
  17. doc_page_extractor/ocr.py +0 -175
  18. doc_page_extractor/ocr_corrector.py +0 -126
  19. doc_page_extractor/onnxocr/__init__.py +0 -1
  20. doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
  21. doc_page_extractor/onnxocr/db_postprocess.py +0 -246
  22. doc_page_extractor/onnxocr/imaug.py +0 -32
  23. doc_page_extractor/onnxocr/operators.py +0 -187
  24. doc_page_extractor/onnxocr/predict_base.py +0 -52
  25. doc_page_extractor/onnxocr/predict_cls.py +0 -89
  26. doc_page_extractor/onnxocr/predict_det.py +0 -120
  27. doc_page_extractor/onnxocr/predict_rec.py +0 -321
  28. doc_page_extractor/onnxocr/predict_system.py +0 -97
  29. doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
  30. doc_page_extractor/onnxocr/utils.py +0 -71
  31. doc_page_extractor/overlap.py +0 -167
  32. doc_page_extractor/raw_optimizer.py +0 -104
  33. doc_page_extractor/rectangle.py +0 -72
  34. doc_page_extractor/rotation.py +0 -158
  35. doc_page_extractor/struct_eqtable/__init__.py +0 -49
  36. doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
  37. doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
  38. doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
  39. doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
  40. doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
  41. doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
  42. doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
  43. doc_page_extractor/table.py +0 -71
  44. doc_page_extractor/types.py +0 -67
  45. doc_page_extractor/utils.py +0 -32
  46. doc_page_extractor-0.1.1.dist-info/METADATA +0 -84
  47. doc_page_extractor-0.1.1.dist-info/RECORD +0 -44
  48. doc_page_extractor-0.1.1.dist-info/licenses/LICENSE +0 -661
  49. doc_page_extractor-0.1.1.dist-info/top_level.txt +0 -2
  50. tests/__init__.py +0 -0
  51. tests/test_history_bus.py +0 -55
@@ -1,15 +1,6 @@
1
- from .extractor import DocExtractor
2
- from .clipper import clip, clip_from_image
1
+ from .extractor import Layout, PageExtractor
2
+ from .model import DeepSeekOCRSize
3
3
  from .plot import plot
4
- from .rectangle import Point, Rectangle
5
- from .types import (
6
- ExtractedResult,
7
- OCRFragment,
8
- LayoutClass,
9
- TableLayoutParsedFormat,
10
- Layout,
11
- BaseLayout,
12
- PlainLayout,
13
- FormulaLayout,
14
- TableLayout,
15
- )
4
+
5
+ __version__ = "1.0.0"
6
+ __all__ = ["DeepSeekOCRSize", "Layout", "PageExtractor", "plot"]
@@ -0,0 +1,40 @@
1
+ import warnings
2
+
3
+ import torch
4
+
5
+ _env_checked = False
6
+
7
+
8
+ def check_env() -> None:
9
+ global _env_checked
10
+ if _env_checked:
11
+ return
12
+ _env_checked = True
13
+
14
+ if torch.cuda.is_available():
15
+ return
16
+ warnings.warn(
17
+ """
18
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
19
+ CUDA is not available!
20
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
21
+
22
+ This package requires CUDA to run, but torch.cuda.is_available() returned False.
23
+
24
+ Possible causes:
25
+ 1. You installed CPU-only PyTorch. Reinstall with CUDA support:
26
+ pip uninstall torch torchvision
27
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
28
+
29
+ 2. Your NVIDIA GPU driver is outdated. Update it from:
30
+ https://www.nvidia.com/download/index.aspx
31
+
32
+ 3. You don't have a CUDA-compatible GPU.
33
+
34
+ To verify your setup, run: nvidia-smi
35
+
36
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
37
+ """.strip(),
38
+ RuntimeWarning,
39
+ stacklevel=2,
40
+ )
@@ -1,213 +1,88 @@
1
- import os
2
-
3
- from typing import Literal, Generator
1
+ import tempfile
2
+ from dataclasses import dataclass
3
+ from os import PathLike
4
4
  from pathlib import Path
5
- from PIL.Image import Image
6
- from doclayout_yolo import YOLOv10
7
-
8
- from .ocr import OCR
9
- from .ocr_corrector import correct_fragments
10
- from .raw_optimizer import RawOptimizer
11
- from .rectangle import intersection_area, Rectangle
12
- from .downloader import download
13
- from .table import Table
14
- from .latex import LaTeX
15
- from .layout_order import LayoutOrder
16
- from .overlap import merge_fragments_as_line, remove_overlap_layouts
17
- from .clipper import clip_from_image
18
- from .types import (
19
- ExtractedResult,
20
- OCRFragment,
21
- TableLayoutParsedFormat,
22
- Layout,
23
- LayoutClass,
24
- PlainLayout,
25
- TableLayout,
26
- FormulaLayout,
27
- )
28
-
29
-
30
- class DocExtractor:
31
- def __init__(
32
- self,
33
- model_dir_path: str,
34
- device: Literal["cpu", "cuda"] = "cpu",
35
- ocr_for_each_layouts: bool = True,
36
- extract_formula: bool = True,
37
- extract_table_format: TableLayoutParsedFormat | None = None,
38
- ):
39
- self._model_dir_path: str = model_dir_path
40
- self._device: Literal["cpu", "cuda"] = device
41
- self._ocr_for_each_layouts: bool = ocr_for_each_layouts
42
- self._extract_formula: bool = extract_formula
43
- self._extract_table_format: TableLayoutParsedFormat | None = extract_table_format
44
- self._yolo: YOLOv10 | None = None
45
- self._ocr: OCR = OCR(
46
- device=device,
47
- model_dir_path=os.path.join(model_dir_path, "onnx_ocr"),
48
- )
49
- self._table: Table = Table(
50
- device=device,
51
- model_path=os.path.join(model_dir_path, "struct_eqtable"),
52
- )
53
- self._latex: LaTeX = LaTeX(
54
- model_path=os.path.join(model_dir_path, "latex"),
55
- )
56
- self._layout_order: LayoutOrder = LayoutOrder(
57
- model_path=os.path.join(model_dir_path, "layoutreader"),
58
- )
59
-
60
- def extract(
61
- self,
62
- image: Image,
63
- adjust_points: bool = False,
64
- ) -> ExtractedResult:
65
-
66
- raw_optimizer = RawOptimizer(image, adjust_points)
67
- fragments = list(self._ocr.search_fragments(raw_optimizer.image_np))
68
- raw_optimizer.receive_raw_fragments(fragments)
69
- layouts = list(self._yolo_extract_layouts(raw_optimizer.image))
70
- layouts = self._layouts_matched_by_fragments(fragments, layouts)
71
- layouts = remove_overlap_layouts(layouts)
72
-
73
- if self._ocr_for_each_layouts:
74
- self._correct_fragments_by_ocr_layouts(raw_optimizer.image, layouts)
75
-
76
- layouts = self._layout_order.sort(layouts, raw_optimizer.image.size)
77
- layouts = [layout for layout in layouts if self._should_keep_layout(layout)]
78
-
79
- self._parse_table_and_formula_layouts(layouts, raw_optimizer)
80
-
81
- for layout in layouts:
82
- layout.fragments = merge_fragments_as_line(layout.fragments)
83
-
84
- raw_optimizer.receive_raw_layouts(layouts)
85
-
86
- return ExtractedResult(
87
- rotation=raw_optimizer.rotation,
88
- layouts=layouts,
89
- extracted_image=image,
90
- adjusted_image=raw_optimizer.adjusted_image,
91
- )
92
-
93
- def _yolo_extract_layouts(self, source: Image) -> Generator[Layout, None, None]:
94
- # about source parameter to see:
95
- # https://github.com/opendatalab/DocLayout-YOLO/blob/7c4be36bc61f11b67cf4a44ee47f3c41e9800a91/doclayout_yolo/data/build.py#L157-L175
96
- det_res = self._get_yolo().predict(
97
- source=source,
98
- imgsz=1024,
99
- conf=0.2,
100
- device=self._device # Device to use (e.g., "cuda" or "cpu")
101
- )
102
- boxes = det_res[0].__dict__["boxes"]
103
-
104
- for cls_id, rect in zip(boxes.cls, boxes.xyxy):
105
- cls_id = cls_id.item()
106
- cls=LayoutClass(round(cls_id))
107
-
108
- x1, y1, x2, y2 = rect
109
- x1 = x1.item()
110
- y1 = y1.item()
111
- x2 = x2.item()
112
- y2 = y2.item()
113
- rect = Rectangle(
114
- lt=(x1, y1),
115
- rt=(x2, y1),
116
- lb=(x1, y2),
117
- rb=(x2, y2),
118
- )
119
- if rect.is_valid:
120
- if cls == LayoutClass.TABLE:
121
- yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
122
- elif cls == LayoutClass.ISOLATE_FORMULA:
123
- yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
124
- else:
125
- yield PlainLayout(cls=cls, rect=rect, fragments=[])
126
-
127
- def _layouts_matched_by_fragments(self, fragments: list[OCRFragment], layouts: list[Layout]):
128
- layouts_group = self._split_layouts_by_group(layouts)
129
- for fragment in fragments:
130
- for sub_layouts in layouts_group:
131
- layout = self._find_matched_layout(fragment, sub_layouts)
132
- if layout is not None:
133
- layout.fragments.append(fragment)
134
- break
135
- return layouts
136
-
137
- def _correct_fragments_by_ocr_layouts(self, source: Image, layouts: list[Layout]):
138
- for layout in layouts:
139
- correct_fragments(self._ocr, source, layout)
140
-
141
- def _parse_table_and_formula_layouts(self, layouts: list[Layout], raw_optimizer: RawOptimizer):
142
- for layout in layouts:
143
- if isinstance(layout, FormulaLayout) and self._extract_formula:
144
- image = clip_from_image(raw_optimizer.image, layout.rect)
145
- layout.latex = self._latex.extract(image)
146
- elif isinstance(layout, TableLayout) and self._extract_table_format is not None:
147
- image = clip_from_image(raw_optimizer.image, layout.rect)
148
- parsed = self._table.predict(image, self._extract_table_format)
149
- if parsed is not None:
150
- layout.parsed = (parsed, self._extract_table_format)
151
-
152
- def _split_layouts_by_group(self, layouts: list[Layout]):
153
- texts_layouts: list[Layout] = []
154
- abandon_layouts: list[Layout] = []
155
-
156
- for layout in layouts:
157
- cls = layout.cls
158
- if cls == LayoutClass.TITLE or \
159
- cls == LayoutClass.PLAIN_TEXT or \
160
- cls == LayoutClass.FIGURE_CAPTION or \
161
- cls == LayoutClass.TABLE_CAPTION or \
162
- cls == LayoutClass.TABLE_FOOTNOTE or \
163
- cls == LayoutClass.FORMULA_CAPTION:
164
- texts_layouts.append(layout)
165
- elif cls == LayoutClass.ABANDON:
166
- abandon_layouts.append(layout)
167
-
168
- return texts_layouts, abandon_layouts
169
-
170
- def _find_matched_layout(self, fragment: OCRFragment, layouts: list[Layout]) -> Layout | None:
171
- fragment_area = fragment.rect.area
172
- primary_layouts: list[(Layout, float)] = []
173
-
174
- if fragment_area == 0.0:
175
- return None
176
-
177
- for layout in layouts:
178
- area = intersection_area(fragment.rect, layout.rect)
179
- if area / fragment_area > 0.85:
180
- primary_layouts.append((layout, layout.rect.area))
181
-
182
- min_area: float = float("inf")
183
- min_layout: Layout | None = None
184
-
185
- for layout, area in primary_layouts:
186
- if area < min_area:
187
- min_area = area
188
- min_layout = layout
189
-
190
- return min_layout
191
-
192
- def _get_yolo(self) -> YOLOv10:
193
- if self._yolo is None:
194
- base_path = os.path.join(self._model_dir_path, "yolo")
195
- os.makedirs(base_path, exist_ok=True)
196
- yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
197
- yolo_model_name = "doclayout_yolo_ft.pt"
198
- yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
199
- if not yolo_model_path.exists():
200
- download(yolo_model_url, yolo_model_path)
201
- self._yolo = YOLOv10(str(yolo_model_path))
202
- return self._yolo
203
-
204
- def _should_keep_layout(self, layout: Layout) -> bool:
205
- if len(layout.fragments) > 0:
206
- return True
207
- cls = layout.cls
208
- return (
209
- cls == LayoutClass.FIGURE or
210
- cls == LayoutClass.TABLE or
211
- cls == LayoutClass.ISOLATE_FORMULA
212
- )
213
-
5
+ from typing import Generator, cast
6
+
7
+ from PIL import Image
8
+
9
+ from .check_env import check_env
10
+ from .model import DeepSeekOCRModel, DeepSeekOCRSize
11
+ from .parser import ParsedItemKind, parse_ocr_response
12
+ from .redacter import background_color, redact
13
+
14
+
15
+ @dataclass
16
+ class Layout:
17
+ ref: str
18
+ det: tuple[int, int, int, int]
19
+ text: str | None
20
+
21
+
22
+ class PageExtractor:
23
+ def __init__(
24
+ self,
25
+ model_path: PathLike | None = None,
26
+ local_only: bool = False,
27
+ ) -> None:
28
+ self._model: DeepSeekOCRModel = DeepSeekOCRModel(
29
+ model_path=Path(model_path) if model_path else None,
30
+ local_only=local_only,
31
+ )
32
+
33
+ def download_models(self) -> None:
34
+ self._model.download()
35
+
36
+ def load_models(self) -> None:
37
+ self._model.load()
38
+
39
+ def extract(
40
+ self, image: Image.Image, size: DeepSeekOCRSize, stages: int = 1
41
+ ) -> Generator[tuple[Image.Image, list[Layout]], None, None]:
42
+ check_env()
43
+ assert stages >= 1, "stages must be at least 1"
44
+ with tempfile.TemporaryDirectory() as temp_path:
45
+ fill_color: tuple[int, int, int] | None = None
46
+ for i in range(stages):
47
+ response = self._model.generate(
48
+ image=image,
49
+ prompt="<image>\n<|grounding|>Convert the document to markdown.",
50
+ temp_path=temp_path,
51
+ size=size,
52
+ )
53
+ layouts: list[Layout] = []
54
+ for ref, det, text in self._parse_response(image, response):
55
+ layouts.append(Layout(ref, det, text))
56
+ yield image, layouts
57
+ if i < stages - 1:
58
+ if fill_color is None:
59
+ fill_color = background_color(image)
60
+ image = redact(
61
+ image=image.copy(),
62
+ fill_color=fill_color,
63
+ rectangles=(layout.det for layout in layouts),
64
+ )
65
+
66
+ def _parse_response(
67
+ self, image: Image.Image, response: str
68
+ ) -> Generator[tuple[str, tuple[int, int, int, int], str | None], None, None]:
69
+ width, height = image.size
70
+ det: tuple[int, int, int, int] | None = None
71
+ ref: str | None = None
72
+
73
+ for kind, content in parse_ocr_response(response, width, height):
74
+ if kind == ParsedItemKind.TEXT:
75
+ if det is not None and ref is not None:
76
+ yield ref, det, cast(str, content)
77
+ det = None
78
+ ref = None
79
+ if det is not None and ref is not None:
80
+ yield ref, det, None
81
+ det = None
82
+ ref = None
83
+ elif kind == ParsedItemKind.DET:
84
+ det = cast(tuple[int, int, int, int], content)
85
+ elif kind == ParsedItemKind.REF:
86
+ ref = cast(str, content)
87
+ if det is not None and ref is not None:
88
+ yield ref, det, None
@@ -0,0 +1,97 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any, Literal
5
+
6
+ import torch
7
+ from huggingface_hub import snapshot_download
8
+ from PIL import Image
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ DeepSeekOCRSize = Literal["tiny", "small", "base", "large", "gundam"]
12
+
13
+
14
+ @dataclass
15
+ class _SizeConfig:
16
+ base_size: int
17
+ image_size: int
18
+ crop_mode: bool
19
+
20
+
21
+ _SIZE_CONFIGS: dict[DeepSeekOCRSize, _SizeConfig] = {
22
+ "tiny": _SizeConfig(base_size=512, image_size=512, crop_mode=False),
23
+ "small": _SizeConfig(base_size=640, image_size=640, crop_mode=False),
24
+ "base": _SizeConfig(base_size=1024, image_size=1024, crop_mode=False),
25
+ "large": _SizeConfig(base_size=1280, image_size=1280, crop_mode=False),
26
+ "gundam": _SizeConfig(base_size=1024, image_size=640, crop_mode=True),
27
+ }
28
+
29
+ _ATTN_IMPLEMENTATION: str
30
+ try:
31
+ import flash_attn # type: ignore # pylint: disable=unused-import
32
+
33
+ _ATTN_IMPLEMENTATION = "flash_attention_2"
34
+ except ImportError:
35
+ _ATTN_IMPLEMENTATION = "eager"
36
+
37
+ _Models = tuple[Any, Any]
38
+
39
+
40
+ class DeepSeekOCRModel:
41
+ def __init__(self, model_path: Path | None, local_only: bool) -> None:
42
+ self._model_name = "deepseek-ai/DeepSeek-OCR"
43
+ self._cache_dir = str(model_path) if model_path else None
44
+ self._local_only = local_only
45
+ self._models: _Models | None = None
46
+
47
+ def download(self) -> None:
48
+ snapshot_download(
49
+ repo_id=self._model_name,
50
+ repo_type="model",
51
+ cache_dir=self._cache_dir,
52
+ )
53
+
54
+ def load(self) -> None:
55
+ self._ensure_models()
56
+
57
+ def _ensure_models(self) -> _Models:
58
+ if self._models is None:
59
+ tokenizer = AutoTokenizer.from_pretrained(
60
+ self._model_name,
61
+ trust_remote_code=True,
62
+ cache_dir=self._cache_dir,
63
+ local_files_only=self._local_only,
64
+ )
65
+ model = AutoModel.from_pretrained(
66
+ pretrained_model_name_or_path=self._model_name,
67
+ _attn_implementation=_ATTN_IMPLEMENTATION,
68
+ trust_remote_code=True,
69
+ use_safetensors=True,
70
+ cache_dir=self._cache_dir,
71
+ local_files_only=self._local_only,
72
+ )
73
+ model = model.cuda().to(torch.bfloat16)
74
+ self._models = (tokenizer, model)
75
+
76
+ return self._models
77
+
78
+ def generate(
79
+ self, image: Image.Image, prompt: str, temp_path: str, size: DeepSeekOCRSize
80
+ ) -> str:
81
+ tokenizer, model = self._ensure_models()
82
+ config = _SIZE_CONFIGS[size]
83
+ temp_image_path = os.path.join(temp_path, "temp_image.png")
84
+ image.save(temp_image_path)
85
+ text_result = model.infer(
86
+ tokenizer,
87
+ prompt=prompt,
88
+ image_file=temp_image_path,
89
+ output_path=temp_path,
90
+ base_size=config.base_size,
91
+ image_size=config.image_size,
92
+ crop_mode=config.crop_mode,
93
+ save_results=True,
94
+ test_compress=True,
95
+ eval_mode=True,
96
+ )
97
+ return text_result
@@ -0,0 +1,51 @@
1
+ import re
2
+ from enum import Enum, auto
3
+ from typing import Generator
4
+
5
+ _TAG_PATTERN = re.compile(r"<\|(det|ref)\|>(.+?)<\|/\1\|>")
6
+ _DET_COORDS_PATTERN = re.compile(r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]")
7
+
8
+
9
+ class ParsedItemKind(Enum):
10
+ DET = auto()
11
+ REF = auto()
12
+ TEXT = auto()
13
+
14
+
15
+ ParsedItem = (
16
+ tuple[ParsedItemKind.DET, tuple[int, int, int, int]]
17
+ | tuple[ParsedItemKind.REF, str]
18
+ | tuple[ParsedItemKind.TEXT, str]
19
+ )
20
+
21
+
22
+ def parse_ocr_response(
23
+ response: str, width: int, height: int
24
+ ) -> Generator[ParsedItem, None, None]:
25
+ last_end: int = 0
26
+ for matched in _TAG_PATTERN.finditer(response):
27
+ if matched.start() > last_end:
28
+ plain_text = response[last_end : matched.start()]
29
+ if plain_text:
30
+ yield ParsedItemKind.TEXT, plain_text
31
+ tag_type = matched.group(1)
32
+ content = matched.group(2)
33
+ if tag_type == "det":
34
+ coords_match = _DET_COORDS_PATTERN.search(content)
35
+ if coords_match:
36
+ x1_norm, y1_norm, x2_norm, y2_norm = [
37
+ int(c) for c in coords_match.groups()
38
+ ]
39
+ x1 = round(x1_norm / 1000 * width)
40
+ y1 = round(y1_norm / 1000 * height)
41
+ x2 = round(x2_norm / 1000 * width)
42
+ y2 = round(y2_norm / 1000 * height)
43
+ yield ParsedItemKind.DET, (x1, y1, x2, y2)
44
+ elif tag_type == "ref":
45
+ yield ParsedItemKind.REF, content
46
+ last_end = matched.end()
47
+
48
+ if last_end < len(response):
49
+ plain_text = response[last_end:]
50
+ if plain_text:
51
+ yield ParsedItemKind.TEXT, plain_text
@@ -1,91 +1,64 @@
1
- from typing import Iterable
1
+ from typing import Iterable, cast
2
+
2
3
  from PIL import ImageDraw
3
- from PIL.ImageFont import load_default, FreeTypeFont
4
4
  from PIL.Image import Image
5
- from .types import Layout, LayoutClass
6
- from .rectangle import Point
5
+ from PIL.ImageFont import FreeTypeFont, load_default
6
+
7
+ from .extractor import Layout
7
8
 
8
- _FRAGMENT_COLOR = (0x49, 0xCF, 0xCB) # Light Green
9
+ _FRAGMENT_COLOR = (0x49, 0xCF, 0xCB) # Light Green
9
10
  _Color = tuple[int, int, int]
10
11
 
11
- def plot(image: Image, layouts: Iterable[Layout]) -> None:
12
- layout_font = load_default(size=35)
13
- fragment_font = load_default(size=25)
14
- draw = ImageDraw.Draw(image, mode="RGBA")
15
12
 
16
- def _draw_number(position: Point, number: int, font: FreeTypeFont, bold: bool, color: _Color) -> None:
17
- nonlocal draw
18
- x, y = position
19
- text = str(object=number)
20
- width = len(text) * font.size
21
- offset = round(font.size * 0.15)
13
+ def plot(image: Image, layouts: Iterable[Layout]) -> Image:
14
+ layout_font = cast(FreeTypeFont, load_default(size=35))
15
+ draw = ImageDraw.Draw(image, mode="RGBA")
22
16
 
23
- for dx, dy in _generate_delta(bold):
24
- draw.text(
25
- xy=(x + dx - width - offset, y + dy),
26
- text=text,
27
- font=font,
28
- fill=color,
29
- )
17
+ def _draw_text(
18
+ position: tuple[int, int],
19
+ text: str,
20
+ font: FreeTypeFont,
21
+ bold: bool,
22
+ color: _Color,
23
+ ) -> None:
24
+ nonlocal draw
25
+ x, y = position
26
+ bbox = font.getbbox(text)
27
+ text_width = bbox[2] - bbox[0]
28
+ offset = round(font.size * 0.15)
30
29
 
31
- for layout in layouts:
32
- draw.polygon(
33
- xy=[p for p in layout.rect],
34
- outline=_layout_color(layout),
35
- width=5,
36
- )
30
+ for dx, dy in _generate_delta(bold):
31
+ draw.text(
32
+ xy=(x + dx - text_width - offset, y + dy),
33
+ text=text,
34
+ font=font,
35
+ fill=color,
36
+ )
37
37
 
38
- for layout in layouts:
39
- for fragment in layout.fragments:
40
- draw.polygon(
41
- xy=[p for p in fragment.rect],
42
- outline=_FRAGMENT_COLOR,
43
- width=3,
44
- )
45
- _draw_number(
46
- position=fragment.rect.lt,
47
- number=fragment.order + 1,
48
- font=fragment_font,
49
- bold=False,
50
- color=_FRAGMENT_COLOR,
51
- )
38
+ for layout in layouts:
39
+ x1, y1, x2, y2 = layout.det
40
+ draw.polygon(
41
+ xy=[(x1, y1), (x2, y1), (x2, y2), (x1, y2)],
42
+ outline=_FRAGMENT_COLOR,
43
+ width=5,
44
+ )
52
45
 
53
- for i, layout in enumerate(layouts):
54
- _draw_number(
55
- position=layout.rect.lt,
56
- number=i + 1,
57
- font=layout_font,
58
- bold=True,
59
- color=_layout_color(layout),
60
- )
46
+ for i, layout in enumerate(layouts):
47
+ x1, y1, _, _ = layout.det
48
+ _draw_text(
49
+ position=(x1, y1),
50
+ text=f"{i + 1}. {layout.ref.strip()}",
51
+ font=layout_font,
52
+ bold=True,
53
+ color=_FRAGMENT_COLOR,
54
+ )
55
+ return image
61
56
 
62
- def _generate_delta(bold: bool):
63
- if bold:
64
- for dx in range(-1, 2):
65
- for dy in range(-1, 2):
66
- yield dx, dy
67
- else:
68
- yield 0, 0
69
57
 
70
- def _layout_color(layout: Layout) -> _Color:
71
- cls = layout.cls
72
- if cls == LayoutClass.TITLE:
73
- return (0x0A, 0x12, 0x2C) # Dark
74
- elif cls == LayoutClass.PLAIN_TEXT:
75
- return (0x3C, 0x67, 0x90) # Blue
76
- elif cls == LayoutClass.ABANDON:
77
- return (0xC0, 0xBB, 0xA9) # Gray
78
- elif cls == LayoutClass.FIGURE:
79
- return (0x5B, 0x91, 0x3C) # Dark Green
80
- elif cls == LayoutClass.FIGURE_CAPTION:
81
- return (0x77, 0xB3, 0x54) # Green
82
- elif cls == LayoutClass.TABLE:
83
- return (0x44, 0x17, 0x52) # Dark Purple
84
- elif cls == LayoutClass.TABLE_CAPTION:
85
- return (0x81, 0x75, 0xA0) # Purple
86
- elif cls == LayoutClass.TABLE_FOOTNOTE:
87
- return (0xEF, 0xB6, 0xC9) # Pink Purple
88
- elif cls == LayoutClass.ISOLATE_FORMULA:
89
- return (0xFA, 0x38, 0x27) # Red
90
- elif cls == LayoutClass.FORMULA_CAPTION:
91
- return (0xFF, 0x9D, 0x24) # Orange
58
+ def _generate_delta(bold: bool):
59
+ if bold:
60
+ for dx in range(-1, 2):
61
+ for dy in range(-1, 2):
62
+ yield dx, dy
63
+ else:
64
+ yield 0, 0