kreuzberg 3.4.2__py3-none-any.whl → 3.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +6 -1
- kreuzberg/_entity_extraction.py +239 -0
- kreuzberg/_extractors/_image.py +21 -1
- kreuzberg/_extractors/_pdf.py +44 -14
- kreuzberg/_extractors/_spread_sheet.py +2 -2
- kreuzberg/_gmft.py +4 -4
- kreuzberg/_language_detection.py +95 -0
- kreuzberg/_multiprocessing/gmft_isolated.py +2 -4
- kreuzberg/_multiprocessing/process_manager.py +2 -1
- kreuzberg/_multiprocessing/sync_easyocr.py +235 -0
- kreuzberg/_multiprocessing/sync_paddleocr.py +199 -0
- kreuzberg/_ocr/_easyocr.py +1 -1
- kreuzberg/_ocr/_tesseract.py +7 -3
- kreuzberg/_types.py +46 -4
- kreuzberg/_utils/_device.py +2 -2
- kreuzberg/_utils/_process_pool.py +2 -2
- kreuzberg/_utils/_sync.py +1 -5
- kreuzberg/_utils/_tmp.py +2 -2
- kreuzberg/extraction.py +39 -12
- {kreuzberg-3.4.2.dist-info → kreuzberg-3.6.0.dist-info}/METADATA +12 -4
- {kreuzberg-3.4.2.dist-info → kreuzberg-3.6.0.dist-info}/RECORD +24 -20
- {kreuzberg-3.4.2.dist-info → kreuzberg-3.6.0.dist-info}/WHEEL +0 -0
- {kreuzberg-3.4.2.dist-info → kreuzberg-3.6.0.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.4.2.dist-info → kreuzberg-3.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
"""Pure synchronous EasyOCR without any async overhead."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import tempfile
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE
|
12
|
+
from kreuzberg._ocr._easyocr import EasyOCRConfig
|
13
|
+
from kreuzberg._types import ExtractionResult
|
14
|
+
from kreuzberg._utils._string import normalize_spaces
|
15
|
+
from kreuzberg.exceptions import MissingDependencyError, OCRError
|
16
|
+
|
17
|
+
|
18
|
+
def _get_easyocr_instance(config: EasyOCRConfig) -> Any:
|
19
|
+
"""Get an EasyOCR Reader instance with the given configuration."""
|
20
|
+
try:
|
21
|
+
import easyocr
|
22
|
+
except ImportError as e:
|
23
|
+
raise MissingDependencyError("EasyOCR is not installed. Install it with: pip install easyocr") from e
|
24
|
+
|
25
|
+
gpu = False
|
26
|
+
if hasattr(config, "device"):
|
27
|
+
if config.device and config.device.lower() != "cpu":
|
28
|
+
gpu = True
|
29
|
+
elif hasattr(config, "use_gpu"):
|
30
|
+
gpu = config.use_gpu
|
31
|
+
|
32
|
+
language = config.language if hasattr(config, "language") else "en"
|
33
|
+
if isinstance(language, str):
|
34
|
+
lang_list = [lang.strip().lower() for lang in language.split(",")]
|
35
|
+
else:
|
36
|
+
lang_list = [lang.lower() for lang in language]
|
37
|
+
|
38
|
+
kwargs = {
|
39
|
+
"lang_list": lang_list,
|
40
|
+
"gpu": gpu,
|
41
|
+
"model_storage_directory": getattr(config, "model_storage_directory", None),
|
42
|
+
"user_network_directory": getattr(config, "user_network_directory", None),
|
43
|
+
"recog_network": getattr(config, "recog_network", None),
|
44
|
+
"detector": getattr(config, "detector", None),
|
45
|
+
"recognizer": getattr(config, "recognizer", None),
|
46
|
+
"verbose": False,
|
47
|
+
"quantize": getattr(config, "quantize", None),
|
48
|
+
"cudnn_benchmark": getattr(config, "cudnn_benchmark", None),
|
49
|
+
}
|
50
|
+
|
51
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
52
|
+
|
53
|
+
return easyocr.Reader(**kwargs)
|
54
|
+
|
55
|
+
|
56
|
+
def process_image_sync_pure(
|
57
|
+
image_path: str | Path,
|
58
|
+
config: EasyOCRConfig | None = None,
|
59
|
+
) -> ExtractionResult:
|
60
|
+
"""Process an image with EasyOCR using pure sync implementation.
|
61
|
+
|
62
|
+
This bypasses all async overhead and calls EasyOCR directly.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
image_path: Path to the image file.
|
66
|
+
config: EasyOCR configuration.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Extraction result.
|
70
|
+
"""
|
71
|
+
cfg = config or EasyOCRConfig()
|
72
|
+
|
73
|
+
try:
|
74
|
+
reader = _get_easyocr_instance(cfg)
|
75
|
+
|
76
|
+
readtext_kwargs = {
|
77
|
+
"decoder": cfg.decoder,
|
78
|
+
"beamWidth": cfg.beam_width,
|
79
|
+
"batch_size": getattr(cfg, "batch_size", 1),
|
80
|
+
"workers": getattr(cfg, "workers", 0),
|
81
|
+
"allowlist": getattr(cfg, "allowlist", None),
|
82
|
+
"blocklist": getattr(cfg, "blocklist", None),
|
83
|
+
"detail": getattr(cfg, "detail", 1),
|
84
|
+
"rotation_info": cfg.rotation_info,
|
85
|
+
"paragraph": getattr(cfg, "paragraph", False),
|
86
|
+
"min_size": cfg.min_size,
|
87
|
+
"text_threshold": cfg.text_threshold,
|
88
|
+
"low_text": cfg.low_text,
|
89
|
+
"link_threshold": cfg.link_threshold,
|
90
|
+
"canvas_size": cfg.canvas_size,
|
91
|
+
"mag_ratio": cfg.mag_ratio,
|
92
|
+
"slope_ths": cfg.slope_ths,
|
93
|
+
"ycenter_ths": cfg.ycenter_ths,
|
94
|
+
"height_ths": cfg.height_ths,
|
95
|
+
"width_ths": cfg.width_ths,
|
96
|
+
"add_margin": cfg.add_margin,
|
97
|
+
"x_ths": cfg.x_ths,
|
98
|
+
"y_ths": cfg.y_ths,
|
99
|
+
}
|
100
|
+
|
101
|
+
readtext_kwargs = {k: v for k, v in readtext_kwargs.items() if v is not None}
|
102
|
+
|
103
|
+
results = reader.readtext(str(image_path), **readtext_kwargs)
|
104
|
+
|
105
|
+
if not results:
|
106
|
+
return ExtractionResult(
|
107
|
+
content="",
|
108
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
109
|
+
metadata={},
|
110
|
+
chunks=[],
|
111
|
+
)
|
112
|
+
|
113
|
+
texts = []
|
114
|
+
confidences = []
|
115
|
+
|
116
|
+
detail_value = getattr(cfg, "detail", 1)
|
117
|
+
if detail_value:
|
118
|
+
for result in results:
|
119
|
+
min_result_length = 2
|
120
|
+
max_confidence_index = 2
|
121
|
+
if len(result) >= min_result_length:
|
122
|
+
_bbox, text = result[0], result[1]
|
123
|
+
confidence = result[max_confidence_index] if len(result) > max_confidence_index else 1.0
|
124
|
+
texts.append(text)
|
125
|
+
confidences.append(confidence)
|
126
|
+
else:
|
127
|
+
texts = results
|
128
|
+
confidences = [1.0] * len(texts)
|
129
|
+
|
130
|
+
content = "\n".join(texts)
|
131
|
+
content = normalize_spaces(content)
|
132
|
+
|
133
|
+
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
134
|
+
|
135
|
+
metadata = {"confidence": avg_confidence} if confidences else {}
|
136
|
+
|
137
|
+
return ExtractionResult(
|
138
|
+
content=content,
|
139
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
140
|
+
metadata=metadata, # type: ignore[arg-type]
|
141
|
+
chunks=[],
|
142
|
+
)
|
143
|
+
|
144
|
+
except Exception as e:
|
145
|
+
raise OCRError(f"EasyOCR processing failed: {e}") from e
|
146
|
+
|
147
|
+
|
148
|
+
def process_image_bytes_sync_pure(
|
149
|
+
image_bytes: bytes,
|
150
|
+
config: EasyOCRConfig | None = None,
|
151
|
+
) -> ExtractionResult:
|
152
|
+
"""Process image bytes with EasyOCR using pure sync implementation.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
image_bytes: Image data as bytes.
|
156
|
+
config: EasyOCR configuration.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
Extraction result.
|
160
|
+
"""
|
161
|
+
import io
|
162
|
+
|
163
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_image:
|
164
|
+
with Image.open(io.BytesIO(image_bytes)) as image:
|
165
|
+
image.save(tmp_image.name, format="PNG")
|
166
|
+
image_path = tmp_image.name
|
167
|
+
|
168
|
+
try:
|
169
|
+
return process_image_sync_pure(image_path, config)
|
170
|
+
finally:
|
171
|
+
image_file = Path(image_path)
|
172
|
+
if image_file.exists():
|
173
|
+
image_file.unlink()
|
174
|
+
|
175
|
+
|
176
|
+
def process_batch_images_sync_pure(
|
177
|
+
image_paths: list[str | Path],
|
178
|
+
config: EasyOCRConfig | None = None,
|
179
|
+
) -> list[ExtractionResult]:
|
180
|
+
"""Process a batch of images sequentially with pure sync implementation.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
image_paths: List of image file paths.
|
184
|
+
config: EasyOCR configuration.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
List of extraction results.
|
188
|
+
"""
|
189
|
+
results = []
|
190
|
+
for image_path in image_paths:
|
191
|
+
result = process_image_sync_pure(image_path, config)
|
192
|
+
results.append(result)
|
193
|
+
return results
|
194
|
+
|
195
|
+
|
196
|
+
def process_batch_images_threaded(
|
197
|
+
image_paths: list[str | Path],
|
198
|
+
config: EasyOCRConfig | None = None,
|
199
|
+
max_workers: int | None = None,
|
200
|
+
) -> list[ExtractionResult]:
|
201
|
+
"""Process a batch of images using threading.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
image_paths: List of image file paths.
|
205
|
+
config: EasyOCR configuration.
|
206
|
+
max_workers: Maximum number of threads.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
List of extraction results in same order as input.
|
210
|
+
"""
|
211
|
+
import multiprocessing as mp
|
212
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
213
|
+
|
214
|
+
if max_workers is None:
|
215
|
+
max_workers = min(len(image_paths), mp.cpu_count())
|
216
|
+
|
217
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
218
|
+
future_to_index = {
|
219
|
+
executor.submit(process_image_sync_pure, path, config): i for i, path in enumerate(image_paths)
|
220
|
+
}
|
221
|
+
|
222
|
+
results: list[ExtractionResult] = [None] * len(image_paths) # type: ignore[list-item]
|
223
|
+
for future in as_completed(future_to_index):
|
224
|
+
index = future_to_index[future]
|
225
|
+
try:
|
226
|
+
results[index] = future.result()
|
227
|
+
except Exception as e: # noqa: BLE001
|
228
|
+
results[index] = ExtractionResult(
|
229
|
+
content=f"Error: {e}",
|
230
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
231
|
+
metadata={"error": str(e)}, # type: ignore[typeddict-unknown-key]
|
232
|
+
chunks=[],
|
233
|
+
)
|
234
|
+
|
235
|
+
return results
|
@@ -0,0 +1,199 @@
|
|
1
|
+
"""Pure synchronous PaddleOCR without any async overhead."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import tempfile
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE
|
12
|
+
from kreuzberg._ocr._paddleocr import PaddleOCRConfig
|
13
|
+
from kreuzberg._types import ExtractionResult
|
14
|
+
from kreuzberg._utils._string import normalize_spaces
|
15
|
+
from kreuzberg.exceptions import MissingDependencyError, OCRError
|
16
|
+
|
17
|
+
|
18
|
+
def _get_paddleocr_instance(config: PaddleOCRConfig) -> Any:
|
19
|
+
"""Get a PaddleOCR instance with the given configuration."""
|
20
|
+
try:
|
21
|
+
import paddleocr
|
22
|
+
except ImportError as e:
|
23
|
+
raise MissingDependencyError("PaddleOCR is not installed. Install it with: pip install paddleocr") from e
|
24
|
+
|
25
|
+
if hasattr(config, "device"):
|
26
|
+
if config.device and config.device.lower() != "cpu":
|
27
|
+
pass
|
28
|
+
elif hasattr(config, "use_gpu"):
|
29
|
+
pass
|
30
|
+
|
31
|
+
kwargs = {
|
32
|
+
"lang": config.language,
|
33
|
+
"use_textline_orientation": config.use_angle_cls,
|
34
|
+
}
|
35
|
+
|
36
|
+
if hasattr(config, "det_db_thresh"):
|
37
|
+
kwargs["text_det_thresh"] = config.det_db_thresh
|
38
|
+
if hasattr(config, "det_db_box_thresh"):
|
39
|
+
kwargs["text_det_box_thresh"] = config.det_db_box_thresh
|
40
|
+
if hasattr(config, "det_db_unclip_ratio"):
|
41
|
+
kwargs["text_det_unclip_ratio"] = config.det_db_unclip_ratio
|
42
|
+
if hasattr(config, "det_max_side_len"):
|
43
|
+
kwargs["text_det_limit_side_len"] = config.det_max_side_len
|
44
|
+
if hasattr(config, "drop_score"):
|
45
|
+
kwargs["text_rec_score_thresh"] = config.drop_score
|
46
|
+
|
47
|
+
return paddleocr.PaddleOCR(**kwargs)
|
48
|
+
|
49
|
+
|
50
|
+
def process_image_sync_pure(
|
51
|
+
image_path: str | Path,
|
52
|
+
config: PaddleOCRConfig | None = None,
|
53
|
+
) -> ExtractionResult:
|
54
|
+
"""Process an image with PaddleOCR using pure sync implementation.
|
55
|
+
|
56
|
+
This bypasses all async overhead and calls PaddleOCR directly.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
image_path: Path to the image file.
|
60
|
+
config: PaddleOCR configuration.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Extraction result.
|
64
|
+
"""
|
65
|
+
cfg = config or PaddleOCRConfig()
|
66
|
+
|
67
|
+
try:
|
68
|
+
ocr_instance = _get_paddleocr_instance(cfg)
|
69
|
+
|
70
|
+
results = ocr_instance.ocr(str(image_path))
|
71
|
+
|
72
|
+
if not results or not results[0]:
|
73
|
+
return ExtractionResult(
|
74
|
+
content="",
|
75
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
76
|
+
metadata={},
|
77
|
+
chunks=[],
|
78
|
+
)
|
79
|
+
|
80
|
+
ocr_result = results[0]
|
81
|
+
result_data = ocr_result.json["res"]
|
82
|
+
|
83
|
+
texts = result_data.get("rec_texts", [])
|
84
|
+
scores = result_data.get("rec_scores", [])
|
85
|
+
|
86
|
+
if not texts:
|
87
|
+
return ExtractionResult(
|
88
|
+
content="",
|
89
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
90
|
+
metadata={},
|
91
|
+
chunks=[],
|
92
|
+
)
|
93
|
+
|
94
|
+
content = "\n".join(texts)
|
95
|
+
content = normalize_spaces(content)
|
96
|
+
|
97
|
+
avg_confidence = sum(scores) / len(scores) if scores else 0.0
|
98
|
+
|
99
|
+
metadata = {"confidence": avg_confidence} if scores else {}
|
100
|
+
|
101
|
+
return ExtractionResult(
|
102
|
+
content=content,
|
103
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
104
|
+
metadata=metadata, # type: ignore[arg-type]
|
105
|
+
chunks=[],
|
106
|
+
)
|
107
|
+
|
108
|
+
except Exception as e:
|
109
|
+
raise OCRError(f"PaddleOCR processing failed: {e}") from e
|
110
|
+
|
111
|
+
|
112
|
+
def process_image_bytes_sync_pure(
|
113
|
+
image_bytes: bytes,
|
114
|
+
config: PaddleOCRConfig | None = None,
|
115
|
+
) -> ExtractionResult:
|
116
|
+
"""Process image bytes with PaddleOCR using pure sync implementation.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
image_bytes: Image data as bytes.
|
120
|
+
config: PaddleOCR configuration.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Extraction result.
|
124
|
+
"""
|
125
|
+
import io
|
126
|
+
|
127
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_image:
|
128
|
+
with Image.open(io.BytesIO(image_bytes)) as image:
|
129
|
+
image.save(tmp_image.name, format="PNG")
|
130
|
+
image_path = tmp_image.name
|
131
|
+
|
132
|
+
try:
|
133
|
+
return process_image_sync_pure(image_path, config)
|
134
|
+
finally:
|
135
|
+
image_file = Path(image_path)
|
136
|
+
if image_file.exists():
|
137
|
+
image_file.unlink()
|
138
|
+
|
139
|
+
|
140
|
+
def process_batch_images_sync_pure(
|
141
|
+
image_paths: list[str | Path],
|
142
|
+
config: PaddleOCRConfig | None = None,
|
143
|
+
) -> list[ExtractionResult]:
|
144
|
+
"""Process a batch of images sequentially with pure sync implementation.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
image_paths: List of image file paths.
|
148
|
+
config: PaddleOCR configuration.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
List of extraction results.
|
152
|
+
"""
|
153
|
+
results = []
|
154
|
+
for image_path in image_paths:
|
155
|
+
result = process_image_sync_pure(image_path, config)
|
156
|
+
results.append(result)
|
157
|
+
return results
|
158
|
+
|
159
|
+
|
160
|
+
def process_batch_images_threaded(
|
161
|
+
image_paths: list[str | Path],
|
162
|
+
config: PaddleOCRConfig | None = None,
|
163
|
+
max_workers: int | None = None,
|
164
|
+
) -> list[ExtractionResult]:
|
165
|
+
"""Process a batch of images using threading.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
image_paths: List of image file paths.
|
169
|
+
config: PaddleOCR configuration.
|
170
|
+
max_workers: Maximum number of threads.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
List of extraction results in same order as input.
|
174
|
+
"""
|
175
|
+
import multiprocessing as mp
|
176
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
177
|
+
|
178
|
+
if max_workers is None:
|
179
|
+
max_workers = min(len(image_paths), mp.cpu_count())
|
180
|
+
|
181
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
182
|
+
future_to_index = {
|
183
|
+
executor.submit(process_image_sync_pure, path, config): i for i, path in enumerate(image_paths)
|
184
|
+
}
|
185
|
+
|
186
|
+
results: list[ExtractionResult] = [None] * len(image_paths) # type: ignore[list-item]
|
187
|
+
for future in as_completed(future_to_index):
|
188
|
+
index = future_to_index[future]
|
189
|
+
try:
|
190
|
+
results[index] = future.result()
|
191
|
+
except Exception as e: # noqa: BLE001
|
192
|
+
results[index] = ExtractionResult(
|
193
|
+
content=f"Error: {e}",
|
194
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
195
|
+
metadata={"error": str(e)}, # type: ignore[typeddict-unknown-key]
|
196
|
+
chunks=[],
|
197
|
+
)
|
198
|
+
|
199
|
+
return results
|
kreuzberg/_ocr/_easyocr.py
CHANGED
kreuzberg/_ocr/_tesseract.py
CHANGED
@@ -202,9 +202,11 @@ class TesseractConfig:
|
|
202
202
|
- 'deu' for German
|
203
203
|
- multiple languages combined with '+', e.g. 'eng+deu')
|
204
204
|
"""
|
205
|
-
language_model_ngram_on: bool =
|
206
|
-
"""Enable or disable the use of n-gram-based language models for improved text recognition.
|
207
|
-
|
205
|
+
language_model_ngram_on: bool = False
|
206
|
+
"""Enable or disable the use of n-gram-based language models for improved text recognition.
|
207
|
+
|
208
|
+
Default is False for optimal performance on modern documents. Enable for degraded or historical text."""
|
209
|
+
psm: PSMMode = PSMMode.AUTO_ONLY
|
208
210
|
"""Page segmentation mode (PSM) to guide Tesseract on how to segment the image (e.g., single block, single line)."""
|
209
211
|
tessedit_dont_blkrej_good_wds: bool = True
|
210
212
|
"""If True, prevents block rejection of words identified as good, improving text output quality."""
|
@@ -212,6 +214,8 @@ class TesseractConfig:
|
|
212
214
|
"""If True, prevents row rejection of words identified as good, avoiding unnecessary omissions."""
|
213
215
|
tessedit_enable_dict_correction: bool = True
|
214
216
|
"""Enable or disable dictionary-based correction for recognized text to improve word accuracy."""
|
217
|
+
tessedit_char_whitelist: str = ""
|
218
|
+
"""Whitelist of characters that Tesseract is allowed to recognize. Empty string means no restriction."""
|
215
219
|
tessedit_use_primary_params_model: bool = True
|
216
220
|
"""If True, forces the use of the primary parameters model for text recognition."""
|
217
221
|
textord_space_size_is_variable: bool = True
|
kreuzberg/_types.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import sys
|
4
|
-
from collections.abc import Awaitable
|
4
|
+
from collections.abc import Awaitable, Callable
|
5
5
|
from dataclasses import asdict, dataclass, field
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
7
7
|
|
8
8
|
from kreuzberg._constants import DEFAULT_MAX_CHARACTERS, DEFAULT_MAX_OVERLAP
|
9
9
|
from kreuzberg.exceptions import ValidationError
|
@@ -17,7 +17,9 @@ if TYPE_CHECKING:
|
|
17
17
|
from pandas import DataFrame
|
18
18
|
from PIL.Image import Image
|
19
19
|
|
20
|
+
from kreuzberg._entity_extraction import SpacyEntityExtractionConfig
|
20
21
|
from kreuzberg._gmft import GMFTConfig
|
22
|
+
from kreuzberg._language_detection import LanguageDetectionConfig
|
21
23
|
from kreuzberg._ocr._easyocr import EasyOCRConfig
|
22
24
|
from kreuzberg._ocr._paddleocr import PaddleOCRConfig
|
23
25
|
from kreuzberg._ocr._tesseract import TesseractConfig
|
@@ -99,6 +101,20 @@ class Metadata(TypedDict, total=False):
|
|
99
101
|
"""Width of the document page/slide/image, if applicable."""
|
100
102
|
|
101
103
|
|
104
|
+
@dataclass(frozen=True)
|
105
|
+
class Entity:
|
106
|
+
"""Represents an extracted entity with type, text, and position."""
|
107
|
+
|
108
|
+
type: str
|
109
|
+
"""e.g., PERSON, ORGANIZATION, LOCATION, DATE, EMAIL, PHONE, or custom"""
|
110
|
+
text: str
|
111
|
+
"""Extracted text"""
|
112
|
+
start: int
|
113
|
+
"""Start character offset in the content"""
|
114
|
+
end: int
|
115
|
+
"""End character offset in the content"""
|
116
|
+
|
117
|
+
|
102
118
|
@dataclass
|
103
119
|
class ExtractionResult:
|
104
120
|
"""The result of a file extraction."""
|
@@ -113,14 +129,20 @@ class ExtractionResult:
|
|
113
129
|
"""Extracted tables. Is an empty list if 'extract_tables' is not set to True in the ExtractionConfig."""
|
114
130
|
chunks: list[str] = field(default_factory=list)
|
115
131
|
"""The extracted content chunks. This is an empty list if 'chunk_content' is not set to True in the ExtractionConfig."""
|
132
|
+
entities: list[Entity] | None = None
|
133
|
+
"""Extracted entities, if entity extraction is enabled."""
|
134
|
+
keywords: list[tuple[str, float]] | None = None
|
135
|
+
"""Extracted keywords and their scores, if keyword extraction is enabled."""
|
136
|
+
detected_languages: list[str] | None = None
|
137
|
+
"""Languages detected in the extracted content, if language detection is enabled."""
|
116
138
|
|
117
139
|
def to_dict(self) -> dict[str, Any]:
|
118
140
|
"""Converts the ExtractionResult to a dictionary."""
|
119
141
|
return asdict(self)
|
120
142
|
|
121
143
|
|
122
|
-
PostProcessingHook = Callable[[ExtractionResult],
|
123
|
-
ValidationHook = Callable[[ExtractionResult],
|
144
|
+
PostProcessingHook = Callable[[ExtractionResult], ExtractionResult | Awaitable[ExtractionResult]]
|
145
|
+
ValidationHook = Callable[[ExtractionResult], None | Awaitable[None]]
|
124
146
|
|
125
147
|
|
126
148
|
@dataclass(unsafe_hash=True)
|
@@ -157,8 +179,28 @@ class ExtractionConfig:
|
|
157
179
|
"""Post processing hooks to call after processing is done and before the final result is returned."""
|
158
180
|
validators: list[ValidationHook] | None = None
|
159
181
|
"""Validation hooks to call after processing is done and before post-processing and result return."""
|
182
|
+
extract_entities: bool = False
|
183
|
+
"""Whether to extract named entities from the content."""
|
184
|
+
extract_keywords: bool = False
|
185
|
+
"""Whether to extract keywords from the content."""
|
186
|
+
keyword_count: int = 10
|
187
|
+
"""Number of keywords to extract if extract_keywords is True."""
|
188
|
+
custom_entity_patterns: frozenset[tuple[str, str]] | None = None
|
189
|
+
"""Custom entity patterns as a frozenset of (entity_type, regex_pattern) tuples."""
|
190
|
+
auto_detect_language: bool = False
|
191
|
+
"""Whether to automatically detect language and configure OCR accordingly."""
|
192
|
+
language_detection_config: LanguageDetectionConfig | None = None
|
193
|
+
"""Configuration for language detection. If None, uses default settings."""
|
194
|
+
spacy_entity_extraction_config: SpacyEntityExtractionConfig | None = None
|
195
|
+
"""Configuration for spaCy entity extraction. If None, uses default settings."""
|
160
196
|
|
161
197
|
def __post_init__(self) -> None:
|
198
|
+
if self.custom_entity_patterns is not None and isinstance(self.custom_entity_patterns, dict):
|
199
|
+
object.__setattr__(self, "custom_entity_patterns", frozenset(self.custom_entity_patterns.items()))
|
200
|
+
if self.post_processing_hooks is not None and isinstance(self.post_processing_hooks, list):
|
201
|
+
object.__setattr__(self, "post_processing_hooks", tuple(self.post_processing_hooks))
|
202
|
+
if self.validators is not None and isinstance(self.validators, list):
|
203
|
+
object.__setattr__(self, "validators", tuple(self.validators))
|
162
204
|
from kreuzberg._ocr._easyocr import EasyOCRConfig
|
163
205
|
from kreuzberg._ocr._paddleocr import PaddleOCRConfig
|
164
206
|
from kreuzberg._ocr._tesseract import TesseractConfig
|
kreuzberg/_utils/_device.py
CHANGED
@@ -153,7 +153,7 @@ def _is_cuda_available() -> bool:
|
|
153
153
|
try:
|
154
154
|
import torch # type: ignore[import-not-found,unused-ignore]
|
155
155
|
|
156
|
-
return torch.cuda.is_available()
|
156
|
+
return bool(torch.cuda.is_available())
|
157
157
|
except ImportError:
|
158
158
|
return False
|
159
159
|
|
@@ -163,7 +163,7 @@ def _is_mps_available() -> bool:
|
|
163
163
|
try:
|
164
164
|
import torch # type: ignore[import-not-found,unused-ignore]
|
165
165
|
|
166
|
-
return torch.backends.mps.is_available()
|
166
|
+
return bool(torch.backends.mps.is_available())
|
167
167
|
except ImportError:
|
168
168
|
return False
|
169
169
|
|
@@ -5,10 +5,10 @@ from __future__ import annotations
|
|
5
5
|
import multiprocessing as mp
|
6
6
|
from concurrent.futures import ProcessPoolExecutor
|
7
7
|
from contextlib import contextmanager
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
8
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
9
9
|
|
10
10
|
if TYPE_CHECKING:
|
11
|
-
from collections.abc import Generator
|
11
|
+
from collections.abc import Callable, Generator
|
12
12
|
|
13
13
|
T = TypeVar("T")
|
14
14
|
|
kreuzberg/_utils/_sync.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import sys
|
4
3
|
from functools import partial
|
5
4
|
from inspect import isawaitable, iscoroutinefunction
|
6
5
|
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
@@ -12,10 +11,7 @@ from anyio.to_thread import run_sync as any_io_run_sync
|
|
12
11
|
if TYPE_CHECKING: # pragma: no cover
|
13
12
|
from collections.abc import Awaitable, Callable
|
14
13
|
|
15
|
-
|
16
|
-
from typing import ParamSpec
|
17
|
-
else: # pragma: no cover
|
18
|
-
from typing_extensions import ParamSpec
|
14
|
+
from typing import ParamSpec
|
19
15
|
|
20
16
|
T = TypeVar("T")
|
21
17
|
P = ParamSpec("P")
|
kreuzberg/_utils/_tmp.py
CHANGED
@@ -3,14 +3,14 @@ from __future__ import annotations
|
|
3
3
|
from contextlib import suppress
|
4
4
|
from pathlib import Path
|
5
5
|
from tempfile import NamedTemporaryFile
|
6
|
-
from typing import TYPE_CHECKING
|
6
|
+
from typing import TYPE_CHECKING
|
7
7
|
|
8
8
|
from anyio import Path as AsyncPath
|
9
9
|
|
10
10
|
from kreuzberg._utils._sync import run_sync
|
11
11
|
|
12
12
|
if TYPE_CHECKING: # pragma: no cover
|
13
|
-
from collections.abc import Coroutine
|
13
|
+
from collections.abc import Callable, Coroutine
|
14
14
|
|
15
15
|
|
16
16
|
async def create_temp_file(
|