kreuzberg 3.13.3__py3-none-any.whl → 3.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  import platform
4
5
  import warnings
5
6
  from importlib.util import find_spec
@@ -36,18 +37,31 @@ except ImportError: # pragma: no cover
36
37
  if TYPE_CHECKING:
37
38
  import numpy as np
38
39
  from paddleocr import PaddleOCR
40
+ else:
41
+ np: Any = None
42
+ PaddleOCR: Any = None
43
+
44
+ HAS_PADDLEOCR: bool = False
45
+
46
+
47
+ def _import_paddleocr() -> tuple[Any, Any]:
48
+ global HAS_PADDLEOCR, np, PaddleOCR
49
+
50
+ if HAS_PADDLEOCR:
51
+ return np, PaddleOCR
39
52
 
40
- HAS_PADDLEOCR: bool
41
- if not TYPE_CHECKING:
42
53
  try:
43
- import numpy as np
44
- from paddleocr import PaddleOCR
54
+ os.environ.setdefault("HUB_DATASET_ENDPOINT", "https://modelscope.cn/api/v1/datasets")
55
+
56
+ import numpy as _np # noqa: PLC0415, ICN001
57
+ from paddleocr import PaddleOCR as _PaddleOCR # noqa: PLC0415
45
58
 
59
+ np = _np
60
+ PaddleOCR = _PaddleOCR
46
61
  HAS_PADDLEOCR = True
62
+ return np, PaddleOCR
47
63
  except ImportError:
48
- HAS_PADDLEOCR = False
49
- np: Any = None
50
- PaddleOCR: Any = None
64
+ return None, None
51
65
 
52
66
 
53
67
  PADDLEOCR_SUPPORTED_LANGUAGE_CODES: Final[set[str]] = {"ch", "en", "french", "german", "japan", "korean"}
@@ -74,7 +88,12 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
74
88
  if image.mode != "RGB":
75
89
  image = image.convert("RGB")
76
90
 
77
- image_np = np.array(image)
91
+ _np, _ = _import_paddleocr()
92
+ if _np is None:
93
+ raise MissingDependencyError.create_for_package(
94
+ dependency_group="paddleocr", functionality="PaddleOCR as an OCR backend", package_name="paddleocr"
95
+ )
96
+ image_np = _np.array(image)
78
97
  use_textline_orientation = kwargs.get("use_textline_orientation", kwargs.get("use_angle_cls", True))
79
98
  result = await run_sync(self._paddle_ocr.ocr, image_np, cls=use_textline_orientation)
80
99
 
@@ -195,7 +214,8 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
195
214
  if cls._paddle_ocr is not None:
196
215
  return
197
216
 
198
- if not HAS_PADDLEOCR or PaddleOCR is None:
217
+ _np, _paddle_ocr = _import_paddleocr()
218
+ if _paddle_ocr is None:
199
219
  raise MissingDependencyError.create_for_package(
200
220
  dependency_group="paddleocr", functionality="PaddleOCR as an OCR backend", package_name="paddleocr"
201
221
  )
@@ -224,7 +244,7 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
224
244
  kwargs.setdefault("enable_mkldnn", cls._is_mkldnn_supported())
225
245
 
226
246
  try:
227
- cls._paddle_ocr = await run_sync(PaddleOCR, lang=language, **kwargs)
247
+ cls._paddle_ocr = await run_sync(_paddle_ocr, lang=language, **kwargs)
228
248
  except Exception as e:
229
249
  raise OCRError(f"Failed to initialize PaddleOCR: {e}") from e
230
250
 
@@ -304,7 +324,12 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
304
324
  if image.mode != "RGB":
305
325
  image = image.convert("RGB")
306
326
 
307
- image_np = np.array(image)
327
+ _np, _ = _import_paddleocr()
328
+ if _np is None:
329
+ raise MissingDependencyError.create_for_package(
330
+ dependency_group="paddleocr", functionality="PaddleOCR as an OCR backend", package_name="paddleocr"
331
+ )
332
+ image_np = _np.array(image)
308
333
  use_textline_orientation = kwargs.get("use_textline_orientation", kwargs.get("use_angle_cls", True))
309
334
  result = self._paddle_ocr.ocr(image_np, cls=use_textline_orientation)
310
335
 
@@ -352,7 +377,8 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
352
377
  if cls._paddle_ocr is not None:
353
378
  return
354
379
 
355
- if not HAS_PADDLEOCR or PaddleOCR is None:
380
+ _np, _paddle_ocr = _import_paddleocr()
381
+ if _paddle_ocr is None:
356
382
  raise MissingDependencyError.create_for_package(
357
383
  dependency_group="paddleocr", functionality="PaddleOCR as an OCR backend", package_name="paddleocr"
358
384
  )
