kreuzberg 3.13.0__py3-none-any.whl → 3.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/_chunker.py +0 -15
- kreuzberg/_config.py +0 -124
- kreuzberg/_document_classification.py +20 -39
- kreuzberg/_entity_extraction.py +0 -29
- kreuzberg/_extractors/_base.py +4 -66
- kreuzberg/_extractors/_email.py +0 -4
- kreuzberg/_extractors/_image.py +0 -2
- kreuzberg/_extractors/_pandoc.py +0 -58
- kreuzberg/_extractors/_pdf.py +0 -3
- kreuzberg/_extractors/_presentation.py +0 -82
- kreuzberg/_extractors/_spread_sheet.py +0 -2
- kreuzberg/_gmft.py +0 -61
- kreuzberg/_language_detection.py +0 -14
- kreuzberg/_mime_types.py +0 -17
- kreuzberg/_ocr/_base.py +4 -76
- kreuzberg/_ocr/_easyocr.py +110 -85
- kreuzberg/_ocr/_paddleocr.py +146 -138
- kreuzberg/_ocr/_table_extractor.py +0 -76
- kreuzberg/_ocr/_tesseract.py +0 -206
- kreuzberg/_playa.py +0 -27
- kreuzberg/_registry.py +0 -36
- kreuzberg/_types.py +16 -119
- kreuzberg/_utils/_cache.py +0 -52
- kreuzberg/_utils/_device.py +0 -56
- kreuzberg/_utils/_document_cache.py +0 -73
- kreuzberg/_utils/_errors.py +0 -47
- kreuzberg/_utils/_ocr_cache.py +136 -0
- kreuzberg/_utils/_pdf_lock.py +0 -14
- kreuzberg/_utils/_process_pool.py +0 -47
- kreuzberg/_utils/_quality.py +0 -17
- kreuzberg/_utils/_ref.py +0 -16
- kreuzberg/_utils/_serialization.py +0 -25
- kreuzberg/_utils/_string.py +0 -20
- kreuzberg/_utils/_sync.py +0 -76
- kreuzberg/_utils/_table.py +0 -45
- kreuzberg/_utils/_tmp.py +0 -9
- {kreuzberg-3.13.0.dist-info → kreuzberg-3.13.1.dist-info}/METADATA +3 -2
- kreuzberg-3.13.1.dist-info/RECORD +57 -0
- kreuzberg-3.13.0.dist-info/RECORD +0 -56
- {kreuzberg-3.13.0.dist-info → kreuzberg-3.13.1.dist-info}/WHEEL +0 -0
- {kreuzberg-3.13.0.dist-info → kreuzberg-3.13.1.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.13.0.dist-info → kreuzberg-3.13.1.dist-info}/licenses/LICENSE +0 -0
kreuzberg/_types.py
CHANGED
@@ -35,18 +35,7 @@ OutputFormatType = Literal["text", "tsv", "hocr", "markdown"]
|
|
35
35
|
|
36
36
|
|
37
37
|
class ConfigDict:
|
38
|
-
"""Abstract base class for configuration objects that can be converted to dictionaries."""
|
39
|
-
|
40
38
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
41
|
-
"""Convert configuration to dictionary.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
include_none: If True, include fields with None values.
|
45
|
-
If False (default), exclude None values.
|
46
|
-
|
47
|
-
Returns:
|
48
|
-
Dictionary representation of the configuration.
|
49
|
-
"""
|
50
39
|
result = msgspec.to_builtins(
|
51
40
|
self,
|
52
41
|
builtin_types=(type(None),),
|
@@ -60,8 +49,6 @@ class ConfigDict:
|
|
60
49
|
|
61
50
|
|
62
51
|
class PSMMode(Enum):
|
63
|
-
"""Enum for Tesseract Page Segmentation Modes (PSM) with human-readable values."""
|
64
|
-
|
65
52
|
OSD_ONLY = 0
|
66
53
|
"""Orientation and script detection only."""
|
67
54
|
AUTO_OSD = 1
|
@@ -88,8 +75,6 @@ class PSMMode(Enum):
|
|
88
75
|
|
89
76
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
90
77
|
class TesseractConfig(ConfigDict):
|
91
|
-
"""Configuration options for Tesseract OCR engine."""
|
92
|
-
|
93
78
|
classify_use_pre_adapted_templates: bool = True
|
94
79
|
"""Whether to use pre-adapted templates during classification to improve recognition accuracy."""
|
95
80
|
language: str = "eng"
|
@@ -132,8 +117,6 @@ class TesseractConfig(ConfigDict):
|
|
132
117
|
|
133
118
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
134
119
|
class EasyOCRConfig(ConfigDict):
|
135
|
-
"""Configuration options for EasyOCR."""
|
136
|
-
|
137
120
|
add_margin: float = 0.1
|
138
121
|
"""Extend bounding boxes in all directions."""
|
139
122
|
adjust_contrast: float = 0.5
|
@@ -185,21 +168,16 @@ class EasyOCRConfig(ConfigDict):
|
|
185
168
|
|
186
169
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
187
170
|
class PaddleOCRConfig(ConfigDict):
|
188
|
-
"""Configuration options for PaddleOCR.
|
189
|
-
|
190
|
-
This dataclass provides type hints and documentation for all PaddleOCR parameters.
|
191
|
-
"""
|
192
|
-
|
193
171
|
cls_image_shape: str = "3,48,192"
|
194
172
|
"""Image shape for classification algorithm in format 'channels,height,width'."""
|
195
173
|
det_algorithm: Literal["DB", "EAST", "SAST", "PSE", "FCE", "PAN", "CT", "DB++", "Layout"] = "DB"
|
196
174
|
"""Detection algorithm."""
|
197
175
|
det_db_box_thresh: float = 0.5
|
198
|
-
"""
|
176
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Use 'text_det_box_thresh' instead. Score threshold for detected boxes."""
|
199
177
|
det_db_thresh: float = 0.3
|
200
|
-
"""Binarization threshold for DB output map."""
|
178
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Use 'text_det_thresh' instead. Binarization threshold for DB output map."""
|
201
179
|
det_db_unclip_ratio: float = 2.0
|
202
|
-
"""Expansion ratio for detected text boxes."""
|
180
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Use 'text_det_unclip_ratio' instead. Expansion ratio for detected text boxes."""
|
203
181
|
det_east_cover_thresh: float = 0.1
|
204
182
|
"""Score threshold for EAST output boxes."""
|
205
183
|
det_east_nms_thresh: float = 0.2
|
@@ -215,7 +193,7 @@ class PaddleOCRConfig(ConfigDict):
|
|
215
193
|
enable_mkldnn: bool = False
|
216
194
|
"""Whether to enable MKL-DNN acceleration (Intel CPU only)."""
|
217
195
|
gpu_mem: int = 8000
|
218
|
-
"""GPU memory size (in MB) to use for initialization."""
|
196
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Parameter no longer supported. GPU memory size (in MB) to use for initialization."""
|
219
197
|
language: str = "en"
|
220
198
|
"""Language to use for OCR."""
|
221
199
|
max_text_length: int = 25
|
@@ -245,13 +223,13 @@ class PaddleOCRConfig(ConfigDict):
|
|
245
223
|
table: bool = True
|
246
224
|
"""Whether to enable table recognition."""
|
247
225
|
use_angle_cls: bool = True
|
248
|
-
"""Whether to use text orientation classification model."""
|
226
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Use 'use_textline_orientation' instead. Whether to use text orientation classification model."""
|
249
227
|
use_gpu: bool = False
|
250
|
-
"""
|
228
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Parameter no longer supported. Use hardware acceleration flags instead."""
|
251
229
|
device: DeviceType = "auto"
|
252
230
|
"""Device to use for inference. Options: 'cpu', 'cuda', 'auto'. Note: MPS not supported by PaddlePaddle."""
|
253
231
|
gpu_memory_limit: float | None = None
|
254
|
-
"""Maximum GPU memory to use in GB.
|
232
|
+
"""DEPRECATED in PaddleOCR 3.2.0+: Parameter no longer supported. Maximum GPU memory to use in GB."""
|
255
233
|
fallback_to_cpu: bool = True
|
256
234
|
"""Whether to fallback to CPU if requested device is unavailable."""
|
257
235
|
use_space_char: bool = True
|
@@ -259,14 +237,18 @@ class PaddleOCRConfig(ConfigDict):
|
|
259
237
|
use_zero_copy_run: bool = False
|
260
238
|
"""Whether to enable zero_copy_run for inference optimization."""
|
261
239
|
|
240
|
+
text_det_thresh: float = 0.3
|
241
|
+
"""Binarization threshold for text detection output map (replaces det_db_thresh)."""
|
242
|
+
text_det_box_thresh: float = 0.5
|
243
|
+
"""Score threshold for detected text boxes (replaces det_db_box_thresh)."""
|
244
|
+
text_det_unclip_ratio: float = 2.0
|
245
|
+
"""Expansion ratio for detected text boxes (replaces det_db_unclip_ratio)."""
|
246
|
+
use_textline_orientation: bool = True
|
247
|
+
"""Whether to use text line orientation classification model (replaces use_angle_cls)."""
|
248
|
+
|
262
249
|
|
263
250
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
264
251
|
class GMFTConfig(ConfigDict):
|
265
|
-
"""Configuration options for GMFT table extraction.
|
266
|
-
|
267
|
-
This class encapsulates the configuration options for GMFT, providing a way to customize its behavior.
|
268
|
-
"""
|
269
|
-
|
270
252
|
verbosity: int = 0
|
271
253
|
"""
|
272
254
|
Verbosity level for logging.
|
@@ -369,8 +351,6 @@ class GMFTConfig(ConfigDict):
|
|
369
351
|
|
370
352
|
@dataclass(frozen=True, slots=True)
|
371
353
|
class LanguageDetectionConfig(ConfigDict):
|
372
|
-
"""Configuration for language detection."""
|
373
|
-
|
374
354
|
low_memory: bool = True
|
375
355
|
"""If True, uses a smaller model (~200MB). If False, uses a larger, more accurate model.
|
376
356
|
Defaults to True for better memory efficiency."""
|
@@ -387,8 +367,6 @@ class LanguageDetectionConfig(ConfigDict):
|
|
387
367
|
|
388
368
|
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
389
369
|
class SpacyEntityExtractionConfig(ConfigDict):
|
390
|
-
"""Configuration for spaCy-based entity extraction."""
|
391
|
-
|
392
370
|
model_cache_dir: str | Path | None = None
|
393
371
|
"""Directory to cache spaCy models. If None, uses spaCy's default."""
|
394
372
|
language_models: dict[str, str] | tuple[tuple[str, str], ...] | None = None
|
@@ -450,7 +428,6 @@ class SpacyEntityExtractionConfig(ConfigDict):
|
|
450
428
|
}
|
451
429
|
|
452
430
|
def get_model_for_language(self, language_code: str) -> str | None:
|
453
|
-
"""Get the appropriate spaCy model for a language code."""
|
454
431
|
if not self.language_models:
|
455
432
|
return None
|
456
433
|
|
@@ -466,13 +443,10 @@ class SpacyEntityExtractionConfig(ConfigDict):
|
|
466
443
|
return None
|
467
444
|
|
468
445
|
def get_fallback_model(self) -> str | None:
|
469
|
-
"""Get fallback multilingual model if enabled."""
|
470
446
|
return "xx_ent_wiki_sm" if self.fallback_to_multilingual else None
|
471
447
|
|
472
448
|
|
473
449
|
class BoundingBox(TypedDict):
|
474
|
-
"""Bounding box coordinates for text elements."""
|
475
|
-
|
476
450
|
left: int
|
477
451
|
"""X coordinate of the left edge."""
|
478
452
|
top: int
|
@@ -484,8 +458,6 @@ class BoundingBox(TypedDict):
|
|
484
458
|
|
485
459
|
|
486
460
|
class TSVWord(TypedDict):
|
487
|
-
"""Represents a word from Tesseract TSV output."""
|
488
|
-
|
489
461
|
level: int
|
490
462
|
"""Hierarchy level (1=page, 2=block, 3=para, 4=line, 5=word)."""
|
491
463
|
page_num: int
|
@@ -513,8 +485,6 @@ class TSVWord(TypedDict):
|
|
513
485
|
|
514
486
|
|
515
487
|
class TableCell(TypedDict):
|
516
|
-
"""Represents a cell in a reconstructed table."""
|
517
|
-
|
518
488
|
row: int
|
519
489
|
"""Row index (0-based)."""
|
520
490
|
col: int
|
@@ -528,8 +498,6 @@ class TableCell(TypedDict):
|
|
528
498
|
|
529
499
|
|
530
500
|
class TableData(TypedDict):
|
531
|
-
"""Table data, returned from table extraction."""
|
532
|
-
|
533
501
|
cropped_image: Image
|
534
502
|
"""The cropped image of the table."""
|
535
503
|
df: DataFrame | None
|
@@ -541,12 +509,6 @@ class TableData(TypedDict):
|
|
541
509
|
|
542
510
|
|
543
511
|
class Metadata(TypedDict, total=False):
|
544
|
-
"""Base metadata common to all document types.
|
545
|
-
|
546
|
-
All fields will only be included if they contain non-empty values.
|
547
|
-
Any field that would be empty or None is omitted from the dictionary.
|
548
|
-
"""
|
549
|
-
|
550
512
|
authors: NotRequired[list[str]]
|
551
513
|
"""List of document authors."""
|
552
514
|
categories: NotRequired[list[str]]
|
@@ -674,10 +636,6 @@ _VALID_METADATA_KEYS = {
|
|
674
636
|
|
675
637
|
|
676
638
|
def normalize_metadata(data: dict[str, Any] | None) -> Metadata:
|
677
|
-
"""Normalize any dict to proper Metadata TypedDict.
|
678
|
-
|
679
|
-
Filters out invalid keys and ensures type safety.
|
680
|
-
"""
|
681
639
|
if not data:
|
682
640
|
return {}
|
683
641
|
|
@@ -691,8 +649,6 @@ def normalize_metadata(data: dict[str, Any] | None) -> Metadata:
|
|
691
649
|
|
692
650
|
@dataclass(frozen=True, slots=True)
|
693
651
|
class Entity:
|
694
|
-
"""Represents an extracted entity with type, text, and position."""
|
695
|
-
|
696
652
|
type: str
|
697
653
|
"""e.g., PERSON, ORGANIZATION, LOCATION, DATE, EMAIL, PHONE, or custom"""
|
698
654
|
text: str
|
@@ -705,8 +661,6 @@ class Entity:
|
|
705
661
|
|
706
662
|
@dataclass(slots=True)
|
707
663
|
class ExtractionResult:
|
708
|
-
"""The result of a file extraction."""
|
709
|
-
|
710
664
|
content: str
|
711
665
|
"""The extracted content."""
|
712
666
|
mime_type: str
|
@@ -731,15 +685,6 @@ class ExtractionResult:
|
|
731
685
|
"""Internal layout data from OCR, not for public use."""
|
732
686
|
|
733
687
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
734
|
-
"""Converts the ExtractionResult to a dictionary.
|
735
|
-
|
736
|
-
Args:
|
737
|
-
include_none: If True, include fields with None values.
|
738
|
-
If False (default), exclude None values.
|
739
|
-
|
740
|
-
Returns:
|
741
|
-
Dictionary representation of the ExtractionResult.
|
742
|
-
"""
|
743
688
|
result = msgspec.to_builtins(
|
744
689
|
self,
|
745
690
|
builtin_types=(type(None),),
|
@@ -752,33 +697,18 @@ class ExtractionResult:
|
|
752
697
|
return {k: v for k, v in result.items() if v is not None}
|
753
698
|
|
754
699
|
def export_tables_to_csv(self) -> list[str]:
|
755
|
-
"""Export all tables to CSV format.
|
756
|
-
|
757
|
-
Returns:
|
758
|
-
List of CSV strings, one per table
|
759
|
-
"""
|
760
700
|
if not self.tables: # pragma: no cover
|
761
701
|
return []
|
762
702
|
|
763
703
|
return [export_table_to_csv(table) for table in self.tables]
|
764
704
|
|
765
705
|
def export_tables_to_tsv(self) -> list[str]:
|
766
|
-
"""Export all tables to TSV format.
|
767
|
-
|
768
|
-
Returns:
|
769
|
-
List of TSV strings, one per table
|
770
|
-
"""
|
771
706
|
if not self.tables: # pragma: no cover
|
772
707
|
return []
|
773
708
|
|
774
709
|
return [export_table_to_tsv(table) for table in self.tables]
|
775
710
|
|
776
711
|
def get_table_summaries(self) -> list[dict[str, Any]]:
|
777
|
-
"""Get structural information for all tables.
|
778
|
-
|
779
|
-
Returns:
|
780
|
-
List of table structure dictionaries
|
781
|
-
"""
|
782
712
|
if not self.tables: # pragma: no cover
|
783
713
|
return []
|
784
714
|
|
@@ -791,14 +721,6 @@ ValidationHook = Callable[[ExtractionResult], None | Awaitable[None]]
|
|
791
721
|
|
792
722
|
@dataclass(unsafe_hash=True, slots=True)
|
793
723
|
class ExtractionConfig(ConfigDict):
|
794
|
-
"""Represents configuration settings for an extraction process.
|
795
|
-
|
796
|
-
This class encapsulates the configuration options for extracting text
|
797
|
-
from images or documents using Optical Character Recognition (OCR). It
|
798
|
-
provides options to customize the OCR behavior, select the backend
|
799
|
-
engine, and configure engine-specific parameters.
|
800
|
-
"""
|
801
|
-
|
802
724
|
force_ocr: bool = False
|
803
725
|
"""Whether to force OCR."""
|
804
726
|
chunk_content: bool = False
|
@@ -876,11 +798,6 @@ class ExtractionConfig(ConfigDict):
|
|
876
798
|
)
|
877
799
|
|
878
800
|
def get_config_dict(self) -> dict[str, Any]:
|
879
|
-
"""Returns the OCR configuration object based on the backend specified.
|
880
|
-
|
881
|
-
Returns:
|
882
|
-
A dict of the OCR configuration or an empty dict if no backend is provided.
|
883
|
-
"""
|
884
801
|
if self.ocr_backend is None:
|
885
802
|
return {"use_cache": self.use_cache}
|
886
803
|
|
@@ -904,15 +821,6 @@ class ExtractionConfig(ConfigDict):
|
|
904
821
|
return config_dict
|
905
822
|
|
906
823
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
907
|
-
"""Convert configuration to dictionary recursively.
|
908
|
-
|
909
|
-
Args:
|
910
|
-
include_none: If True, include fields with None values.
|
911
|
-
If False (default), exclude None values.
|
912
|
-
|
913
|
-
Returns:
|
914
|
-
Dictionary representation of the configuration with nested configs converted.
|
915
|
-
"""
|
916
824
|
result = msgspec.to_builtins(
|
917
825
|
self,
|
918
826
|
builtin_types=(type(None),),
|
@@ -931,13 +839,6 @@ class ExtractionConfig(ConfigDict):
|
|
931
839
|
|
932
840
|
@dataclass(frozen=True)
|
933
841
|
class HTMLToMarkdownConfig:
|
934
|
-
"""Configuration for HTML to Markdown conversion.
|
935
|
-
|
936
|
-
This configuration class provides fine-grained control over how HTML content
|
937
|
-
is converted to Markdown format. Most fields have sensible defaults that work
|
938
|
-
well for typical document extraction scenarios.
|
939
|
-
"""
|
940
|
-
|
941
842
|
stream_processing: bool = False
|
942
843
|
"""Enable streaming mode for processing large HTML documents."""
|
943
844
|
chunk_size: int = 1024
|
@@ -1004,8 +905,4 @@ class HTMLToMarkdownConfig:
|
|
1004
905
|
"""Remove form elements from HTML."""
|
1005
906
|
|
1006
907
|
def to_dict(self) -> dict[str, Any]:
|
1007
|
-
"""Convert config to dictionary for passing to convert_to_markdown.
|
1008
|
-
|
1009
|
-
Excludes None values and handles special cases.
|
1010
|
-
"""
|
1011
908
|
return {key: value for key, value in self.__dict__.items() if value is not None}
|
kreuzberg/_utils/_cache.py
CHANGED
@@ -20,12 +20,6 @@ T = TypeVar("T")
|
|
20
20
|
|
21
21
|
|
22
22
|
class KreuzbergCache(Generic[T]):
|
23
|
-
"""File-based cache for Kreuzberg operations.
|
24
|
-
|
25
|
-
Provides both sync and async interfaces for caching extraction results,
|
26
|
-
OCR results, table data, and other expensive operations to disk.
|
27
|
-
"""
|
28
|
-
|
29
23
|
def __init__(
|
30
24
|
self,
|
31
25
|
cache_type: str,
|
@@ -33,14 +27,6 @@ class KreuzbergCache(Generic[T]):
|
|
33
27
|
max_cache_size_mb: float = 500.0,
|
34
28
|
max_age_days: int = 30,
|
35
29
|
) -> None:
|
36
|
-
"""Initialize cache.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
cache_type: Type of cache (e.g., 'ocr', 'tables', 'documents', 'mime')
|
40
|
-
cache_dir: Cache directory (defaults to .kreuzberg/{cache_type} in cwd)
|
41
|
-
max_cache_size_mb: Maximum cache size in MB (default: 500MB)
|
42
|
-
max_age_days: Maximum age of cached results in days (default: 30 days)
|
43
|
-
"""
|
44
30
|
if cache_dir is None:
|
45
31
|
cache_dir = Path.cwd() / ".kreuzberg" / cache_type
|
46
32
|
|
@@ -159,14 +145,6 @@ class KreuzbergCache(Generic[T]):
|
|
159
145
|
pass
|
160
146
|
|
161
147
|
def get(self, **kwargs: Any) -> T | None:
|
162
|
-
"""Get cached result (sync).
|
163
|
-
|
164
|
-
Args:
|
165
|
-
**kwargs: Key-value pairs to generate cache key from
|
166
|
-
|
167
|
-
Returns:
|
168
|
-
Cached result if available, None otherwise
|
169
|
-
"""
|
170
148
|
cache_key = self._get_cache_key(**kwargs)
|
171
149
|
cache_path = self._get_cache_path(cache_key)
|
172
150
|
|
@@ -183,12 +161,6 @@ class KreuzbergCache(Generic[T]):
|
|
183
161
|
return None
|
184
162
|
|
185
163
|
def set(self, result: T, **kwargs: Any) -> None:
|
186
|
-
"""Cache result (sync).
|
187
|
-
|
188
|
-
Args:
|
189
|
-
result: Result to cache
|
190
|
-
**kwargs: Key-value pairs to generate cache key from
|
191
|
-
"""
|
192
164
|
cache_key = self._get_cache_key(**kwargs)
|
193
165
|
cache_path = self._get_cache_path(cache_key)
|
194
166
|
|
@@ -203,14 +175,6 @@ class KreuzbergCache(Generic[T]):
|
|
203
175
|
pass
|
204
176
|
|
205
177
|
async def aget(self, **kwargs: Any) -> T | None:
|
206
|
-
"""Get cached result (async).
|
207
|
-
|
208
|
-
Args:
|
209
|
-
**kwargs: Key-value pairs to generate cache key from
|
210
|
-
|
211
|
-
Returns:
|
212
|
-
Cached result if available, None otherwise
|
213
|
-
"""
|
214
178
|
cache_key = self._get_cache_key(**kwargs)
|
215
179
|
cache_path = AsyncPath(self._get_cache_path(cache_key))
|
216
180
|
|
@@ -227,12 +191,6 @@ class KreuzbergCache(Generic[T]):
|
|
227
191
|
return None
|
228
192
|
|
229
193
|
async def aset(self, result: T, **kwargs: Any) -> None:
|
230
|
-
"""Cache result (async).
|
231
|
-
|
232
|
-
Args:
|
233
|
-
result: Result to cache
|
234
|
-
**kwargs: Key-value pairs to generate cache key from
|
235
|
-
"""
|
236
194
|
cache_key = self._get_cache_key(**kwargs)
|
237
195
|
cache_path = AsyncPath(self._get_cache_path(cache_key))
|
238
196
|
|
@@ -247,13 +205,11 @@ class KreuzbergCache(Generic[T]):
|
|
247
205
|
pass
|
248
206
|
|
249
207
|
def is_processing(self, **kwargs: Any) -> bool:
|
250
|
-
"""Check if operation is currently being processed."""
|
251
208
|
cache_key = self._get_cache_key(**kwargs)
|
252
209
|
with self._lock:
|
253
210
|
return cache_key in self._processing
|
254
211
|
|
255
212
|
def mark_processing(self, **kwargs: Any) -> threading.Event:
|
256
|
-
"""Mark operation as being processed and return event to wait on."""
|
257
213
|
cache_key = self._get_cache_key(**kwargs)
|
258
214
|
|
259
215
|
with self._lock:
|
@@ -262,7 +218,6 @@ class KreuzbergCache(Generic[T]):
|
|
262
218
|
return self._processing[cache_key]
|
263
219
|
|
264
220
|
def mark_complete(self, **kwargs: Any) -> None:
|
265
|
-
"""Mark operation processing as complete."""
|
266
221
|
cache_key = self._get_cache_key(**kwargs)
|
267
222
|
|
268
223
|
with self._lock:
|
@@ -271,7 +226,6 @@ class KreuzbergCache(Generic[T]):
|
|
271
226
|
event.set()
|
272
227
|
|
273
228
|
def clear(self) -> None:
|
274
|
-
"""Clear all cached results."""
|
275
229
|
try:
|
276
230
|
for cache_file in self.cache_dir.glob("*.msgpack"):
|
277
231
|
cache_file.unlink(missing_ok=True)
|
@@ -282,7 +236,6 @@ class KreuzbergCache(Generic[T]):
|
|
282
236
|
pass
|
283
237
|
|
284
238
|
def get_stats(self) -> dict[str, Any]:
|
285
|
-
"""Get cache statistics."""
|
286
239
|
try:
|
287
240
|
cache_files = list(self.cache_dir.glob("*.msgpack"))
|
288
241
|
total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists())
|
@@ -328,7 +281,6 @@ _ocr_cache_ref = Ref("ocr_cache", _create_ocr_cache)
|
|
328
281
|
|
329
282
|
|
330
283
|
def get_ocr_cache() -> KreuzbergCache[ExtractionResult]:
|
331
|
-
"""Get the OCR cache instance."""
|
332
284
|
return _ocr_cache_ref.get()
|
333
285
|
|
334
286
|
|
@@ -350,7 +302,6 @@ _document_cache_ref = Ref("document_cache", _create_document_cache)
|
|
350
302
|
|
351
303
|
|
352
304
|
def get_document_cache() -> KreuzbergCache[ExtractionResult]:
|
353
|
-
"""Get the document cache instance."""
|
354
305
|
return _document_cache_ref.get()
|
355
306
|
|
356
307
|
|
@@ -372,7 +323,6 @@ _table_cache_ref = Ref("table_cache", _create_table_cache)
|
|
372
323
|
|
373
324
|
|
374
325
|
def get_table_cache() -> KreuzbergCache[Any]:
|
375
|
-
"""Get the table cache instance."""
|
376
326
|
return _table_cache_ref.get()
|
377
327
|
|
378
328
|
|
@@ -394,12 +344,10 @@ _mime_cache_ref = Ref("mime_cache", _create_mime_cache)
|
|
394
344
|
|
395
345
|
|
396
346
|
def get_mime_cache() -> KreuzbergCache[str]:
|
397
|
-
"""Get the MIME type cache instance."""
|
398
347
|
return _mime_cache_ref.get()
|
399
348
|
|
400
349
|
|
401
350
|
def clear_all_caches() -> None:
|
402
|
-
"""Clear all caches."""
|
403
351
|
if _ocr_cache_ref.is_initialized():
|
404
352
|
get_ocr_cache().clear()
|
405
353
|
if _document_cache_ref.is_initialized():
|
kreuzberg/_utils/_device.py
CHANGED
@@ -14,8 +14,6 @@ DeviceType = Literal["cpu", "cuda", "mps", "auto"]
|
|
14
14
|
|
15
15
|
@dataclass(frozen=True, slots=True)
|
16
16
|
class DeviceInfo:
|
17
|
-
"""Information about a compute device."""
|
18
|
-
|
19
17
|
device_type: Literal["cpu", "cuda", "mps"]
|
20
18
|
"""The type of device."""
|
21
19
|
device_id: int | None = None
|
@@ -29,11 +27,6 @@ class DeviceInfo:
|
|
29
27
|
|
30
28
|
|
31
29
|
def detect_available_devices() -> list[DeviceInfo]:
|
32
|
-
"""Detect all available compute devices.
|
33
|
-
|
34
|
-
Returns:
|
35
|
-
List of available devices, with the most preferred device first.
|
36
|
-
"""
|
37
30
|
cpu_device = DeviceInfo(device_type="cpu", name="CPU")
|
38
31
|
|
39
32
|
cuda_devices = _get_cuda_devices() if _is_cuda_available() else []
|
@@ -46,11 +39,6 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
46
39
|
|
47
40
|
|
48
41
|
def get_optimal_device() -> DeviceInfo:
|
49
|
-
"""Get the optimal device for OCR processing.
|
50
|
-
|
51
|
-
Returns:
|
52
|
-
The best available device, preferring GPU over CPU.
|
53
|
-
"""
|
54
42
|
devices = detect_available_devices()
|
55
43
|
return devices[0] if devices else DeviceInfo(device_type="cpu", name="CPU")
|
56
44
|
|
@@ -62,20 +50,6 @@ def validate_device_request(
|
|
62
50
|
memory_limit: float | None = None,
|
63
51
|
fallback_to_cpu: bool = True,
|
64
52
|
) -> DeviceInfo:
|
65
|
-
"""Validate and resolve a device request.
|
66
|
-
|
67
|
-
Args:
|
68
|
-
requested: The requested device type.
|
69
|
-
backend: Name of the OCR backend requesting the device.
|
70
|
-
memory_limit: Optional memory limit in GB.
|
71
|
-
fallback_to_cpu: Whether to fallback to CPU if requested device unavailable.
|
72
|
-
|
73
|
-
Returns:
|
74
|
-
A validated DeviceInfo object.
|
75
|
-
|
76
|
-
Raises:
|
77
|
-
ValidationError: If the requested device is not available and fallback is disabled.
|
78
|
-
"""
|
79
53
|
available_devices = detect_available_devices()
|
80
54
|
|
81
55
|
if requested == "auto":
|
@@ -115,14 +89,6 @@ def validate_device_request(
|
|
115
89
|
|
116
90
|
|
117
91
|
def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | None]:
|
118
|
-
"""Get memory information for a device.
|
119
|
-
|
120
|
-
Args:
|
121
|
-
device: The device to query.
|
122
|
-
|
123
|
-
Returns:
|
124
|
-
Tuple of (total_memory_gb, available_memory_gb). None values if unknown.
|
125
|
-
"""
|
126
92
|
if device.device_type == "cpu":
|
127
93
|
return None, None
|
128
94
|
|
@@ -261,28 +227,11 @@ def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
|
|
261
227
|
|
262
228
|
|
263
229
|
def is_backend_gpu_compatible(backend: str) -> bool:
|
264
|
-
"""Check if an OCR backend supports GPU acceleration.
|
265
|
-
|
266
|
-
Args:
|
267
|
-
backend: Name of the OCR backend.
|
268
|
-
|
269
|
-
Returns:
|
270
|
-
True if the backend supports GPU acceleration.
|
271
|
-
"""
|
272
230
|
# EasyOCR and PaddleOCR support GPU, Tesseract does not # ~keep
|
273
231
|
return backend.lower() in ("easyocr", "paddleocr")
|
274
232
|
|
275
233
|
|
276
234
|
def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0) -> int:
|
277
|
-
"""Get recommended batch size for OCR processing.
|
278
|
-
|
279
|
-
Args:
|
280
|
-
device: The device to optimize for.
|
281
|
-
input_size_mb: Estimated input size per item in MB.
|
282
|
-
|
283
|
-
Returns:
|
284
|
-
Recommended batch size.
|
285
|
-
"""
|
286
235
|
if device.device_type == "cpu":
|
287
236
|
# Conservative batch size for CPU # ~keep
|
288
237
|
return 1
|
@@ -304,11 +253,6 @@ def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0)
|
|
304
253
|
|
305
254
|
|
306
255
|
def cleanup_device_memory(device: DeviceInfo) -> None:
|
307
|
-
"""Clean up device memory.
|
308
|
-
|
309
|
-
Args:
|
310
|
-
device: The device to clean up.
|
311
|
-
"""
|
312
256
|
if device.device_type == "cuda":
|
313
257
|
try:
|
314
258
|
import torch # noqa: PLC0415
|