kreuzberg 3.14.1__py3-none-any.whl → 3.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +6 -0
- kreuzberg/_api/_config_cache.py +247 -0
- kreuzberg/_api/main.py +127 -45
- kreuzberg/_chunker.py +7 -6
- kreuzberg/_constants.py +2 -0
- kreuzberg/_document_classification.py +4 -6
- kreuzberg/_entity_extraction.py +9 -4
- kreuzberg/_extractors/_base.py +269 -3
- kreuzberg/_extractors/_email.py +95 -27
- kreuzberg/_extractors/_html.py +85 -7
- kreuzberg/_extractors/_image.py +23 -22
- kreuzberg/_extractors/_pandoc.py +106 -75
- kreuzberg/_extractors/_pdf.py +209 -99
- kreuzberg/_extractors/_presentation.py +72 -8
- kreuzberg/_extractors/_spread_sheet.py +25 -30
- kreuzberg/_mcp/server.py +345 -25
- kreuzberg/_mime_types.py +42 -0
- kreuzberg/_ocr/_easyocr.py +2 -2
- kreuzberg/_ocr/_paddleocr.py +1 -1
- kreuzberg/_ocr/_tesseract.py +74 -34
- kreuzberg/_types.py +180 -21
- kreuzberg/_utils/_cache.py +10 -4
- kreuzberg/_utils/_device.py +2 -4
- kreuzberg/_utils/_image_preprocessing.py +12 -39
- kreuzberg/_utils/_process_pool.py +29 -8
- kreuzberg/_utils/_quality.py +7 -2
- kreuzberg/_utils/_resource_managers.py +65 -0
- kreuzberg/_utils/_sync.py +36 -6
- kreuzberg/_utils/_tmp.py +37 -1
- kreuzberg/cli.py +34 -20
- kreuzberg/extraction.py +43 -27
- {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/METADATA +2 -1
- kreuzberg-3.15.0.dist-info/RECORD +60 -0
- kreuzberg-3.14.1.dist-info/RECORD +0 -58
- {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/WHEEL +0 -0
- {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/licenses/LICENSE +0 -0
kreuzberg/_ocr/_tesseract.py
CHANGED
@@ -8,6 +8,7 @@ import re
|
|
8
8
|
import subprocess
|
9
9
|
import sys
|
10
10
|
import tempfile
|
11
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
11
12
|
from io import StringIO
|
12
13
|
from pathlib import Path
|
13
14
|
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
@@ -28,10 +29,10 @@ from kreuzberg._ocr._base import OCRBackend
|
|
28
29
|
from kreuzberg._ocr._table_extractor import extract_words, reconstruct_table, to_markdown
|
29
30
|
from kreuzberg._types import ExtractionResult, HTMLToMarkdownConfig, PSMMode, TableData, TesseractConfig
|
30
31
|
from kreuzberg._utils._cache import get_ocr_cache
|
31
|
-
from kreuzberg._utils._process_pool import ProcessPoolManager
|
32
|
+
from kreuzberg._utils._process_pool import ProcessPoolManager, get_optimal_worker_count
|
32
33
|
from kreuzberg._utils._string import normalize_spaces
|
33
34
|
from kreuzberg._utils._sync import run_sync
|
34
|
-
from kreuzberg._utils._tmp import create_temp_file
|
35
|
+
from kreuzberg._utils._tmp import create_temp_file, temporary_file_sync
|
35
36
|
from kreuzberg.exceptions import MissingDependencyError, OCRError, ValidationError
|
36
37
|
|
37
38
|
if TYPE_CHECKING:
|
@@ -257,18 +258,19 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
257
258
|
if enable_table_detection and output_format == "text":
|
258
259
|
output_format = "tsv"
|
259
260
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
261
|
+
match output_format:
|
262
|
+
case "markdown":
|
263
|
+
tesseract_format = "hocr"
|
264
|
+
ext = ".hocr"
|
265
|
+
case "tsv":
|
266
|
+
tesseract_format = "tsv"
|
267
|
+
ext = ".tsv"
|
268
|
+
case "hocr":
|
269
|
+
tesseract_format = "hocr"
|
270
|
+
ext = ".hocr"
|
271
|
+
case _:
|
272
|
+
tesseract_format = "text"
|
273
|
+
ext = ".txt"
|
272
274
|
|
273
275
|
return {
|
274
276
|
"language": language,
|
@@ -344,11 +346,9 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
344
346
|
if output_format == "tsv":
|
345
347
|
return self._extract_text_from_tsv(output)
|
346
348
|
if output_format == "hocr":
|
347
|
-
return ExtractionResult(content=output, mime_type=HTML_MIME_TYPE, metadata={}
|
349
|
+
return ExtractionResult(content=output, mime_type=HTML_MIME_TYPE, metadata={})
|
348
350
|
|
349
|
-
return ExtractionResult(
|
350
|
-
content=normalize_spaces(output), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={}, chunks=[]
|
351
|
-
)
|
351
|
+
return ExtractionResult(content=normalize_spaces(output), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={})
|
352
352
|
|
353
353
|
async def process_file(self, path: Path, **kwargs: Unpack[TesseractConfig]) -> ExtractionResult:
|
354
354
|
use_cache = kwargs.pop("use_cache", True)
|
@@ -494,9 +494,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
494
494
|
content += parts[11] + " "
|
495
495
|
content = content.strip()
|
496
496
|
|
497
|
-
return ExtractionResult(
|
498
|
-
content=normalize_spaces(content), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={}, chunks=[]
|
499
|
-
)
|
497
|
+
return ExtractionResult(content=normalize_spaces(content), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={})
|
500
498
|
|
501
499
|
async def _process_hocr_to_markdown(
|
502
500
|
self,
|
@@ -517,7 +515,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
517
515
|
|
518
516
|
tables: list[TableData] = []
|
519
517
|
if enable_table_detection:
|
520
|
-
soup = BeautifulSoup(hocr_content, "
|
518
|
+
soup = BeautifulSoup(hocr_content, "xml")
|
521
519
|
tables = await self._extract_tables_from_hocr(
|
522
520
|
soup,
|
523
521
|
table_column_threshold,
|
@@ -539,7 +537,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
539
537
|
markdown_content = normalize_spaces(markdown_content)
|
540
538
|
except (ValueError, TypeError, AttributeError):
|
541
539
|
try:
|
542
|
-
soup = BeautifulSoup(hocr_content, "
|
540
|
+
soup = BeautifulSoup(hocr_content, "xml")
|
543
541
|
words = soup.find_all("span", class_="ocrx_word")
|
544
542
|
text_parts = []
|
545
543
|
for word in words:
|
@@ -690,7 +688,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
690
688
|
|
691
689
|
except (ValueError, TypeError, AttributeError):
|
692
690
|
try:
|
693
|
-
soup = BeautifulSoup(hocr_content, "
|
691
|
+
soup = BeautifulSoup(hocr_content, "xml")
|
694
692
|
words = soup.find_all("span", class_="ocrx_word")
|
695
693
|
text_parts = []
|
696
694
|
for word in words:
|
@@ -948,11 +946,9 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
948
946
|
if output_format == "tsv":
|
949
947
|
return self._extract_text_from_tsv(output)
|
950
948
|
if output_format == "hocr":
|
951
|
-
return ExtractionResult(content=output, mime_type=HTML_MIME_TYPE, metadata={}
|
949
|
+
return ExtractionResult(content=output, mime_type=HTML_MIME_TYPE, metadata={})
|
952
950
|
|
953
|
-
return ExtractionResult(
|
954
|
-
content=normalize_spaces(output), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={}, chunks=[]
|
955
|
-
)
|
951
|
+
return ExtractionResult(content=normalize_spaces(output), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={})
|
956
952
|
|
957
953
|
def process_image_sync(self, image: PILImage, **kwargs: Unpack[TesseractConfig]) -> ExtractionResult:
|
958
954
|
use_cache = kwargs.pop("use_cache", True)
|
@@ -979,10 +975,8 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
979
975
|
ocr_cache = get_ocr_cache()
|
980
976
|
try:
|
981
977
|
self._validate_tesseract_version_sync()
|
982
|
-
with
|
983
|
-
image_path = Path(tmp_file.name)
|
978
|
+
with temporary_file_sync(".png") as image_path:
|
984
979
|
save_image.save(str(image_path), format="PNG")
|
985
|
-
try:
|
986
980
|
kwargs_with_cache = {**kwargs, "use_cache": use_cache}
|
987
981
|
result = self.process_file_sync(image_path, **kwargs_with_cache)
|
988
982
|
|
@@ -990,9 +984,6 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
990
984
|
ocr_cache.set(result, **cache_kwargs)
|
991
985
|
|
992
986
|
return result
|
993
|
-
finally:
|
994
|
-
if image_path.exists():
|
995
|
-
image_path.unlink()
|
996
987
|
finally:
|
997
988
|
if use_cache:
|
998
989
|
ocr_cache.mark_complete(**cache_kwargs)
|
@@ -1092,6 +1083,55 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
1092
1083
|
"mtime": 0,
|
1093
1084
|
}
|
1094
1085
|
|
1086
|
+
def _result_from_dict(self, result_dict: dict[str, Any]) -> ExtractionResult:
|
1087
|
+
"""Convert a worker result dict to ExtractionResult."""
|
1088
|
+
if result_dict.get("success"):
|
1089
|
+
return ExtractionResult(
|
1090
|
+
content=str(result_dict.get("text", "")),
|
1091
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
1092
|
+
metadata={},
|
1093
|
+
chunks=[],
|
1094
|
+
)
|
1095
|
+
return ExtractionResult(
|
1096
|
+
content=f"[OCR error: {result_dict.get('error', 'Unknown error')}]",
|
1097
|
+
mime_type=PLAIN_TEXT_MIME_TYPE,
|
1098
|
+
metadata={},
|
1099
|
+
chunks=[],
|
1100
|
+
)
|
1101
|
+
|
1102
|
+
def process_batch_sync(self, paths: list[Path], **kwargs: Unpack[TesseractConfig]) -> list[ExtractionResult]:
|
1103
|
+
if not paths:
|
1104
|
+
return []
|
1105
|
+
|
1106
|
+
results: list[ExtractionResult] = [
|
1107
|
+
ExtractionResult(content="", mime_type=PLAIN_TEXT_MIME_TYPE, metadata={})
|
1108
|
+
] * len(paths)
|
1109
|
+
|
1110
|
+
run_config = self._prepare_tesseract_run_config(**kwargs)
|
1111
|
+
config_dict: dict[str, Any] = {
|
1112
|
+
**run_config["remaining_kwargs"],
|
1113
|
+
"language": run_config["language"],
|
1114
|
+
"psm": run_config["psm"],
|
1115
|
+
}
|
1116
|
+
|
1117
|
+
optimal_workers = get_optimal_worker_count(len(paths), cpu_intensive=True)
|
1118
|
+
|
1119
|
+
with ProcessPoolExecutor(max_workers=optimal_workers) as pool:
|
1120
|
+
future_to_idx = {
|
1121
|
+
pool.submit(_process_image_with_tesseract, str(p), config_dict): idx for idx, p in enumerate(paths)
|
1122
|
+
}
|
1123
|
+
for future in as_completed(future_to_idx):
|
1124
|
+
idx = future_to_idx[future]
|
1125
|
+
try:
|
1126
|
+
result_dict = future.result()
|
1127
|
+
results[idx] = self._result_from_dict(result_dict)
|
1128
|
+
except Exception as e: # noqa: BLE001
|
1129
|
+
results[idx] = ExtractionResult(
|
1130
|
+
content=f"[OCR error: {e}]", mime_type=PLAIN_TEXT_MIME_TYPE, metadata={}
|
1131
|
+
)
|
1132
|
+
|
1133
|
+
return results
|
1134
|
+
|
1095
1135
|
def _build_tesseract_command(
|
1096
1136
|
self,
|
1097
1137
|
path: Path,
|
kreuzberg/_types.py
CHANGED
@@ -4,6 +4,7 @@ import sys
|
|
4
4
|
from collections.abc import Awaitable, Callable, Iterable, Mapping
|
5
5
|
from dataclasses import asdict, dataclass, field
|
6
6
|
from enum import Enum
|
7
|
+
from pathlib import Path
|
7
8
|
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypedDict
|
8
9
|
|
9
10
|
import msgspec
|
@@ -25,8 +26,6 @@ else: # pragma: no cover
|
|
25
26
|
from typing import NotRequired
|
26
27
|
|
27
28
|
if TYPE_CHECKING:
|
28
|
-
from pathlib import Path
|
29
|
-
|
30
29
|
from PIL.Image import Image
|
31
30
|
from polars import DataFrame
|
32
31
|
|
@@ -165,6 +164,12 @@ class EasyOCRConfig(ConfigDict):
|
|
165
164
|
ycenter_ths: float = 0.5
|
166
165
|
"""Maximum shift in y direction for merging."""
|
167
166
|
|
167
|
+
def __post_init__(self) -> None:
|
168
|
+
if isinstance(self.language, list):
|
169
|
+
object.__setattr__(self, "language", tuple(self.language))
|
170
|
+
if isinstance(self.rotation_info, list):
|
171
|
+
object.__setattr__(self, "rotation_info", tuple(self.rotation_info))
|
172
|
+
|
168
173
|
|
169
174
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
170
175
|
class PaddleOCRConfig(ConfigDict):
|
@@ -349,6 +354,51 @@ class GMFTConfig(ConfigDict):
|
|
349
354
|
"""
|
350
355
|
|
351
356
|
|
357
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
358
|
+
class ImageOCRConfig(ConfigDict):
|
359
|
+
"""Configuration for OCR processing of extracted images."""
|
360
|
+
|
361
|
+
enabled: bool = False
|
362
|
+
"""Whether to perform OCR on extracted images."""
|
363
|
+
backend: OcrBackendType | None = None
|
364
|
+
"""OCR backend for image OCR. Falls back to main ocr_backend when None."""
|
365
|
+
backend_config: TesseractConfig | PaddleOCRConfig | EasyOCRConfig | None = None
|
366
|
+
"""Backend-specific configuration for image OCR."""
|
367
|
+
min_dimensions: tuple[int, int] = (50, 50)
|
368
|
+
"""Minimum (width, height) in pixels for image OCR eligibility."""
|
369
|
+
max_dimensions: tuple[int, int] = (10000, 10000)
|
370
|
+
"""Maximum (width, height) in pixels for image OCR eligibility."""
|
371
|
+
allowed_formats: frozenset[str] = frozenset(
|
372
|
+
{
|
373
|
+
"jpg",
|
374
|
+
"jpeg",
|
375
|
+
"png",
|
376
|
+
"gif",
|
377
|
+
"bmp",
|
378
|
+
"tiff",
|
379
|
+
"tif",
|
380
|
+
"webp",
|
381
|
+
"jp2",
|
382
|
+
"jpx",
|
383
|
+
"jpm",
|
384
|
+
"mj2",
|
385
|
+
"pnm",
|
386
|
+
"pbm",
|
387
|
+
"pgm",
|
388
|
+
"ppm",
|
389
|
+
}
|
390
|
+
)
|
391
|
+
"""Allowed image formats for OCR processing (lowercase, without dot)."""
|
392
|
+
batch_size: int = 4
|
393
|
+
"""Number of images to process in parallel for OCR."""
|
394
|
+
timeout_seconds: int = 30
|
395
|
+
"""Maximum time in seconds for OCR processing per image."""
|
396
|
+
|
397
|
+
def __post_init__(self) -> None:
|
398
|
+
if isinstance(self.allowed_formats, list):
|
399
|
+
object.__setattr__(self, "allowed_formats", frozenset(self.allowed_formats))
|
400
|
+
|
401
|
+
|
352
402
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
353
403
|
class LanguageDetectionConfig(ConfigDict):
|
354
404
|
low_memory: bool = True
|
@@ -391,6 +441,9 @@ class SpacyEntityExtractionConfig(ConfigDict):
|
|
391
441
|
"""Batch size for processing multiple texts."""
|
392
442
|
|
393
443
|
def __post_init__(self) -> None:
|
444
|
+
if isinstance(self.model_cache_dir, Path):
|
445
|
+
object.__setattr__(self, "model_cache_dir", str(self.model_cache_dir))
|
446
|
+
|
394
447
|
if self.language_models is None:
|
395
448
|
object.__setattr__(self, "language_models", self._get_default_language_models())
|
396
449
|
|
@@ -622,6 +675,8 @@ class Metadata(TypedDict, total=False):
|
|
622
675
|
"""Source format of the extracted content."""
|
623
676
|
error: NotRequired[str]
|
624
677
|
"""Error message if extraction failed."""
|
678
|
+
error_context: NotRequired[dict[str, Any]]
|
679
|
+
"""Error context information for debugging."""
|
625
680
|
|
626
681
|
|
627
682
|
_VALID_METADATA_KEYS = {
|
@@ -664,6 +719,9 @@ _VALID_METADATA_KEYS = {
|
|
664
719
|
"tables_summary",
|
665
720
|
"quality_score",
|
666
721
|
"image_preprocessing",
|
722
|
+
"source_format",
|
723
|
+
"error",
|
724
|
+
"error_context",
|
667
725
|
}
|
668
726
|
|
669
727
|
|
@@ -679,7 +737,7 @@ def normalize_metadata(data: dict[str, Any] | None) -> Metadata:
|
|
679
737
|
return normalized
|
680
738
|
|
681
739
|
|
682
|
-
@dataclass(frozen=True, slots=True)
|
740
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
683
741
|
class Entity:
|
684
742
|
type: str
|
685
743
|
"""e.g., PERSON, ORGANIZATION, LOCATION, DATE, EMAIL, PHONE, or custom"""
|
@@ -691,18 +749,44 @@ class Entity:
|
|
691
749
|
"""End character offset in the content"""
|
692
750
|
|
693
751
|
|
752
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
753
|
+
class ExtractedImage:
|
754
|
+
data: bytes
|
755
|
+
format: str
|
756
|
+
filename: str | None = None
|
757
|
+
page_number: int | None = None
|
758
|
+
dimensions: tuple[int, int] | None = None
|
759
|
+
colorspace: str | None = None
|
760
|
+
bits_per_component: int | None = None
|
761
|
+
is_mask: bool = False
|
762
|
+
description: str | None = None
|
763
|
+
|
764
|
+
|
765
|
+
@dataclass(slots=True)
|
766
|
+
class ImageOCRResult:
|
767
|
+
image: ExtractedImage
|
768
|
+
ocr_result: ExtractionResult
|
769
|
+
confidence_score: float | None = None
|
770
|
+
processing_time: float | None = None
|
771
|
+
skipped_reason: str | None = None
|
772
|
+
|
773
|
+
|
694
774
|
@dataclass(slots=True)
|
695
775
|
class ExtractionResult:
|
696
776
|
content: str
|
697
777
|
"""The extracted content."""
|
698
778
|
mime_type: str
|
699
779
|
"""The mime type of the extracted content. Is either text/plain or text/markdown."""
|
700
|
-
metadata: Metadata
|
780
|
+
metadata: Metadata = field(default_factory=lambda: Metadata())
|
701
781
|
"""The metadata of the content."""
|
702
782
|
tables: list[TableData] = field(default_factory=list)
|
703
783
|
"""Extracted tables. Is an empty list if 'extract_tables' is not set to True in the ExtractionConfig."""
|
704
784
|
chunks: list[str] = field(default_factory=list)
|
705
785
|
"""The extracted content chunks. This is an empty list if 'chunk_content' is not set to True in the ExtractionConfig."""
|
786
|
+
images: list[ExtractedImage] = field(default_factory=list)
|
787
|
+
"""Extracted images. Empty list if 'extract_images' is not enabled."""
|
788
|
+
image_ocr_results: list[ImageOCRResult] = field(default_factory=list)
|
789
|
+
"""OCR results from extracted images. Empty list if disabled or none processed."""
|
706
790
|
entities: list[Entity] | None = None
|
707
791
|
"""Extracted entities, if entity extraction is enabled."""
|
708
792
|
keywords: list[tuple[str, float]] | None = None
|
@@ -761,6 +845,41 @@ class ExtractionConfig(ConfigDict):
|
|
761
845
|
"""Whether to extract tables from the content. This requires the 'gmft' dependency."""
|
762
846
|
extract_tables_from_ocr: bool = False
|
763
847
|
"""Extract tables from OCR output using TSV format (Tesseract only)."""
|
848
|
+
extract_images: bool = False
|
849
|
+
"""Whether to extract images from documents."""
|
850
|
+
deduplicate_images: bool = True
|
851
|
+
"""Whether to remove duplicate images using CRC32 checksums."""
|
852
|
+
image_ocr_config: ImageOCRConfig | None = None
|
853
|
+
"""Configuration for OCR processing of extracted images."""
|
854
|
+
ocr_extracted_images: bool = False
|
855
|
+
"""Deprecated: Use image_ocr_config.enabled instead."""
|
856
|
+
image_ocr_backend: OcrBackendType | None = None
|
857
|
+
"""Deprecated: Use image_ocr_config.backend instead."""
|
858
|
+
image_ocr_min_dimensions: tuple[int, int] = (50, 50)
|
859
|
+
"""Deprecated: Use image_ocr_config.min_dimensions instead."""
|
860
|
+
image_ocr_max_dimensions: tuple[int, int] = (10000, 10000)
|
861
|
+
"""Deprecated: Use image_ocr_config.max_dimensions instead."""
|
862
|
+
image_ocr_formats: frozenset[str] = frozenset(
|
863
|
+
{
|
864
|
+
"jpg",
|
865
|
+
"jpeg",
|
866
|
+
"png",
|
867
|
+
"gif",
|
868
|
+
"bmp",
|
869
|
+
"tiff",
|
870
|
+
"tif",
|
871
|
+
"webp",
|
872
|
+
"jp2",
|
873
|
+
"jpx",
|
874
|
+
"jpm",
|
875
|
+
"mj2",
|
876
|
+
"pnm",
|
877
|
+
"pbm",
|
878
|
+
"pgm",
|
879
|
+
"ppm",
|
880
|
+
}
|
881
|
+
)
|
882
|
+
"""Deprecated: Use image_ocr_config.allowed_formats instead."""
|
764
883
|
max_chars: int = DEFAULT_MAX_CHARACTERS
|
765
884
|
"""The size of each chunk in characters."""
|
766
885
|
max_overlap: int = DEFAULT_MAX_OVERLAP
|
@@ -826,6 +945,51 @@ class ExtractionConfig(ConfigDict):
|
|
826
945
|
if self.validators is not None and isinstance(self.validators, list):
|
827
946
|
object.__setattr__(self, "validators", tuple(self.validators))
|
828
947
|
|
948
|
+
if isinstance(self.pdf_password, list):
|
949
|
+
object.__setattr__(self, "pdf_password", tuple(self.pdf_password))
|
950
|
+
|
951
|
+
if isinstance(self.image_ocr_formats, list):
|
952
|
+
object.__setattr__(self, "image_ocr_formats", frozenset(self.image_ocr_formats))
|
953
|
+
|
954
|
+
if self.image_ocr_config is None and (
|
955
|
+
self.ocr_extracted_images
|
956
|
+
or self.image_ocr_backend is not None
|
957
|
+
or self.image_ocr_min_dimensions != (50, 50)
|
958
|
+
or self.image_ocr_max_dimensions != (10000, 10000)
|
959
|
+
or self.image_ocr_formats
|
960
|
+
!= frozenset(
|
961
|
+
{
|
962
|
+
"jpg",
|
963
|
+
"jpeg",
|
964
|
+
"png",
|
965
|
+
"gif",
|
966
|
+
"bmp",
|
967
|
+
"tiff",
|
968
|
+
"tif",
|
969
|
+
"webp",
|
970
|
+
"jp2",
|
971
|
+
"jpx",
|
972
|
+
"jpm",
|
973
|
+
"mj2",
|
974
|
+
"pnm",
|
975
|
+
"pbm",
|
976
|
+
"pgm",
|
977
|
+
"ppm",
|
978
|
+
}
|
979
|
+
)
|
980
|
+
):
|
981
|
+
object.__setattr__(
|
982
|
+
self,
|
983
|
+
"image_ocr_config",
|
984
|
+
ImageOCRConfig(
|
985
|
+
enabled=self.ocr_extracted_images,
|
986
|
+
backend=self.image_ocr_backend,
|
987
|
+
min_dimensions=self.image_ocr_min_dimensions,
|
988
|
+
max_dimensions=self.image_ocr_max_dimensions,
|
989
|
+
allowed_formats=self.image_ocr_formats,
|
990
|
+
),
|
991
|
+
)
|
992
|
+
|
829
993
|
if self.ocr_backend is None and self.ocr_config is not None:
|
830
994
|
raise ValidationError("'ocr_backend' is None but 'ocr_config' is provided")
|
831
995
|
|
@@ -839,7 +1003,6 @@ class ExtractionConfig(ConfigDict):
|
|
839
1003
|
context={"ocr_backend": self.ocr_backend, "ocr_config": type(self.ocr_config).__name__},
|
840
1004
|
)
|
841
1005
|
|
842
|
-
# Validate DPI configuration
|
843
1006
|
if self.target_dpi <= 0:
|
844
1007
|
raise ValidationError("target_dpi must be positive", context={"target_dpi": self.target_dpi})
|
845
1008
|
if self.min_dpi <= 0:
|
@@ -861,27 +1024,22 @@ class ExtractionConfig(ConfigDict):
|
|
861
1024
|
)
|
862
1025
|
|
863
1026
|
def get_config_dict(self) -> dict[str, Any]:
|
864
|
-
if self.ocr_backend is None:
|
865
|
-
return {"use_cache": self.use_cache}
|
866
|
-
|
867
|
-
if self.ocr_config is not None:
|
868
|
-
config_dict = asdict(self.ocr_config)
|
869
|
-
config_dict["use_cache"] = self.use_cache
|
870
|
-
return config_dict
|
871
|
-
|
872
1027
|
match self.ocr_backend:
|
873
|
-
case
|
874
|
-
|
1028
|
+
case None:
|
1029
|
+
return {"use_cache": self.use_cache}
|
1030
|
+
case _ if self.ocr_config is not None:
|
1031
|
+
config_dict = asdict(self.ocr_config)
|
875
1032
|
config_dict["use_cache"] = self.use_cache
|
876
1033
|
return config_dict
|
1034
|
+
case "tesseract":
|
1035
|
+
config_dict = asdict(TesseractConfig())
|
877
1036
|
case "easyocr":
|
878
1037
|
config_dict = asdict(EasyOCRConfig())
|
879
|
-
config_dict["use_cache"] = self.use_cache
|
880
|
-
return config_dict
|
881
1038
|
case _:
|
882
1039
|
config_dict = asdict(PaddleOCRConfig())
|
883
|
-
|
884
|
-
|
1040
|
+
|
1041
|
+
config_dict["use_cache"] = self.use_cache
|
1042
|
+
return config_dict
|
885
1043
|
|
886
1044
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
887
1045
|
result = msgspec.to_builtins(
|
@@ -900,7 +1058,7 @@ class ExtractionConfig(ConfigDict):
|
|
900
1058
|
return {k: v for k, v in result.items() if v is not None}
|
901
1059
|
|
902
1060
|
|
903
|
-
@dataclass(frozen=True)
|
1061
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
904
1062
|
class HTMLToMarkdownConfig:
|
905
1063
|
stream_processing: bool = False
|
906
1064
|
"""Enable streaming mode for processing large HTML documents."""
|
@@ -968,4 +1126,5 @@ class HTMLToMarkdownConfig:
|
|
968
1126
|
"""Remove form elements from HTML."""
|
969
1127
|
|
970
1128
|
def to_dict(self) -> dict[str, Any]:
|
971
|
-
|
1129
|
+
result = msgspec.to_builtins(self, builtin_types=(type(None),), order="deterministic")
|
1130
|
+
return {k: v for k, v in result.items() if v is not None}
|
kreuzberg/_utils/_cache.py
CHANGED
@@ -20,6 +20,8 @@ from kreuzberg._utils._sync import run_sync
|
|
20
20
|
|
21
21
|
T = TypeVar("T")
|
22
22
|
|
23
|
+
CACHE_CLEANUP_FREQUENCY = 100
|
24
|
+
|
23
25
|
|
24
26
|
class KreuzbergCache(Generic[T]):
|
25
27
|
def __init__(
|
@@ -136,16 +138,20 @@ class KreuzbergCache(Generic[T]):
|
|
136
138
|
def _cleanup_cache(self) -> None:
|
137
139
|
try:
|
138
140
|
cache_files = list(self.cache_dir.glob("*.msgpack"))
|
139
|
-
|
140
141
|
cutoff_time = time.time() - (self.max_age_days * 24 * 3600)
|
141
|
-
|
142
|
+
|
143
|
+
remaining_files = []
|
144
|
+
for cache_file in cache_files:
|
142
145
|
try:
|
143
146
|
if cache_file.stat().st_mtime < cutoff_time:
|
144
147
|
cache_file.unlink(missing_ok=True)
|
145
|
-
|
148
|
+
else:
|
149
|
+
remaining_files.append(cache_file)
|
146
150
|
except OSError: # noqa: PERF203
|
147
151
|
continue
|
148
152
|
|
153
|
+
cache_files = remaining_files
|
154
|
+
|
149
155
|
total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists()) / (
|
150
156
|
1024 * 1024
|
151
157
|
)
|
@@ -191,7 +197,7 @@ class KreuzbergCache(Generic[T]):
|
|
191
197
|
content = serialize(serialized)
|
192
198
|
cache_path.write_bytes(content)
|
193
199
|
|
194
|
-
if hash(cache_key) %
|
200
|
+
if hash(cache_key) % CACHE_CLEANUP_FREQUENCY == 0:
|
195
201
|
self._cleanup_cache()
|
196
202
|
except (OSError, TypeError, ValueError):
|
197
203
|
pass
|
kreuzberg/_utils/_device.py
CHANGED
@@ -12,7 +12,7 @@ from kreuzberg.exceptions import ValidationError
|
|
12
12
|
DeviceType = Literal["cpu", "cuda", "mps", "auto"]
|
13
13
|
|
14
14
|
|
15
|
-
@dataclass(frozen=True, slots=True)
|
15
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
16
16
|
class DeviceInfo:
|
17
17
|
device_type: Literal["cpu", "cuda", "mps"]
|
18
18
|
"""The type of device."""
|
@@ -30,12 +30,10 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
30
30
|
cpu_device = DeviceInfo(device_type="cpu", name="CPU")
|
31
31
|
|
32
32
|
cuda_devices = _get_cuda_devices() if _is_cuda_available() else []
|
33
|
-
|
34
33
|
mps_device = _get_mps_device() if _is_mps_available() else None
|
35
34
|
mps_devices = [mps_device] if mps_device else []
|
36
35
|
|
37
|
-
|
38
|
-
return [*gpu_devices, cpu_device]
|
36
|
+
return list(chain(cuda_devices, mps_devices, [cpu_device]))
|
39
37
|
|
40
38
|
|
41
39
|
def get_optimal_device() -> DeviceInfo:
|