kreuzberg 3.0.0__py3-none-any.whl → 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +4 -1
- kreuzberg/_extractors/__init__.py +0 -0
- kreuzberg/_extractors/_base.py +92 -0
- kreuzberg/_extractors/_html.py +34 -0
- kreuzberg/_extractors/_image.py +74 -0
- kreuzberg/_extractors/_pandoc.py +613 -0
- kreuzberg/_extractors/_pdf.py +171 -0
- kreuzberg/_extractors/_presentation.py +233 -0
- kreuzberg/_extractors/_spread_sheet.py +125 -0
- kreuzberg/_gmft.py +174 -0
- kreuzberg/_ocr/__init__.py +17 -0
- kreuzberg/_ocr/_base.py +54 -0
- kreuzberg/_ocr/_easyocr.py +376 -0
- kreuzberg/_ocr/_paddleocr.py +283 -0
- kreuzberg/_ocr/_tesseract.py +342 -0
- kreuzberg/_types.py +31 -4
- kreuzberg/_utils/__init__.py +0 -0
- kreuzberg/_utils/_string.py +39 -0
- kreuzberg/_utils/_sync.py +121 -0
- kreuzberg/_utils/_tmp.py +37 -0
- {kreuzberg-3.0.0.dist-info → kreuzberg-3.1.0.dist-info}/METADATA +14 -19
- kreuzberg-3.1.0.dist-info/RECORD +33 -0
- {kreuzberg-3.0.0.dist-info → kreuzberg-3.1.0.dist-info}/WHEEL +1 -1
- kreuzberg-3.0.0.dist-info/RECORD +0 -15
- {kreuzberg-3.0.0.dist-info → kreuzberg-3.1.0.dist-info}/licenses/LICENSE +0 -0
- {kreuzberg-3.0.0.dist-info → kreuzberg-3.1.0.dist-info}/top_level.txt +0 -0
kreuzberg/_gmft.py
ADDED
@@ -0,0 +1,174 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
|
6
|
+
from kreuzberg._types import TableData
|
7
|
+
from kreuzberg._utils._sync import run_sync
|
8
|
+
from kreuzberg.exceptions import MissingDependencyError
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from os import PathLike
|
12
|
+
|
13
|
+
from gmft.detectors.base import CroppedTable
|
14
|
+
from pandas import DataFrame
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass(unsafe_hash=True)
|
18
|
+
class GMFTConfig:
|
19
|
+
"""Configuration options for GMFT.
|
20
|
+
|
21
|
+
This class encapsulates the configuration options for GMFT, providing a way to customize its behavior.
|
22
|
+
"""
|
23
|
+
|
24
|
+
verbosity: int = 0
|
25
|
+
"""
|
26
|
+
Verbosity level for logging.
|
27
|
+
|
28
|
+
0: errors only
|
29
|
+
1: print warnings
|
30
|
+
2: print warnings and info
|
31
|
+
3: print warnings, info, and debug
|
32
|
+
"""
|
33
|
+
formatter_base_threshold: float = 0.3
|
34
|
+
"""
|
35
|
+
Base threshold for the confidence demanded of a table feature (row/column).
|
36
|
+
|
37
|
+
Note that a low threshold is actually better, because overzealous rows means that generally, numbers are still aligned and there are just many empty rows (having fewer rows than expected merges cells, which is bad).
|
38
|
+
"""
|
39
|
+
cell_required_confidence: dict[Literal[0, 1, 2, 3, 4, 5, 6], float] = field(
|
40
|
+
default_factory=lambda: {
|
41
|
+
0: 0.3,
|
42
|
+
1: 0.3,
|
43
|
+
2: 0.3,
|
44
|
+
3: 0.3,
|
45
|
+
4: 0.5,
|
46
|
+
5: 0.5,
|
47
|
+
6: 99,
|
48
|
+
},
|
49
|
+
hash=False,
|
50
|
+
)
|
51
|
+
"""
|
52
|
+
Confidences required (>=) for a row/column feature to be considered good. See TATRFormattedTable.id2label
|
53
|
+
|
54
|
+
But low confidences may be better than too high confidence (see formatter_base_threshold)
|
55
|
+
"""
|
56
|
+
detector_base_threshold: float = 0.9
|
57
|
+
"""Minimum confidence score required for a table"""
|
58
|
+
remove_null_rows: bool = True
|
59
|
+
"""
|
60
|
+
Flag to remove rows with no text.
|
61
|
+
"""
|
62
|
+
enable_multi_header: bool = False
|
63
|
+
"""
|
64
|
+
Enable multi-indices in the dataframe.
|
65
|
+
|
66
|
+
If false, then multiple headers will be merged column-wise.
|
67
|
+
"""
|
68
|
+
semantic_spanning_cells: bool = False
|
69
|
+
"""
|
70
|
+
[Experimental] Enable semantic spanning cells, which often encode hierarchical multi-level indices.
|
71
|
+
"""
|
72
|
+
semantic_hierarchical_left_fill: str | None = "algorithm"
|
73
|
+
"""
|
74
|
+
[Experimental] When semantic spanning cells is enabled, when a left header is detected which might represent a group of rows, that same value is reduplicated for each row.
|
75
|
+
|
76
|
+
Possible values: 'algorithm', 'deep', None.
|
77
|
+
|
78
|
+
'algorithm': assumes that the higher-level header is always the first row followed by several empty rows.
|
79
|
+
'deep': merges headers according to the spanning cells detected by the Table Transformer.
|
80
|
+
None: headers are not duplicated.
|
81
|
+
"""
|
82
|
+
large_table_if_n_rows_removed: int = 8
|
83
|
+
"""
|
84
|
+
If >= n rows are removed due to non-maxima suppression (NMS), then this table is classified as a large table.
|
85
|
+
"""
|
86
|
+
large_table_threshold: int = 10
|
87
|
+
"""
|
88
|
+
With large tables, table transformer struggles with placing too many overlapping rows. Luckily, with more rows, we have more info on the usual size of text, which we can use to make a guess on the height such that no rows are merged or overlapping.
|
89
|
+
|
90
|
+
Large table assumption is only applied when (# of rows > large_table_threshold) AND (total overlap > large_table_row_overlap_threshold). Set 9999 to disable; set 0 to force large table assumption to run every time.
|
91
|
+
"""
|
92
|
+
large_table_row_overlap_threshold: float = 0.2
|
93
|
+
"""
|
94
|
+
With large tables, table transformer struggles with placing too many overlapping rows. Luckily, with more rows, we have more info on the usual size of text, which we can use to make a guess on the height such that no rows are merged or overlapping.
|
95
|
+
|
96
|
+
Large table assumption is only applied when (# of rows > large_table_threshold) AND (total overlap > large_table_row_overlap_threshold).
|
97
|
+
"""
|
98
|
+
large_table_maximum_rows: int = 1000
|
99
|
+
"""
|
100
|
+
Maximum number of rows allowed for a large table.
|
101
|
+
"""
|
102
|
+
force_large_table_assumption: bool | None = None
|
103
|
+
"""
|
104
|
+
Force the large table assumption to be applied, regardless of the number of rows and overlap.
|
105
|
+
"""
|
106
|
+
|
107
|
+
|
108
|
+
async def extract_tables(file_path: str | PathLike[str], config: GMFTConfig | None = None) -> list[TableData]:
|
109
|
+
"""Extracts tables from a PDF file.
|
110
|
+
|
111
|
+
This function takes a file path to a PDF file, and an optional configuration object.
|
112
|
+
It returns a list of strings, where each string is a markdown-formatted table.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
file_path: The path to the PDF file.
|
116
|
+
config: An optional configuration object.
|
117
|
+
|
118
|
+
Raises:
|
119
|
+
MissingDependencyError: Raised when the required dependencies are not installed.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
A list of table data dictionaries.
|
123
|
+
"""
|
124
|
+
try:
|
125
|
+
from gmft.auto import AutoTableDetector, AutoTableFormatter
|
126
|
+
from gmft.detectors.tatr import TATRDetectorConfig
|
127
|
+
from gmft.formatters.tatr import TATRFormatConfig
|
128
|
+
from gmft.pdf_bindings.pdfium import PyPDFium2Document
|
129
|
+
|
130
|
+
config = config or GMFTConfig()
|
131
|
+
formatter = AutoTableFormatter(
|
132
|
+
config=TATRFormatConfig(
|
133
|
+
verbosity=config.verbosity,
|
134
|
+
formatter_base_threshold=config.formatter_base_threshold,
|
135
|
+
cell_required_confidence=config.cell_required_confidence,
|
136
|
+
remove_null_rows=config.remove_null_rows,
|
137
|
+
enable_multi_header=config.enable_multi_header,
|
138
|
+
semantic_spanning_cells=config.semantic_spanning_cells,
|
139
|
+
semantic_hierarchical_left_fill=config.semantic_hierarchical_left_fill,
|
140
|
+
large_table_if_n_rows_removed=config.large_table_if_n_rows_removed,
|
141
|
+
large_table_threshold=config.large_table_threshold,
|
142
|
+
large_table_row_overlap_threshold=config.large_table_row_overlap_threshold,
|
143
|
+
large_table_maximum_rows=config.large_table_maximum_rows,
|
144
|
+
force_large_table_assumption=config.force_large_table_assumption,
|
145
|
+
)
|
146
|
+
)
|
147
|
+
detector = AutoTableDetector(config=TATRDetectorConfig(detector_base_threshold=config.detector_base_threshold))
|
148
|
+
doc = await run_sync(PyPDFium2Document, str(file_path))
|
149
|
+
cropped_tables: list[CroppedTable] = []
|
150
|
+
dataframes: list[DataFrame] = []
|
151
|
+
try:
|
152
|
+
for page in doc:
|
153
|
+
cropped_tables.extend(await run_sync(detector.extract, page))
|
154
|
+
|
155
|
+
for cropped_table in cropped_tables:
|
156
|
+
formatted_table = await run_sync(formatter.extract, cropped_table)
|
157
|
+
dataframes.append(await run_sync(formatted_table.df))
|
158
|
+
|
159
|
+
return [
|
160
|
+
TableData(
|
161
|
+
cropped_image=cropped_table.image(),
|
162
|
+
page_number=cropped_table.page.page_number,
|
163
|
+
text=data_frame.to_markdown(),
|
164
|
+
df=data_frame,
|
165
|
+
)
|
166
|
+
for data_frame, cropped_table in zip(dataframes, cropped_tables)
|
167
|
+
]
|
168
|
+
finally:
|
169
|
+
await run_sync(doc.close)
|
170
|
+
|
171
|
+
except ImportError as e:
|
172
|
+
raise MissingDependencyError.create_for_package(
|
173
|
+
dependency_group="gmft", functionality="table extraction", package_name="gmft"
|
174
|
+
) from e
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from kreuzberg._ocr._base import OCRBackend
|
5
|
+
from kreuzberg._ocr._easyocr import EasyOCRBackend
|
6
|
+
from kreuzberg._ocr._paddleocr import PaddleBackend
|
7
|
+
from kreuzberg._ocr._tesseract import TesseractBackend
|
8
|
+
from kreuzberg._types import OcrBackendType
|
9
|
+
|
10
|
+
|
11
|
+
@lru_cache
|
12
|
+
def get_ocr_backend(backend: OcrBackendType) -> OCRBackend[Any]:
|
13
|
+
if backend == "easyocr":
|
14
|
+
return EasyOCRBackend()
|
15
|
+
if backend == "paddleocr":
|
16
|
+
return PaddleBackend()
|
17
|
+
return TesseractBackend()
|
kreuzberg/_ocr/_base.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Generic, TypeVar
|
4
|
+
|
5
|
+
from PIL.Image import Image
|
6
|
+
|
7
|
+
from kreuzberg._types import ExtractionResult
|
8
|
+
|
9
|
+
try: # pragma: no cover
|
10
|
+
from typing import Unpack # type: ignore[attr-defined]
|
11
|
+
except ImportError: # pragma: no cover
|
12
|
+
from typing_extensions import Unpack
|
13
|
+
|
14
|
+
|
15
|
+
T = TypeVar("T")
|
16
|
+
|
17
|
+
|
18
|
+
class OCRBackend(ABC, Generic[T]):
|
19
|
+
"""Abstract base class for Optical Character Recognition (OCR) backend implementations.
|
20
|
+
|
21
|
+
This class provides the blueprint for OCR backend implementations,
|
22
|
+
offering both synchronous and asynchronous methods to process images
|
23
|
+
and files for text extraction.
|
24
|
+
"""
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
async def process_image(self, image: Image, **kwargs: Unpack[T]) -> ExtractionResult:
|
28
|
+
"""Asynchronously process an image and extract its text and metadata.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
image: An instance of PIL.Image representing the input image.
|
32
|
+
**kwargs: Any kwargs related to the given backend
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The extraction result object
|
36
|
+
"""
|
37
|
+
...
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
async def process_file(self, path: Path, **kwargs: Unpack[T]) -> ExtractionResult:
|
41
|
+
"""Asynchronously process a file and extract its text and metadata.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
path: A Path object representing the file to be processed.
|
45
|
+
**kwargs: Any kwargs related to the given backend
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
The extraction result object
|
49
|
+
"""
|
50
|
+
...
|
51
|
+
|
52
|
+
def __hash__(self) -> int:
|
53
|
+
"""Hash function for allowing caching."""
|
54
|
+
return hash(type(self).__name__)
|
@@ -0,0 +1,376 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Final, Literal
|
5
|
+
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE
|
9
|
+
from kreuzberg._ocr._base import OCRBackend
|
10
|
+
from kreuzberg._types import ExtractionResult, Metadata
|
11
|
+
from kreuzberg._utils._string import normalize_spaces
|
12
|
+
from kreuzberg._utils._sync import run_sync
|
13
|
+
from kreuzberg.exceptions import MissingDependencyError, OCRError, ValidationError
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from pathlib import Path
|
17
|
+
|
18
|
+
try: # pragma: no cover
|
19
|
+
from typing import Unpack # type: ignore[attr-defined]
|
20
|
+
except ImportError: # pragma: no cover
|
21
|
+
from typing_extensions import Unpack
|
22
|
+
|
23
|
+
|
24
|
+
EASYOCR_SUPPORTED_LANGUAGE_CODES: Final[set[str]] = {
|
25
|
+
"abq",
|
26
|
+
"ady",
|
27
|
+
"af",
|
28
|
+
"ang",
|
29
|
+
"ar",
|
30
|
+
"as",
|
31
|
+
"ava",
|
32
|
+
"az",
|
33
|
+
"be",
|
34
|
+
"bg",
|
35
|
+
"bh",
|
36
|
+
"bho",
|
37
|
+
"bn",
|
38
|
+
"bs",
|
39
|
+
"ch_sim",
|
40
|
+
"ch_tra",
|
41
|
+
"che",
|
42
|
+
"cs",
|
43
|
+
"cy",
|
44
|
+
"da",
|
45
|
+
"dar",
|
46
|
+
"de",
|
47
|
+
"en",
|
48
|
+
"es",
|
49
|
+
"et",
|
50
|
+
"fa",
|
51
|
+
"fr",
|
52
|
+
"ga",
|
53
|
+
"gom",
|
54
|
+
"hi",
|
55
|
+
"hr",
|
56
|
+
"hu",
|
57
|
+
"id",
|
58
|
+
"inh", # codespell:ignore
|
59
|
+
"is",
|
60
|
+
"it",
|
61
|
+
"ja",
|
62
|
+
"kbd",
|
63
|
+
"kn",
|
64
|
+
"ko",
|
65
|
+
"ku",
|
66
|
+
"la",
|
67
|
+
"lbe",
|
68
|
+
"lez",
|
69
|
+
"lt",
|
70
|
+
"lv",
|
71
|
+
"mah",
|
72
|
+
"mai",
|
73
|
+
"mi",
|
74
|
+
"mn",
|
75
|
+
"mr",
|
76
|
+
"ms",
|
77
|
+
"mt",
|
78
|
+
"ne",
|
79
|
+
"new",
|
80
|
+
"nl",
|
81
|
+
"no",
|
82
|
+
"oc",
|
83
|
+
"pi",
|
84
|
+
"pl",
|
85
|
+
"pt",
|
86
|
+
"ro",
|
87
|
+
"ru",
|
88
|
+
"rs_cyrillic",
|
89
|
+
"rs_latin",
|
90
|
+
"sck",
|
91
|
+
"sk",
|
92
|
+
"sl",
|
93
|
+
"sq",
|
94
|
+
"sv",
|
95
|
+
"sw",
|
96
|
+
"ta",
|
97
|
+
"tab",
|
98
|
+
"te", # codespell:ignore
|
99
|
+
"th",
|
100
|
+
"tjk",
|
101
|
+
"tl",
|
102
|
+
"tr",
|
103
|
+
"ug",
|
104
|
+
"uk",
|
105
|
+
"ur",
|
106
|
+
"uz",
|
107
|
+
"vi",
|
108
|
+
}
|
109
|
+
|
110
|
+
|
111
|
+
@dataclass(unsafe_hash=True, frozen=True)
|
112
|
+
class EasyOCRConfig:
|
113
|
+
"""Configuration options for EasyOCR."""
|
114
|
+
|
115
|
+
add_margin: float = 0.1
|
116
|
+
"""Extend bounding boxes in all directions."""
|
117
|
+
adjust_contrast: float = 0.5
|
118
|
+
"""Target contrast level for low contrast text."""
|
119
|
+
beam_width: int = 5
|
120
|
+
"""Beam width for beam search in recognition."""
|
121
|
+
canvas_size: int = 2560
|
122
|
+
"""Maximum image dimension for detection."""
|
123
|
+
contrast_ths: float = 0.1
|
124
|
+
"""Contrast threshold for preprocessing."""
|
125
|
+
decoder: Literal["greedy", "beamsearch", "wordbeamsearch"] = "greedy"
|
126
|
+
"""Decoder method. Options: 'greedy', 'beamsearch', 'wordbeamsearch'."""
|
127
|
+
height_ths: float = 0.5
|
128
|
+
"""Maximum difference in box height for merging."""
|
129
|
+
language: str | list[str] = "en"
|
130
|
+
"""Language or languages to use for OCR."""
|
131
|
+
link_threshold: float = 0.4
|
132
|
+
"""Link confidence threshold."""
|
133
|
+
low_text: float = 0.4
|
134
|
+
"""Text low-bound score."""
|
135
|
+
mag_ratio: float = 1.0
|
136
|
+
"""Image magnification ratio."""
|
137
|
+
min_size: int = 10
|
138
|
+
"""Minimum text box size in pixels."""
|
139
|
+
rotation_info: list[int] | None = None
|
140
|
+
"""List of angles to try for detection."""
|
141
|
+
slope_ths: float = 0.1
|
142
|
+
"""Maximum slope for merging text boxes."""
|
143
|
+
text_threshold: float = 0.7
|
144
|
+
"""Text confidence threshold."""
|
145
|
+
use_gpu: bool = False
|
146
|
+
"""Whether to use GPU for inference."""
|
147
|
+
width_ths: float = 0.5
|
148
|
+
"""Maximum horizontal distance for merging boxes."""
|
149
|
+
x_ths: float = 1.0
|
150
|
+
"""Maximum horizontal distance for paragraph merging."""
|
151
|
+
y_ths: float = 0.5
|
152
|
+
"""Maximum vertical distance for paragraph merging."""
|
153
|
+
ycenter_ths: float = 0.5
|
154
|
+
"""Maximum shift in y direction for merging."""
|
155
|
+
|
156
|
+
|
157
|
+
class EasyOCRBackend(OCRBackend[EasyOCRConfig]):
|
158
|
+
_reader: ClassVar[Any] = None
|
159
|
+
|
160
|
+
async def process_image(self, image: Image.Image, **kwargs: Unpack[EasyOCRConfig]) -> ExtractionResult:
|
161
|
+
"""Asynchronously process an image and extract its text and metadata using EasyOCR.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
image: An instance of PIL.Image representing the input image.
|
165
|
+
**kwargs: Configuration parameters for EasyOCR including language, detection thresholds, etc.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
ExtractionResult: The extraction result containing text content, mime type, and metadata.
|
169
|
+
|
170
|
+
Raises:
|
171
|
+
OCRError: If OCR processing fails.
|
172
|
+
"""
|
173
|
+
await self._init_easyocr(**kwargs)
|
174
|
+
|
175
|
+
beam_width = kwargs.pop("beam_width")
|
176
|
+
try:
|
177
|
+
result = await run_sync(
|
178
|
+
self._reader.readtext,
|
179
|
+
image.tobytes(),
|
180
|
+
beamWidth=beam_width,
|
181
|
+
**kwargs,
|
182
|
+
)
|
183
|
+
|
184
|
+
return self._process_easyocr_result(result, image)
|
185
|
+
except Exception as e:
|
186
|
+
raise OCRError(f"Failed to OCR using EasyOCR: {e}") from e
|
187
|
+
|
188
|
+
async def process_file(self, path: Path, **kwargs: Unpack[EasyOCRConfig]) -> ExtractionResult:
|
189
|
+
"""Asynchronously process a file and extract its text and metadata using EasyOCR.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
path: A Path object representing the file to be processed.
|
193
|
+
**kwargs: Configuration parameters for EasyOCR including language, detection thresholds, etc.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
ExtractionResult: The extraction result containing text content, mime type, and metadata.
|
197
|
+
|
198
|
+
Raises:
|
199
|
+
OCRError: If file loading or OCR processing fails.
|
200
|
+
"""
|
201
|
+
await self._init_easyocr(**kwargs)
|
202
|
+
try:
|
203
|
+
image = await run_sync(Image.open, path)
|
204
|
+
return await self.process_image(image, **kwargs)
|
205
|
+
except Exception as e:
|
206
|
+
raise OCRError(f"Failed to load or process image using EasyOCR: {e}") from e
|
207
|
+
|
208
|
+
@staticmethod
|
209
|
+
def _process_easyocr_result(result: list[Any], image: Image.Image) -> ExtractionResult:
|
210
|
+
"""Process EasyOCR result into an ExtractionResult with metadata.
|
211
|
+
|
212
|
+
Args:
|
213
|
+
result: The raw result from EasyOCR.
|
214
|
+
image: The original PIL image.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
ExtractionResult: The extraction result containing text content, mime type, and metadata.
|
218
|
+
"""
|
219
|
+
if not result:
|
220
|
+
return ExtractionResult(
|
221
|
+
content="",
|
222
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
223
|
+
metadata=Metadata(width=image.width, height=image.height),
|
224
|
+
chunks=[],
|
225
|
+
)
|
226
|
+
|
227
|
+
expected_tuple_length = 2
|
228
|
+
|
229
|
+
if all(len(item) == expected_tuple_length for item in result):
|
230
|
+
text_content = ""
|
231
|
+
confidence_sum = 0
|
232
|
+
confidence_count = 0
|
233
|
+
|
234
|
+
for text, confidence in result:
|
235
|
+
if text:
|
236
|
+
text_content += text + "\n"
|
237
|
+
confidence_sum += confidence
|
238
|
+
confidence_count += 1
|
239
|
+
|
240
|
+
metadata = Metadata(
|
241
|
+
width=image.width,
|
242
|
+
height=image.height,
|
243
|
+
)
|
244
|
+
|
245
|
+
return ExtractionResult(
|
246
|
+
content=normalize_spaces(text_content), mime_type=PLAIN_TEXT_MIME_TYPE, metadata=metadata, chunks=[]
|
247
|
+
)
|
248
|
+
|
249
|
+
sorted_results = sorted(result, key=lambda x: x[0][0][1] + x[0][2][1])
|
250
|
+
line_groups: list[list[Any]] = []
|
251
|
+
current_line: list[Any] = []
|
252
|
+
prev_y_center: float | None = None
|
253
|
+
line_height_threshold = 20
|
254
|
+
|
255
|
+
for item in sorted_results:
|
256
|
+
box, text, confidence = item
|
257
|
+
y_center = sum(point[1] for point in box) / 4
|
258
|
+
|
259
|
+
if prev_y_center is None or abs(y_center - prev_y_center) > line_height_threshold:
|
260
|
+
if current_line:
|
261
|
+
line_groups.append(current_line)
|
262
|
+
current_line = [item]
|
263
|
+
else:
|
264
|
+
current_line.append(item)
|
265
|
+
|
266
|
+
prev_y_center = y_center
|
267
|
+
|
268
|
+
if current_line:
|
269
|
+
line_groups.append(current_line)
|
270
|
+
|
271
|
+
text_content = ""
|
272
|
+
confidence_sum = 0
|
273
|
+
confidence_count = 0
|
274
|
+
|
275
|
+
for line in line_groups:
|
276
|
+
line_sorted = sorted(line, key=lambda x: x[0][0][0])
|
277
|
+
|
278
|
+
for item in line_sorted:
|
279
|
+
_, text, confidence = item
|
280
|
+
if text:
|
281
|
+
text_content += text + " "
|
282
|
+
confidence_sum += confidence
|
283
|
+
confidence_count += 1
|
284
|
+
|
285
|
+
text_content += "\n"
|
286
|
+
|
287
|
+
metadata = Metadata(
|
288
|
+
width=image.width,
|
289
|
+
height=image.height,
|
290
|
+
)
|
291
|
+
|
292
|
+
return ExtractionResult(
|
293
|
+
content=normalize_spaces(text_content), mime_type=PLAIN_TEXT_MIME_TYPE, metadata=metadata, chunks=[]
|
294
|
+
)
|
295
|
+
|
296
|
+
@classmethod
|
297
|
+
def _is_gpu_available(cls) -> bool:
|
298
|
+
"""Check if GPU is available for EasyOCR.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
bool: True if GPU support is available.
|
302
|
+
"""
|
303
|
+
try:
|
304
|
+
import torch
|
305
|
+
|
306
|
+
return torch.cuda.is_available()
|
307
|
+
except ImportError:
|
308
|
+
return False
|
309
|
+
|
310
|
+
@classmethod
|
311
|
+
async def _init_easyocr(cls, **kwargs: Unpack[EasyOCRConfig]) -> None:
|
312
|
+
"""Initialize EasyOCR with the provided configuration.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
**kwargs: Configuration parameters for EasyOCR including language, etc.
|
316
|
+
|
317
|
+
Raises:
|
318
|
+
MissingDependencyError: If EasyOCR is not installed.
|
319
|
+
OCRError: If initialization fails.
|
320
|
+
"""
|
321
|
+
if cls._reader is not None:
|
322
|
+
return
|
323
|
+
|
324
|
+
try:
|
325
|
+
import easyocr
|
326
|
+
except ImportError as e:
|
327
|
+
raise MissingDependencyError.create_for_package(
|
328
|
+
dependency_group="easyocr", functionality="EasyOCR as an OCR backend", package_name="easyocr"
|
329
|
+
) from e
|
330
|
+
|
331
|
+
languages = cls._validate_language_code(kwargs.pop("language", "en"))
|
332
|
+
has_gpu = cls._is_gpu_available()
|
333
|
+
kwargs.setdefault("gpu", has_gpu)
|
334
|
+
kwargs.setdefault("detector", True)
|
335
|
+
kwargs.setdefault("recognizer", True)
|
336
|
+
kwargs.setdefault("download_enabled", True)
|
337
|
+
kwargs.setdefault("recog_network", "standard")
|
338
|
+
|
339
|
+
try:
|
340
|
+
cls._reader = await run_sync(
|
341
|
+
easyocr.Reader,
|
342
|
+
languages,
|
343
|
+
gpu=kwargs.get("use_gpu"),
|
344
|
+
verbose=False,
|
345
|
+
)
|
346
|
+
except Exception as e:
|
347
|
+
raise OCRError(f"Failed to initialize EasyOCR: {e}") from e
|
348
|
+
|
349
|
+
@staticmethod
|
350
|
+
def _validate_language_code(language_codes: str | list[str]) -> list[str]:
|
351
|
+
"""Validate and normalize a provided language code.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
language_codes: The language code string.
|
355
|
+
|
356
|
+
Raises:
|
357
|
+
ValidationError: If the language is not supported by EasyOCR
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
A list with the normalized language code.
|
361
|
+
"""
|
362
|
+
if not isinstance(language_codes, list):
|
363
|
+
languages = [language_codes.lower()]
|
364
|
+
else:
|
365
|
+
languages = [lang.lower() for lang in language_codes]
|
366
|
+
|
367
|
+
if all(lang in EASYOCR_SUPPORTED_LANGUAGE_CODES for lang in languages):
|
368
|
+
return languages
|
369
|
+
|
370
|
+
raise ValidationError(
|
371
|
+
"The provided language codes are not supported by EasyOCR",
|
372
|
+
context={
|
373
|
+
"language_code": ",".join([lang for lang in languages if lang not in EASYOCR_SUPPORTED_LANGUAGE_CODES]),
|
374
|
+
"supported_languages": ",".join(sorted(EASYOCR_SUPPORTED_LANGUAGE_CODES)),
|
375
|
+
},
|
376
|
+
)
|