kreuzberg 3.8.1__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +4 -0
- kreuzberg/_api/main.py +22 -1
- kreuzberg/_chunker.py +3 -3
- kreuzberg/_config.py +404 -0
- kreuzberg/_document_classification.py +156 -0
- kreuzberg/_entity_extraction.py +6 -6
- kreuzberg/_extractors/_image.py +4 -3
- kreuzberg/_extractors/_pdf.py +40 -29
- kreuzberg/_extractors/_spread_sheet.py +6 -8
- kreuzberg/_extractors/_structured.py +34 -25
- kreuzberg/_gmft.py +33 -42
- kreuzberg/_language_detection.py +1 -1
- kreuzberg/_mcp/server.py +58 -8
- kreuzberg/_mime_types.py +1 -1
- kreuzberg/_ocr/_base.py +1 -1
- kreuzberg/_ocr/_easyocr.py +5 -5
- kreuzberg/_ocr/_paddleocr.py +4 -4
- kreuzberg/_ocr/_tesseract.py +12 -21
- kreuzberg/_playa.py +2 -3
- kreuzberg/_types.py +65 -27
- kreuzberg/_utils/_cache.py +14 -17
- kreuzberg/_utils/_device.py +17 -27
- kreuzberg/_utils/_errors.py +41 -38
- kreuzberg/_utils/_quality.py +7 -11
- kreuzberg/_utils/_serialization.py +21 -16
- kreuzberg/_utils/_string.py +22 -12
- kreuzberg/_utils/_table.py +3 -4
- kreuzberg/cli.py +5 -5
- kreuzberg/exceptions.py +10 -0
- kreuzberg/extraction.py +20 -11
- kreuzberg-3.9.0.dist-info/METADATA +269 -0
- kreuzberg-3.9.0.dist-info/RECORD +54 -0
- kreuzberg/_cli_config.py +0 -175
- kreuzberg-3.8.1.dist-info/METADATA +0 -301
- kreuzberg-3.8.1.dist-info/RECORD +0 -53
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/WHEEL +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/licenses/LICENSE +0 -0
kreuzberg/_mcp/server.py
CHANGED
@@ -3,11 +3,14 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import base64
|
6
|
+
import json
|
6
7
|
from typing import Any
|
7
8
|
|
9
|
+
import msgspec
|
8
10
|
from mcp.server import FastMCP
|
9
11
|
from mcp.types import TextContent
|
10
12
|
|
13
|
+
from kreuzberg._config import try_discover_config
|
11
14
|
from kreuzberg._types import ExtractionConfig, OcrBackendType
|
12
15
|
from kreuzberg.extraction import extract_bytes_sync, extract_file_sync
|
13
16
|
|
@@ -15,6 +18,44 @@ from kreuzberg.extraction import extract_bytes_sync, extract_file_sync
|
|
15
18
|
mcp = FastMCP("Kreuzberg Text Extraction")
|
16
19
|
|
17
20
|
|
21
|
+
def _create_config_with_overrides(**kwargs: Any) -> ExtractionConfig:
|
22
|
+
"""Create ExtractionConfig with discovered config as base and tool parameters as overrides.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
**kwargs: Tool parameters to override defaults/discovered config.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
ExtractionConfig instance.
|
29
|
+
"""
|
30
|
+
# Try to discover configuration from files
|
31
|
+
base_config = try_discover_config()
|
32
|
+
|
33
|
+
if base_config is None:
|
34
|
+
# No config file found, use defaults
|
35
|
+
return ExtractionConfig(**kwargs)
|
36
|
+
|
37
|
+
# Merge discovered config with tool parameters (tool params take precedence)
|
38
|
+
config_dict: dict[str, Any] = {
|
39
|
+
"force_ocr": base_config.force_ocr,
|
40
|
+
"chunk_content": base_config.chunk_content,
|
41
|
+
"extract_tables": base_config.extract_tables,
|
42
|
+
"extract_entities": base_config.extract_entities,
|
43
|
+
"extract_keywords": base_config.extract_keywords,
|
44
|
+
"ocr_backend": base_config.ocr_backend,
|
45
|
+
"max_chars": base_config.max_chars,
|
46
|
+
"max_overlap": base_config.max_overlap,
|
47
|
+
"keyword_count": base_config.keyword_count,
|
48
|
+
"auto_detect_language": base_config.auto_detect_language,
|
49
|
+
"ocr_config": base_config.ocr_config,
|
50
|
+
"gmft_config": base_config.gmft_config,
|
51
|
+
}
|
52
|
+
|
53
|
+
# Override with provided parameters
|
54
|
+
config_dict.update(kwargs)
|
55
|
+
|
56
|
+
return ExtractionConfig(**config_dict)
|
57
|
+
|
58
|
+
|
18
59
|
@mcp.tool()
|
19
60
|
def extract_document( # noqa: PLR0913
|
20
61
|
file_path: str,
|
@@ -49,7 +90,7 @@ def extract_document( # noqa: PLR0913
|
|
49
90
|
Returns:
|
50
91
|
Extracted content with metadata, tables, chunks, entities, and keywords
|
51
92
|
"""
|
52
|
-
config =
|
93
|
+
config = _create_config_with_overrides(
|
53
94
|
force_ocr=force_ocr,
|
54
95
|
chunk_content=chunk_content,
|
55
96
|
extract_tables=extract_tables,
|
@@ -63,7 +104,7 @@ def extract_document( # noqa: PLR0913
|
|
63
104
|
)
|
64
105
|
|
65
106
|
result = extract_file_sync(file_path, mime_type, config)
|
66
|
-
return result.to_dict()
|
107
|
+
return result.to_dict(include_none=True)
|
67
108
|
|
68
109
|
|
69
110
|
@mcp.tool()
|
@@ -102,7 +143,7 @@ def extract_bytes( # noqa: PLR0913
|
|
102
143
|
"""
|
103
144
|
content_bytes = base64.b64decode(content_base64)
|
104
145
|
|
105
|
-
config =
|
146
|
+
config = _create_config_with_overrides(
|
106
147
|
force_ocr=force_ocr,
|
107
148
|
chunk_content=chunk_content,
|
108
149
|
extract_tables=extract_tables,
|
@@ -116,7 +157,7 @@ def extract_bytes( # noqa: PLR0913
|
|
116
157
|
)
|
117
158
|
|
118
159
|
result = extract_bytes_sync(content_bytes, mime_type, config)
|
119
|
-
return result.to_dict()
|
160
|
+
return result.to_dict(include_none=True)
|
120
161
|
|
121
162
|
|
122
163
|
@mcp.tool()
|
@@ -133,7 +174,7 @@ def extract_simple(
|
|
133
174
|
Returns:
|
134
175
|
Extracted text content as a string
|
135
176
|
"""
|
136
|
-
config =
|
177
|
+
config = _create_config_with_overrides()
|
137
178
|
result = extract_file_sync(file_path, mime_type, config)
|
138
179
|
return result.content
|
139
180
|
|
@@ -142,7 +183,16 @@ def extract_simple(
|
|
142
183
|
def get_default_config() -> str:
|
143
184
|
"""Get the default extraction configuration."""
|
144
185
|
config = ExtractionConfig()
|
145
|
-
return
|
186
|
+
return json.dumps(msgspec.to_builtins(config, order="deterministic"), indent=2)
|
187
|
+
|
188
|
+
|
189
|
+
@mcp.resource("config://discovered")
|
190
|
+
def get_discovered_config() -> str:
|
191
|
+
"""Get the discovered configuration from config files."""
|
192
|
+
config = try_discover_config()
|
193
|
+
if config is None:
|
194
|
+
return "No configuration file found"
|
195
|
+
return json.dumps(msgspec.to_builtins(config, order="deterministic"), indent=2)
|
146
196
|
|
147
197
|
|
148
198
|
@mcp.resource("config://available-backends")
|
@@ -175,7 +225,7 @@ def extract_and_summarize(file_path: str) -> list[TextContent]:
|
|
175
225
|
Returns:
|
176
226
|
Extracted content with summarization prompt
|
177
227
|
"""
|
178
|
-
result = extract_file_sync(file_path, None,
|
228
|
+
result = extract_file_sync(file_path, None, _create_config_with_overrides())
|
179
229
|
|
180
230
|
return [
|
181
231
|
TextContent(
|
@@ -195,7 +245,7 @@ def extract_structured(file_path: str) -> list[TextContent]:
|
|
195
245
|
Returns:
|
196
246
|
Extracted content with structured analysis prompt
|
197
247
|
"""
|
198
|
-
config =
|
248
|
+
config = _create_config_with_overrides(
|
199
249
|
extract_entities=True,
|
200
250
|
extract_keywords=True,
|
201
251
|
extract_tables=True,
|
kreuzberg/_mime_types.py
CHANGED
@@ -191,7 +191,7 @@ def validate_mime_type(
|
|
191
191
|
return _validate_explicit_mime_type(mime_type)
|
192
192
|
|
193
193
|
if file_path:
|
194
|
-
from kreuzberg._utils._cache import get_mime_cache
|
194
|
+
from kreuzberg._utils._cache import get_mime_cache # noqa: PLC0415
|
195
195
|
|
196
196
|
path = Path(file_path)
|
197
197
|
|
kreuzberg/_ocr/_base.py
CHANGED
@@ -103,7 +103,7 @@ class OCRBackend(ABC, Generic[T]):
|
|
103
103
|
Returns:
|
104
104
|
List of extraction result objects in the same order as input paths
|
105
105
|
"""
|
106
|
-
from kreuzberg._utils._sync import run_taskgroup
|
106
|
+
from kreuzberg._utils._sync import run_taskgroup # noqa: PLC0415
|
107
107
|
|
108
108
|
tasks = [self.process_file(path, **kwargs) for path in paths]
|
109
109
|
return await run_taskgroup(*tasks)
|
kreuzberg/_ocr/_easyocr.py
CHANGED
@@ -111,7 +111,7 @@ EASYOCR_SUPPORTED_LANGUAGE_CODES: Final[set[str]] = {
|
|
111
111
|
}
|
112
112
|
|
113
113
|
|
114
|
-
@dataclass(unsafe_hash=True, frozen=True)
|
114
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
115
115
|
class EasyOCRConfig:
|
116
116
|
"""Configuration options for EasyOCR."""
|
117
117
|
|
@@ -180,7 +180,7 @@ class EasyOCRBackend(OCRBackend[EasyOCRConfig]):
|
|
180
180
|
Raises:
|
181
181
|
OCRError: If OCR processing fails.
|
182
182
|
"""
|
183
|
-
import numpy as np
|
183
|
+
import numpy as np # noqa: PLC0415
|
184
184
|
|
185
185
|
await self._init_easyocr(**kwargs)
|
186
186
|
|
@@ -318,7 +318,7 @@ class EasyOCRBackend(OCRBackend[EasyOCRConfig]):
|
|
318
318
|
bool: True if GPU support is available.
|
319
319
|
"""
|
320
320
|
try:
|
321
|
-
import torch
|
321
|
+
import torch # noqa: PLC0415
|
322
322
|
|
323
323
|
return bool(torch.cuda.is_available())
|
324
324
|
except ImportError:
|
@@ -339,7 +339,7 @@ class EasyOCRBackend(OCRBackend[EasyOCRConfig]):
|
|
339
339
|
return
|
340
340
|
|
341
341
|
try:
|
342
|
-
import easyocr
|
342
|
+
import easyocr # noqa: PLC0415
|
343
343
|
except ImportError as e:
|
344
344
|
raise MissingDependencyError.create_for_package(
|
345
345
|
dependency_group="easyocr", functionality="EasyOCR as an OCR backend", package_name="easyocr"
|
@@ -507,7 +507,7 @@ class EasyOCRBackend(OCRBackend[EasyOCRConfig]):
|
|
507
507
|
return
|
508
508
|
|
509
509
|
try:
|
510
|
-
import easyocr
|
510
|
+
import easyocr # noqa: PLC0415
|
511
511
|
except ImportError as e:
|
512
512
|
raise MissingDependencyError.create_for_package(
|
513
513
|
dependency_group="easyocr", functionality="EasyOCR as an OCR backend", package_name="easyocr"
|
kreuzberg/_ocr/_paddleocr.py
CHANGED
@@ -31,7 +31,7 @@ except ImportError: # pragma: no cover
|
|
31
31
|
PADDLEOCR_SUPPORTED_LANGUAGE_CODES: Final[set[str]] = {"ch", "en", "french", "german", "japan", "korean"}
|
32
32
|
|
33
33
|
|
34
|
-
@dataclass(unsafe_hash=True, frozen=True)
|
34
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
35
35
|
class PaddleOCRConfig:
|
36
36
|
"""Configuration options for PaddleOCR.
|
37
37
|
|
@@ -124,7 +124,7 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
|
|
124
124
|
Raises:
|
125
125
|
OCRError: If OCR processing fails.
|
126
126
|
"""
|
127
|
-
import numpy as np
|
127
|
+
import numpy as np # noqa: PLC0415
|
128
128
|
|
129
129
|
await self._init_paddle_ocr(**kwargs)
|
130
130
|
|
@@ -260,7 +260,7 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
|
|
260
260
|
return
|
261
261
|
|
262
262
|
try:
|
263
|
-
from paddleocr import PaddleOCR
|
263
|
+
from paddleocr import PaddleOCR # noqa: PLC0415
|
264
264
|
except ImportError as e:
|
265
265
|
raise MissingDependencyError.create_for_package(
|
266
266
|
dependency_group="paddleocr", functionality="PaddleOCR as an OCR backend", package_name="paddleocr"
|
@@ -427,7 +427,7 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
|
|
427
427
|
return
|
428
428
|
|
429
429
|
try:
|
430
|
-
from paddleocr import PaddleOCR
|
430
|
+
from paddleocr import PaddleOCR # noqa: PLC0415
|
431
431
|
except ImportError as e:
|
432
432
|
raise MissingDependencyError.create_for_package(
|
433
433
|
dependency_group="paddleocr", functionality="PaddleOCR as an OCR backend", package_name="paddleocr"
|
kreuzberg/_ocr/_tesseract.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import hashlib
|
4
|
+
import io
|
4
5
|
import os
|
5
6
|
import re
|
6
7
|
import subprocess
|
@@ -11,8 +12,10 @@ from enum import Enum
|
|
11
12
|
from pathlib import Path
|
12
13
|
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
13
14
|
|
15
|
+
import anyio
|
14
16
|
from anyio import Path as AsyncPath
|
15
17
|
from anyio import run_process
|
18
|
+
from PIL import Image
|
16
19
|
from typing_extensions import Self
|
17
20
|
|
18
21
|
from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE
|
@@ -24,7 +27,7 @@ from kreuzberg._utils._tmp import create_temp_file
|
|
24
27
|
from kreuzberg.exceptions import MissingDependencyError, OCRError, ValidationError
|
25
28
|
|
26
29
|
if TYPE_CHECKING:
|
27
|
-
from PIL.Image import Image
|
30
|
+
from PIL.Image import Image as PILImage
|
28
31
|
|
29
32
|
try: # pragma: no cover
|
30
33
|
from typing import Unpack # type: ignore[attr-defined]
|
@@ -192,7 +195,7 @@ class PSMMode(Enum):
|
|
192
195
|
"""Treat the image as a single character."""
|
193
196
|
|
194
197
|
|
195
|
-
@dataclass(unsafe_hash=True, frozen=True)
|
198
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
196
199
|
class TesseractConfig:
|
197
200
|
"""Configuration options for Tesseract OCR engine."""
|
198
201
|
|
@@ -232,12 +235,10 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
232
235
|
|
233
236
|
async def process_image(
|
234
237
|
self,
|
235
|
-
image:
|
238
|
+
image: PILImage,
|
236
239
|
**kwargs: Unpack[TesseractConfig],
|
237
240
|
) -> ExtractionResult:
|
238
|
-
import
|
239
|
-
|
240
|
-
from kreuzberg._utils._cache import get_ocr_cache
|
241
|
+
from kreuzberg._utils._cache import get_ocr_cache # noqa: PLC0415
|
241
242
|
|
242
243
|
image_buffer = io.BytesIO()
|
243
244
|
await run_sync(image.save, image_buffer, format="PNG")
|
@@ -255,8 +256,6 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
255
256
|
return cached_result
|
256
257
|
|
257
258
|
if ocr_cache.is_processing(**cache_kwargs):
|
258
|
-
import anyio
|
259
|
-
|
260
259
|
event = ocr_cache.mark_processing(**cache_kwargs)
|
261
260
|
await anyio.to_thread.run_sync(event.wait)
|
262
261
|
|
@@ -287,7 +286,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
287
286
|
path: Path,
|
288
287
|
**kwargs: Unpack[TesseractConfig],
|
289
288
|
) -> ExtractionResult:
|
290
|
-
from kreuzberg._utils._cache import get_ocr_cache
|
289
|
+
from kreuzberg._utils._cache import get_ocr_cache # noqa: PLC0415
|
291
290
|
|
292
291
|
try:
|
293
292
|
stat = path.stat()
|
@@ -315,8 +314,6 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
315
314
|
return cached_result
|
316
315
|
|
317
316
|
if ocr_cache.is_processing(**cache_kwargs):
|
318
|
-
import anyio
|
319
|
-
|
320
317
|
event = ocr_cache.mark_processing(**cache_kwargs)
|
321
318
|
await anyio.to_thread.run_sync(event.wait)
|
322
319
|
|
@@ -412,7 +409,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
412
409
|
|
413
410
|
def process_image_sync(
|
414
411
|
self,
|
415
|
-
image:
|
412
|
+
image: PILImage,
|
416
413
|
**kwargs: Unpack[TesseractConfig],
|
417
414
|
) -> ExtractionResult:
|
418
415
|
"""Synchronously process an image and extract its text and metadata.
|
@@ -424,9 +421,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
424
421
|
Returns:
|
425
422
|
The extraction result object
|
426
423
|
"""
|
427
|
-
import
|
428
|
-
|
429
|
-
from kreuzberg._utils._cache import get_ocr_cache
|
424
|
+
from kreuzberg._utils._cache import get_ocr_cache # noqa: PLC0415
|
430
425
|
|
431
426
|
image_buffer = io.BytesIO()
|
432
427
|
image.save(image_buffer, format="PNG")
|
@@ -485,7 +480,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
485
480
|
Returns:
|
486
481
|
The extraction result object
|
487
482
|
"""
|
488
|
-
from kreuzberg._utils._cache import get_ocr_cache
|
483
|
+
from kreuzberg._utils._cache import get_ocr_cache # noqa: PLC0415
|
489
484
|
|
490
485
|
file_info = self._get_file_info(path)
|
491
486
|
|
@@ -774,10 +769,6 @@ def _process_image_bytes_with_tesseract(
|
|
774
769
|
OCR result as dictionary.
|
775
770
|
"""
|
776
771
|
try:
|
777
|
-
import io
|
778
|
-
|
779
|
-
from PIL import Image
|
780
|
-
|
781
772
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_image:
|
782
773
|
with Image.open(io.BytesIO(image_bytes)) as image:
|
783
774
|
image.save(tmp_image.name, format="PNG")
|
@@ -815,7 +806,7 @@ class TesseractProcessPool:
|
|
815
806
|
max_processes: Maximum number of processes.
|
816
807
|
memory_limit_gb: Memory limit in GB.
|
817
808
|
"""
|
818
|
-
from kreuzberg._utils._process_pool import ProcessPoolManager
|
809
|
+
from kreuzberg._utils._process_pool import ProcessPoolManager # noqa: PLC0415
|
819
810
|
|
820
811
|
self.config = config or TesseractConfig()
|
821
812
|
self.process_manager = ProcessPoolManager(
|
kreuzberg/_playa.py
CHANGED
@@ -114,9 +114,8 @@ def _extract_keyword_metadata(pdf_info: dict[str, Any], result: Metadata) -> Non
|
|
114
114
|
if keywords := pdf_info.get("keywords"):
|
115
115
|
if isinstance(keywords, (str, bytes)):
|
116
116
|
kw_str = decode_text(keywords)
|
117
|
-
|
118
|
-
|
119
|
-
result["keywords"] = [k for k in kw_list if k]
|
117
|
+
# Combine multiple operations into a single comprehension
|
118
|
+
result["keywords"] = [k.strip() for part in kw_str.replace(";", ",").split(",") if (k := part.strip())]
|
120
119
|
elif isinstance(keywords, list):
|
121
120
|
result["keywords"] = [decode_text(k) for k in keywords]
|
122
121
|
|
kreuzberg/_types.py
CHANGED
@@ -5,7 +5,14 @@ from collections.abc import Awaitable, Callable
|
|
5
5
|
from dataclasses import asdict, dataclass, field
|
6
6
|
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
7
7
|
|
8
|
+
import msgspec
|
9
|
+
|
8
10
|
from kreuzberg._constants import DEFAULT_MAX_CHARACTERS, DEFAULT_MAX_OVERLAP
|
11
|
+
from kreuzberg._utils._table import (
|
12
|
+
export_table_to_csv,
|
13
|
+
export_table_to_tsv,
|
14
|
+
extract_table_structure_info,
|
15
|
+
)
|
9
16
|
from kreuzberg.exceptions import ValidationError
|
10
17
|
|
11
18
|
if sys.version_info < (3, 11): # pragma: no cover
|
@@ -191,7 +198,7 @@ def normalize_metadata(data: dict[str, Any] | None) -> Metadata:
|
|
191
198
|
return normalized
|
192
199
|
|
193
200
|
|
194
|
-
@dataclass(frozen=True)
|
201
|
+
@dataclass(frozen=True, slots=True)
|
195
202
|
class Entity:
|
196
203
|
"""Represents an extracted entity with type, text, and position."""
|
197
204
|
|
@@ -205,7 +212,7 @@ class Entity:
|
|
205
212
|
"""End character offset in the content"""
|
206
213
|
|
207
214
|
|
208
|
-
@dataclass
|
215
|
+
@dataclass(slots=True)
|
209
216
|
class ExtractionResult:
|
210
217
|
"""The result of a file extraction."""
|
211
218
|
|
@@ -225,10 +232,36 @@ class ExtractionResult:
|
|
225
232
|
"""Extracted keywords and their scores, if keyword extraction is enabled."""
|
226
233
|
detected_languages: list[str] | None = None
|
227
234
|
"""Languages detected in the extracted content, if language detection is enabled."""
|
235
|
+
document_type: str | None = None
|
236
|
+
"""Detected document type, if document type detection is enabled."""
|
237
|
+
document_type_confidence: float | None = None
|
238
|
+
"""Confidence of the detected document type."""
|
239
|
+
layout: DataFrame | None = field(default=None, repr=False, hash=False)
|
240
|
+
"""Internal layout data from OCR, not for public use."""
|
241
|
+
|
242
|
+
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
243
|
+
"""Converts the ExtractionResult to a dictionary.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
include_none: If True, include fields with None values.
|
247
|
+
If False (default), exclude None values.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
Dictionary representation of the ExtractionResult.
|
251
|
+
"""
|
252
|
+
# Use msgspec.to_builtins for efficient conversion
|
253
|
+
# The builtin_types parameter allows DataFrames to pass through
|
254
|
+
result = msgspec.to_builtins(
|
255
|
+
self,
|
256
|
+
builtin_types=(type(None),), # Allow None to pass through
|
257
|
+
order="deterministic", # Ensure consistent output
|
258
|
+
)
|
259
|
+
|
260
|
+
if include_none:
|
261
|
+
return result # type: ignore[no-any-return]
|
228
262
|
|
229
|
-
|
230
|
-
|
231
|
-
return asdict(self)
|
263
|
+
# Remove None values to match expected behavior
|
264
|
+
return {k: v for k, v in result.items() if v is not None}
|
232
265
|
|
233
266
|
def export_tables_to_csv(self) -> list[str]:
|
234
267
|
"""Export all tables to CSV format.
|
@@ -239,8 +272,6 @@ class ExtractionResult:
|
|
239
272
|
if not self.tables:
|
240
273
|
return []
|
241
274
|
|
242
|
-
from kreuzberg._utils._table import export_table_to_csv
|
243
|
-
|
244
275
|
return [export_table_to_csv(table) for table in self.tables]
|
245
276
|
|
246
277
|
def export_tables_to_tsv(self) -> list[str]:
|
@@ -252,8 +283,6 @@ class ExtractionResult:
|
|
252
283
|
if not self.tables:
|
253
284
|
return []
|
254
285
|
|
255
|
-
from kreuzberg._utils._table import export_table_to_tsv
|
256
|
-
|
257
286
|
return [export_table_to_tsv(table) for table in self.tables]
|
258
287
|
|
259
288
|
def get_table_summaries(self) -> list[dict[str, Any]]:
|
@@ -265,8 +294,6 @@ class ExtractionResult:
|
|
265
294
|
if not self.tables:
|
266
295
|
return []
|
267
296
|
|
268
|
-
from kreuzberg._utils._table import extract_table_structure_info
|
269
|
-
|
270
297
|
return [extract_table_structure_info(table) for table in self.tables]
|
271
298
|
|
272
299
|
|
@@ -274,7 +301,7 @@ PostProcessingHook = Callable[[ExtractionResult], ExtractionResult | Awaitable[E
|
|
274
301
|
ValidationHook = Callable[[ExtractionResult], None | Awaitable[None]]
|
275
302
|
|
276
303
|
|
277
|
-
@dataclass(unsafe_hash=True)
|
304
|
+
@dataclass(unsafe_hash=True, slots=True)
|
278
305
|
class ExtractionConfig:
|
279
306
|
"""Represents configuration settings for an extraction process.
|
280
307
|
|
@@ -322,6 +349,12 @@ class ExtractionConfig:
|
|
322
349
|
"""Configuration for language detection. If None, uses default settings."""
|
323
350
|
spacy_entity_extraction_config: SpacyEntityExtractionConfig | None = None
|
324
351
|
"""Configuration for spaCy entity extraction. If None, uses default settings."""
|
352
|
+
auto_detect_document_type: bool = False
|
353
|
+
"""Whether to automatically detect the document type."""
|
354
|
+
document_type_confidence_threshold: float = 0.7
|
355
|
+
"""Confidence threshold for document type detection."""
|
356
|
+
document_classification_mode: Literal["text", "vision"] = "text"
|
357
|
+
"""The mode to use for document classification."""
|
325
358
|
enable_quality_processing: bool = True
|
326
359
|
"""Whether to apply quality post-processing to improve extraction results."""
|
327
360
|
|
@@ -332,9 +365,9 @@ class ExtractionConfig:
|
|
332
365
|
object.__setattr__(self, "post_processing_hooks", tuple(self.post_processing_hooks))
|
333
366
|
if self.validators is not None and isinstance(self.validators, list):
|
334
367
|
object.__setattr__(self, "validators", tuple(self.validators))
|
335
|
-
from kreuzberg._ocr._easyocr import EasyOCRConfig
|
336
|
-
from kreuzberg._ocr._paddleocr import PaddleOCRConfig
|
337
|
-
from kreuzberg._ocr._tesseract import TesseractConfig
|
368
|
+
from kreuzberg._ocr._easyocr import EasyOCRConfig # noqa: PLC0415
|
369
|
+
from kreuzberg._ocr._paddleocr import PaddleOCRConfig # noqa: PLC0415
|
370
|
+
from kreuzberg._ocr._tesseract import TesseractConfig # noqa: PLC0415
|
338
371
|
|
339
372
|
if self.ocr_backend is None and self.ocr_config is not None:
|
340
373
|
raise ValidationError("'ocr_backend' is None but 'ocr_config' is provided")
|
@@ -355,18 +388,23 @@ class ExtractionConfig:
|
|
355
388
|
Returns:
|
356
389
|
A dict of the OCR configuration or an empty dict if no backend is provided.
|
357
390
|
"""
|
358
|
-
if self.ocr_backend is
|
359
|
-
|
360
|
-
return asdict(self.ocr_config)
|
361
|
-
if self.ocr_backend == "tesseract":
|
362
|
-
from kreuzberg._ocr._tesseract import TesseractConfig
|
391
|
+
if self.ocr_backend is None:
|
392
|
+
return {}
|
363
393
|
|
364
|
-
|
365
|
-
|
366
|
-
|
394
|
+
if self.ocr_config is not None:
|
395
|
+
# Use asdict for OCR configs to preserve enum objects correctly
|
396
|
+
return asdict(self.ocr_config)
|
367
397
|
|
368
|
-
|
369
|
-
|
398
|
+
# Lazy load and cache default configs instead of creating new instances
|
399
|
+
if self.ocr_backend == "tesseract":
|
400
|
+
from kreuzberg._ocr._tesseract import TesseractConfig # noqa: PLC0415
|
370
401
|
|
371
|
-
return asdict(
|
372
|
-
|
402
|
+
return asdict(TesseractConfig())
|
403
|
+
if self.ocr_backend == "easyocr":
|
404
|
+
from kreuzberg._ocr._easyocr import EasyOCRConfig # noqa: PLC0415
|
405
|
+
|
406
|
+
return asdict(EasyOCRConfig())
|
407
|
+
# paddleocr
|
408
|
+
from kreuzberg._ocr._paddleocr import PaddleOCRConfig # noqa: PLC0415
|
409
|
+
|
410
|
+
return asdict(PaddleOCRConfig())
|
kreuzberg/_utils/_cache.py
CHANGED
@@ -7,6 +7,7 @@ import os
|
|
7
7
|
import threading
|
8
8
|
import time
|
9
9
|
from contextlib import suppress
|
10
|
+
from io import StringIO
|
10
11
|
from pathlib import Path
|
11
12
|
from typing import Any, Generic, TypeVar
|
12
13
|
|
@@ -64,11 +65,10 @@ class KreuzbergCache(Generic[T]):
|
|
64
65
|
Returns:
|
65
66
|
Unique cache key string
|
66
67
|
"""
|
67
|
-
# Use more efficient string building for cache key
|
68
68
|
if not kwargs:
|
69
69
|
return "empty"
|
70
70
|
|
71
|
-
# Build key
|
71
|
+
# Build cache key using list + join (faster than StringIO)
|
72
72
|
parts = []
|
73
73
|
for key in sorted(kwargs):
|
74
74
|
value = kwargs[key]
|
@@ -81,6 +81,7 @@ class KreuzbergCache(Generic[T]):
|
|
81
81
|
parts.append(f"{key}={type(value).__name__}:{value!s}")
|
82
82
|
|
83
83
|
cache_str = "&".join(parts)
|
84
|
+
# SHA256 is secure and fast enough for cache keys
|
84
85
|
return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
|
85
86
|
|
86
87
|
def _get_cache_path(self, cache_key: str) -> Path:
|
@@ -107,15 +108,14 @@ class KreuzbergCache(Generic[T]):
|
|
107
108
|
serialized_data = []
|
108
109
|
for item in result:
|
109
110
|
if isinstance(item, dict) and "df" in item:
|
110
|
-
#
|
111
|
-
|
111
|
+
# Build new dict without unnecessary copy
|
112
|
+
serialized_item = {k: v for k, v in item.items() if k != "df"}
|
112
113
|
if hasattr(item["df"], "to_csv"):
|
113
|
-
|
114
|
+
serialized_item["df_csv"] = item["df"].to_csv(index=False)
|
114
115
|
else:
|
115
116
|
# Fallback for non-DataFrame objects
|
116
|
-
|
117
|
-
|
118
|
-
serialized_data.append(item_copy)
|
117
|
+
serialized_item["df_csv"] = str(item["df"])
|
118
|
+
serialized_data.append(serialized_item)
|
119
119
|
else:
|
120
120
|
serialized_data.append(item)
|
121
121
|
return {"type": "TableDataList", "data": serialized_data, "cached_at": time.time()}
|
@@ -127,18 +127,15 @@ class KreuzbergCache(Generic[T]):
|
|
127
127
|
data = cached_data["data"]
|
128
128
|
|
129
129
|
if cached_data.get("type") == "TableDataList" and isinstance(data, list):
|
130
|
+
import pandas as pd # noqa: PLC0415
|
131
|
+
|
130
132
|
deserialized_data = []
|
131
133
|
for item in data:
|
132
134
|
if isinstance(item, dict) and "df_csv" in item:
|
133
|
-
#
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
import pandas as pd
|
138
|
-
|
139
|
-
item_copy["df"] = pd.read_csv(StringIO(item["df_csv"]))
|
140
|
-
del item_copy["df_csv"]
|
141
|
-
deserialized_data.append(item_copy)
|
135
|
+
# Build new dict without unnecessary copy
|
136
|
+
deserialized_item = {k: v for k, v in item.items() if k != "df_csv"}
|
137
|
+
deserialized_item["df"] = pd.read_csv(StringIO(item["df_csv"]))
|
138
|
+
deserialized_data.append(deserialized_item)
|
142
139
|
else:
|
143
140
|
deserialized_data.append(item)
|
144
141
|
return deserialized_data # type: ignore[return-value]
|