@@ -381,6 +407,6 @@ class PaddleBackend(OCRBackend[PaddleOCRConfig]):
381
407
  kwargs.setdefault("enable_mkldnn", cls._is_mkldnn_supported())
382
408
 
383
409
  try:
384
- cls._paddle_ocr = PaddleOCR(lang=language, **kwargs)
410
+ cls._paddle_ocr = _paddle_ocr(lang=language, **kwargs)
385
411
  except Exception as e:
386
412
  raise OCRError(f"Failed to initialize PaddleOCR: {e}") from e
@@ -28,6 +28,7 @@ from kreuzberg._ocr._base import OCRBackend
28
28
  from kreuzberg._ocr._table_extractor import extract_words, reconstruct_table, to_markdown
29
29
  from kreuzberg._types import ExtractionResult, HTMLToMarkdownConfig, PSMMode, TableData, TesseractConfig
30
30
  from kreuzberg._utils._cache import get_ocr_cache
31
+ from kreuzberg._utils._process_pool import ProcessPoolManager
31
32
  from kreuzberg._utils._string import normalize_spaces
32
33
  from kreuzberg._utils._sync import run_sync
33
34
  from kreuzberg._utils._tmp import create_temp_file
@@ -467,7 +468,7 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
467
468
  last_para = -1
468
469
 
469
470
  for line_key in sorted(lines.keys()):
470
- page_num, block_num, par_num, line_num = line_key
471
+ _page_num, block_num, par_num, _line_num = line_key
471
472
 
472
473
  if block_num != last_block:
473
474
  if text_parts: # ~keep
@@ -1297,8 +1298,6 @@ class TesseractProcessPool:
1297
1298
  max_processes: int | None = None,
1298
1299
  memory_limit_gb: float | None = None,
1299
1300
  ) -> None:
1300
- from kreuzberg._utils._process_pool import ProcessPoolManager # noqa: PLC0415
1301
-
1302
1301
  self.config = config or TesseractConfig()
1303
1302
  self.process_manager = ProcessPoolManager(
1304
1303
  max_processes=max_processes,
kreuzberg/_registry.py CHANGED
@@ -28,6 +28,13 @@ if TYPE_CHECKING:
28
28
 
29
29
 
30
30
  class ExtractorRegistry:
31
+ """Registry for managing document extractors.
32
+
33
+ This class maintains a registry of extractors for different file types and provides
34
+ functionality to get the appropriate extractor for a given MIME type, as well as
35
+ add or remove custom extractors.
36
+ """
37
+
31
38
  _default_extractors: ClassVar[list[type[Extractor]]] = [
32
39
  PDFExtractor,
33
40
  OfficeDocumentExtractor,
@@ -51,6 +58,15 @@ class ExtractorRegistry:
51
58
  @classmethod
52
59
  @lru_cache
53
60
  def get_extractor(cls, mime_type: str | None, config: ExtractionConfig) -> Extractor | None:
61
+ """Get an appropriate extractor for the given MIME type.
62
+
63
+ Args:
64
+ mime_type: The MIME type to find an extractor for.
65
+ config: The extraction configuration.
66
+
67
+ Returns:
68
+ An extractor instance if one supports the MIME type, None otherwise.
69
+ """
54
70
  extractors: list[type[Extractor]] = [
55
71
  *cls._registered_extractors,
56
72
  *cls._default_extractors,
@@ -64,11 +80,21 @@ class ExtractorRegistry:
64
80
 
65
81
  @classmethod
66
82
  def add_extractor(cls, extractor: type[Extractor]) -> None:
83
+ """Add a custom extractor to the registry.
84
+
85
+ Args:
86
+ extractor: The extractor class to add to the registry.
87
+ """
67
88
  cls._registered_extractors.append(extractor)
68
89
  cls.get_extractor.cache_clear()
69
90
 
70
91
  @classmethod
71
92
  def remove_extractor(cls, extractor: type[Extractor]) -> None:
93
+ """Remove a custom extractor from the registry.
94
+
95
+ Args:
96
+ extractor: The extractor class to remove from the registry.
97
+ """
72
98
  try:
73
99
  cls._registered_extractors.remove(extractor)
74
100
  cls.get_extractor.cache_clear()
kreuzberg/_types.py CHANGED
@@ -4,7 +4,7 @@ import sys
4
4
  from collections.abc import Awaitable, Callable, Iterable, Mapping
5
5
  from dataclasses import asdict, dataclass, field
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING, Any, Literal, TypedDict
7
+ from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypedDict
8
8
 
9
9
  import msgspec
10
10
 
@@ -508,6 +508,35 @@ class TableData(TypedDict):
508
508
  """The table text as a markdown string."""
509
509
 
510
510
 
511
+ class ImagePreprocessingMetadata(NamedTuple):
512
+ """Metadata about image preprocessing operations for OCR."""
513
+
514
+ original_dimensions: tuple[int, int]
515
+ """Original image dimensions (width, height) in pixels."""
516
+ original_dpi: tuple[float, float]
517
+ """Original image DPI (horizontal, vertical)."""
518
+ target_dpi: int
519
+ """Target DPI that was requested."""
520
+ scale_factor: float
521
+ """Scale factor applied to the image."""
522
+ auto_adjusted: bool
523
+ """Whether DPI was automatically adjusted due to size constraints."""
524
+ final_dpi: int | None = None
525
+ """Final DPI used after processing."""
526
+ new_dimensions: tuple[int, int] | None = None
527
+ """New image dimensions after processing (width, height) in pixels."""
528
+ resample_method: str | None = None
529
+ """Resampling method used (LANCZOS, BICUBIC, etc.)."""
530
+ skipped_resize: bool = False
531
+ """Whether resizing was skipped (no change needed)."""
532
+ dimension_clamped: bool = False
533
+ """Whether image was clamped to maximum dimension constraints."""
534
+ calculated_dpi: int | None = None
535
+ """DPI calculated during auto-adjustment."""
536
+ resize_error: str | None = None
537
+ """Error message if resizing failed."""
538
+
539
+
511
540
  class Metadata(TypedDict, total=False):
512
541
  authors: NotRequired[list[str]]
513
542
  """List of document authors."""
@@ -587,6 +616,8 @@ class Metadata(TypedDict, total=False):
587
616
  """Summary of table extraction results."""
588
617
  quality_score: NotRequired[float]
589
618
  """Quality score for extracted content (0.0-1.0)."""
619
+ image_preprocessing: NotRequired[ImagePreprocessingMetadata]
620
+ """Metadata about image preprocessing operations (DPI adjustments, scaling, etc.)."""
590
621
  source_format: NotRequired[str]
591
622
  """Source format of the extracted content."""
592
623
  error: NotRequired[str]
@@ -632,6 +663,7 @@ _VALID_METADATA_KEYS = {
632
663
  "table_count",
633
664
  "tables_summary",
634
665
  "quality_score",
666
+ "image_preprocessing",
635
667
  }
636
668
 
637
669
 
@@ -775,6 +807,16 @@ class ExtractionConfig(ConfigDict):
775
807
  """Configuration for HTML to Markdown conversion. If None, uses default settings."""
776
808
  use_cache: bool = True
777
809
  """Whether to use caching for extraction results. Set to False to disable all caching."""
810
+ target_dpi: int = 150
811
+ """Target DPI for OCR processing. Images and PDF pages will be scaled to this DPI for optimal OCR results."""
812
+ max_image_dimension: int = 25000
813
+ """Maximum allowed pixel dimension (width or height) for processed images to prevent memory issues."""
814
+ auto_adjust_dpi: bool = True
815
+ """Whether to automatically adjust DPI based on image dimensions to stay within max_image_dimension limits."""
816
+ min_dpi: int = 72
817
+ """Minimum DPI threshold when auto-adjusting DPI."""
818
+ max_dpi: int = 600
819
+ """Maximum DPI threshold when auto-adjusting DPI."""
778
820
 
779
821
  def __post_init__(self) -> None:
780
822
  if self.custom_entity_patterns is not None and isinstance(self.custom_entity_patterns, dict):
@@ -797,6 +839,27 @@ class ExtractionConfig(ConfigDict):
797
839
  context={"ocr_backend": self.ocr_backend, "ocr_config": type(self.ocr_config).__name__},
798
840
  )
799
841
 
842
+ # Validate DPI configuration
843
+ if self.target_dpi <= 0:
844
+ raise ValidationError("target_dpi must be positive", context={"target_dpi": self.target_dpi})
845
+ if self.min_dpi <= 0:
846
+ raise ValidationError("min_dpi must be positive", context={"min_dpi": self.min_dpi})
847
+ if self.max_dpi <= 0:
848
+ raise ValidationError("max_dpi must be positive", context={"max_dpi": self.max_dpi})
849
+ if self.min_dpi >= self.max_dpi:
850
+ raise ValidationError(
851
+ "min_dpi must be less than max_dpi", context={"min_dpi": self.min_dpi, "max_dpi": self.max_dpi}
852
+ )
853
+ if self.max_image_dimension <= 0:
854
+ raise ValidationError(
855
+ "max_image_dimension must be positive", context={"max_image_dimension": self.max_image_dimension}
856
+ )
857
+ if not (self.min_dpi <= self.target_dpi <= self.max_dpi):
858
+ raise ValidationError(
859
+ "target_dpi must be between min_dpi and max_dpi",
860
+ context={"target_dpi": self.target_dpi, "min_dpi": self.min_dpi, "max_dpi": self.max_dpi},
861
+ )
862
+
800
863
  def get_config_dict(self) -> dict[str, Any]:
801
864
  if self.ocr_backend is None:
802
865
  return {"use_cache": self.use_cache}
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import hashlib
4
+ import io
4
5
  import os
5
6
  import threading
6
7
  import time
7
8
  from contextlib import suppress
8
9
  from io import StringIO
9
10
  from pathlib import Path
10
- from typing import Any, Generic, TypeVar
11
+ from typing import Any, Generic, TypeVar, cast
11
12
 
13
+ import polars as pl
12
14
  from anyio import Path as AsyncPath
13
15
 
14
16
  from kreuzberg._types import ExtractionResult
@@ -79,10 +81,18 @@ class KreuzbergCache(Generic[T]):
79
81
  for item in result:
80
82
  if isinstance(item, dict) and "df" in item:
81
83
  serialized_item = {k: v for k, v in item.items() if k != "df"}
82
- if hasattr(item["df"], "to_csv"):
83
- serialized_item["df_csv"] = item["df"].to_csv(index=False)
84
+ if item["df"] is not None:
85
+ buffer = io.BytesIO()
86
+ if hasattr(item["df"], "write_parquet"):
87
+ item["df"].write_parquet(buffer)
88
+ serialized_item["df_parquet"] = buffer.getvalue()
89
+ elif hasattr(item["df"], "write_csv"):
90
+ item["df"].write_csv(buffer)
91
+ serialized_item["df_parquet"] = buffer.getvalue()
92
+ else:
93
+ serialized_item["df_parquet"] = None
84
94
  else:
85
- serialized_item["df_csv"] = str(item["df"])
95
+ serialized_item["df_parquet"] = None
86
96
  serialized_data.append(serialized_item)
87
97
  else:
88
98
  serialized_data.append(item)
@@ -94,22 +104,34 @@ class KreuzbergCache(Generic[T]):
94
104
  data = cached_data["data"]
95
105
 
96
106
  if cached_data.get("type") == "TableDataList" and isinstance(data, list):
97
- import pandas as pd # noqa: PLC0415
98
-
99
107
  deserialized_data = []
100
108
  for item in data:
101
- if isinstance(item, dict) and "df_csv" in item:
102
- deserialized_item = {k: v for k, v in item.items() if k != "df_csv"}
103
- deserialized_item["df"] = pd.read_csv(StringIO(item["df_csv"]))
109
+ if isinstance(item, dict) and ("df_parquet" in item or "df_csv" in item):
110
+ deserialized_item = {k: v for k, v in item.items() if k not in ("df_parquet", "df_csv")}
111
+
112
+ if "df_parquet" in item:
113
+ if item["df_parquet"] is None:
114
+ deserialized_item["df"] = pl.DataFrame()
115
+ else:
116
+ buffer = io.BytesIO(item["df_parquet"])
117
+ try:
118
+ deserialized_item["df"] = pl.read_parquet(buffer)
119
+ except Exception: # noqa: BLE001
120
+ deserialized_item["df"] = pl.DataFrame()
121
+ elif "df_csv" in item:
122
+ if item["df_csv"] is None or item["df_csv"] == "" or item["df_csv"] == "\n":
123
+ deserialized_item["df"] = pl.DataFrame()
124
+ else:
125
+ deserialized_item["df"] = pl.read_csv(StringIO(item["df_csv"]))
104
126
  deserialized_data.append(deserialized_item)
105
127
  else:
106
128
  deserialized_data.append(item)
107
- return deserialized_data # type: ignore[return-value]
129
+ return cast("T", deserialized_data)
108
130
 
109
131
  if cached_data.get("type") == "ExtractionResult" and isinstance(data, dict):
110
- return ExtractionResult(**data) # type: ignore[return-value]
132
+ return cast("T", ExtractionResult(**data))
111
133
 
112
- return data # type: ignore[no-any-return]
134
+ return cast("T", data)
113
135
 
114
136
  def _cleanup_cache(self) -> None:
115
137
  try: