kreuzberg 3.8.0__py3-none-any.whl → 3.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kreuzberg/__init__.py +4 -0
  2. kreuzberg/_api/main.py +22 -1
  3. kreuzberg/_config.py +404 -0
  4. kreuzberg/_entity_extraction.py +4 -5
  5. kreuzberg/_extractors/_base.py +3 -5
  6. kreuzberg/_extractors/_image.py +18 -32
  7. kreuzberg/_extractors/_pandoc.py +3 -14
  8. kreuzberg/_extractors/_pdf.py +39 -57
  9. kreuzberg/_extractors/_spread_sheet.py +2 -3
  10. kreuzberg/_extractors/_structured.py +10 -7
  11. kreuzberg/_gmft.py +314 -10
  12. kreuzberg/_language_detection.py +1 -1
  13. kreuzberg/_mcp/server.py +58 -8
  14. kreuzberg/_ocr/__init__.py +1 -22
  15. kreuzberg/_ocr/_base.py +59 -0
  16. kreuzberg/_ocr/_easyocr.py +92 -1
  17. kreuzberg/_ocr/_paddleocr.py +90 -1
  18. kreuzberg/_ocr/_tesseract.py +556 -5
  19. kreuzberg/_playa.py +2 -3
  20. kreuzberg/_types.py +46 -24
  21. kreuzberg/_utils/_cache.py +35 -4
  22. kreuzberg/_utils/_device.py +10 -20
  23. kreuzberg/_utils/_errors.py +44 -45
  24. kreuzberg/_utils/_process_pool.py +2 -6
  25. kreuzberg/_utils/_quality.py +7 -11
  26. kreuzberg/_utils/_serialization.py +21 -16
  27. kreuzberg/_utils/_string.py +22 -12
  28. kreuzberg/_utils/_table.py +3 -4
  29. kreuzberg/cli.py +4 -5
  30. kreuzberg/exceptions.py +10 -0
  31. kreuzberg/extraction.py +6 -24
  32. kreuzberg-3.8.2.dist-info/METADATA +265 -0
  33. kreuzberg-3.8.2.dist-info/RECORD +53 -0
  34. kreuzberg/_cli_config.py +0 -175
  35. kreuzberg/_multiprocessing/__init__.py +0 -5
  36. kreuzberg/_multiprocessing/gmft_isolated.py +0 -330
  37. kreuzberg/_ocr/_pool.py +0 -357
  38. kreuzberg/_ocr/_sync.py +0 -566
  39. kreuzberg-3.8.0.dist-info/METADATA +0 -313
  40. kreuzberg-3.8.0.dist-info/RECORD +0 -57
  41. {kreuzberg-3.8.0.dist-info → kreuzberg-3.8.2.dist-info}/WHEEL +0 -0
  42. {kreuzberg-3.8.0.dist-info → kreuzberg-3.8.2.dist-info}/entry_points.txt +0 -0
  43. {kreuzberg-3.8.0.dist-info → kreuzberg-3.8.2.dist-info}/licenses/LICENSE +0 -0
kreuzberg/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from importlib.metadata import version
2
2
 
3
+ from kreuzberg._config import discover_and_load_config, load_config_from_path, try_discover_config
3
4
  from kreuzberg._entity_extraction import SpacyEntityExtractionConfig
4
5
  from kreuzberg._gmft import GMFTConfig
5
6
  from kreuzberg._language_detection import LanguageDetectionConfig
@@ -48,8 +49,11 @@ __all__ = [
48
49
  "batch_extract_bytes_sync",
49
50
  "batch_extract_file",
50
51
  "batch_extract_file_sync",
52
+ "discover_and_load_config",
51
53
  "extract_bytes",
52
54
  "extract_bytes_sync",
53
55
  "extract_file",
54
56
  "extract_file_sync",
57
+ "load_config_from_path",
58
+ "try_discover_config",
55
59
  ]
kreuzberg/_api/main.py CHANGED
@@ -3,7 +3,10 @@ from __future__ import annotations
3
3
  from json import dumps
4
4
  from typing import TYPE_CHECKING, Annotated, Any
5
5
 
6
+ import msgspec
7
+
6
8
  from kreuzberg import (
9
+ ExtractionConfig,
7
10
  ExtractionResult,
8
11
  KreuzbergError,
9
12
  MissingDependencyError,
@@ -11,6 +14,7 @@ from kreuzberg import (
11
14
  ValidationError,
12
15
  batch_extract_bytes,
13
16
  )
17
+ from kreuzberg._config import try_discover_config
14
18
 
15
19
  if TYPE_CHECKING:
16
20
  from litestar.datastructures import UploadFile
@@ -66,8 +70,12 @@ async def handle_files_upload(
66
70
  data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)],
67
71
  ) -> list[ExtractionResult]:
68
72
  """Extracts text content from an uploaded file."""
73
+ # Try to discover configuration from files
74
+ config = try_discover_config()
75
+
69
76
  return await batch_extract_bytes(
70
77
  [(await file.read(), file.content_type) for file in data],
78
+ config=config or ExtractionConfig(),
71
79
  )
72
80
 
73
81
 
@@ -77,8 +85,21 @@ async def health_check() -> dict[str, str]:
77
85
  return {"status": "ok"}
78
86
 
79
87
 
88
+ @get("/config", operation_id="GetConfiguration")
89
+ async def get_configuration() -> dict[str, Any]:
90
+ """Get the current configuration."""
91
+ config = try_discover_config()
92
+ if config is None:
93
+ return {"message": "No configuration file found", "config": None}
94
+
95
+ return {
96
+ "message": "Configuration loaded successfully",
97
+ "config": msgspec.to_builtins(config, order="deterministic"),
98
+ }
99
+
100
+
80
101
  app = Litestar(
81
- route_handlers=[handle_files_upload, health_check],
102
+ route_handlers=[handle_files_upload, health_check, get_configuration],
82
103
  plugins=[OpenTelemetryPlugin(OpenTelemetryConfig())],
83
104
  logging_config=StructLoggingConfig(),
84
105
  exception_handlers={
kreuzberg/_config.py ADDED
@@ -0,0 +1,404 @@
1
+ """Configuration discovery and loading for Kreuzberg.
2
+
3
+ This module provides configuration loading from both kreuzberg.toml and pyproject.toml files.
4
+ Configuration is automatically discovered by searching up the directory tree from the current
5
+ working directory.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ if sys.version_info >= (3, 11):
15
+ import tomllib
16
+ else:
17
+ import tomli as tomllib # type: ignore[import-not-found]
18
+
19
+ from kreuzberg._gmft import GMFTConfig
20
+ from kreuzberg._ocr._easyocr import EasyOCRConfig
21
+ from kreuzberg._ocr._paddleocr import PaddleOCRConfig
22
+ from kreuzberg._ocr._tesseract import TesseractConfig
23
+ from kreuzberg._types import ExtractionConfig, OcrBackendType
24
+ from kreuzberg.exceptions import ValidationError
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import MutableMapping
28
+
29
+
30
+ def load_config_from_file(config_path: Path) -> dict[str, Any]:
31
+ """Load configuration from a TOML file.
32
+
33
+ Args:
34
+ config_path: Path to the configuration file.
35
+
36
+ Returns:
37
+ Dictionary containing the loaded configuration.
38
+
39
+ Raises:
40
+ ValidationError: If the file cannot be read or parsed.
41
+ """
42
+ try:
43
+ with config_path.open("rb") as f:
44
+ data = tomllib.load(f)
45
+ except FileNotFoundError as e:
46
+ raise ValidationError(f"Configuration file not found: {config_path}") from e
47
+ except tomllib.TOMLDecodeError as e:
48
+ raise ValidationError(f"Invalid TOML in configuration file: {e}") from e
49
+
50
+ # Handle both kreuzberg.toml (root level) and pyproject.toml ([tool.kreuzberg])
51
+ if config_path.name == "kreuzberg.toml":
52
+ return data # type: ignore[no-any-return]
53
+ return data.get("tool", {}).get("kreuzberg", {}) # type: ignore[no-any-return]
54
+
55
+
56
+ def merge_configs(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
57
+ """Merge two configuration dictionaries recursively.
58
+
59
+ Args:
60
+ base: Base configuration dictionary.
61
+ override: Configuration dictionary to override base values.
62
+
63
+ Returns:
64
+ Merged configuration dictionary.
65
+ """
66
+ result = base.copy()
67
+ for key, value in override.items():
68
+ if isinstance(value, dict) and key in result and isinstance(result[key], dict):
69
+ result[key] = merge_configs(result[key], value)
70
+ else:
71
+ result[key] = value
72
+ return result
73
+
74
+
75
+ def parse_ocr_backend_config(
76
+ config_dict: dict[str, Any], backend: OcrBackendType
77
+ ) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig | None:
78
+ """Parse OCR backend-specific configuration.
79
+
80
+ Args:
81
+ config_dict: Configuration dictionary.
82
+ backend: The OCR backend type.
83
+
84
+ Returns:
85
+ Backend-specific configuration object or None.
86
+ """
87
+ if backend not in config_dict:
88
+ return None
89
+
90
+ backend_config = config_dict[backend]
91
+ if not isinstance(backend_config, dict):
92
+ return None
93
+
94
+ if backend == "tesseract":
95
+ # Convert psm integer to PSMMode enum if needed
96
+ processed_config = backend_config.copy()
97
+ if "psm" in processed_config and isinstance(processed_config["psm"], int):
98
+ from kreuzberg._ocr._tesseract import PSMMode
99
+
100
+ processed_config["psm"] = PSMMode(processed_config["psm"])
101
+ return TesseractConfig(**processed_config)
102
+ if backend == "easyocr":
103
+ return EasyOCRConfig(**backend_config)
104
+ if backend == "paddleocr":
105
+ return PaddleOCRConfig(**backend_config)
106
+ return None
107
+
108
+
109
+ def build_extraction_config_from_dict(config_dict: dict[str, Any]) -> ExtractionConfig:
110
+ """Build ExtractionConfig from a configuration dictionary.
111
+
112
+ Args:
113
+ config_dict: Configuration dictionary from TOML file.
114
+
115
+ Returns:
116
+ ExtractionConfig instance.
117
+ """
118
+ extraction_config: dict[str, Any] = {}
119
+
120
+ # Copy basic configuration fields using dictionary comprehension
121
+ basic_fields = {
122
+ "force_ocr",
123
+ "chunk_content",
124
+ "extract_tables",
125
+ "max_chars",
126
+ "max_overlap",
127
+ "ocr_backend",
128
+ "extract_entities",
129
+ "extract_keywords",
130
+ "auto_detect_language",
131
+ "enable_quality_processing",
132
+ }
133
+ extraction_config.update({field: config_dict[field] for field in basic_fields if field in config_dict})
134
+
135
+ # Handle OCR backend configuration
136
+ ocr_backend = extraction_config.get("ocr_backend")
137
+ if ocr_backend and ocr_backend != "none":
138
+ ocr_config = parse_ocr_backend_config(config_dict, ocr_backend)
139
+ if ocr_config:
140
+ extraction_config["ocr_config"] = ocr_config
141
+
142
+ # Handle GMFT configuration for table extraction
143
+ if extraction_config.get("extract_tables") and "gmft" in config_dict and isinstance(config_dict["gmft"], dict):
144
+ extraction_config["gmft_config"] = GMFTConfig(**config_dict["gmft"])
145
+
146
+ # Convert "none" to None for ocr_backend
147
+ if extraction_config.get("ocr_backend") == "none":
148
+ extraction_config["ocr_backend"] = None
149
+
150
+ return ExtractionConfig(**extraction_config)
151
+
152
+
153
+ def find_config_file(start_path: Path | None = None) -> Path | None:
154
+ """Find configuration file by searching up the directory tree.
155
+
156
+ Searches for configuration files in the following order:
157
+ 1. kreuzberg.toml
158
+ 2. pyproject.toml (with [tool.kreuzberg] section)
159
+
160
+ Args:
161
+ start_path: Directory to start searching from. Defaults to current working directory.
162
+
163
+ Returns:
164
+ Path to the configuration file or None if not found.
165
+ """
166
+ current = start_path or Path.cwd()
167
+
168
+ while current != current.parent:
169
+ # First, look for kreuzberg.toml
170
+ kreuzberg_toml = current / "kreuzberg.toml"
171
+ if kreuzberg_toml.exists():
172
+ return kreuzberg_toml
173
+
174
+ # Then, look for pyproject.toml with [tool.kreuzberg] section
175
+ pyproject_toml = current / "pyproject.toml"
176
+ if pyproject_toml.exists():
177
+ try:
178
+ with pyproject_toml.open("rb") as f:
179
+ data = tomllib.load(f)
180
+ if "tool" in data and "kreuzberg" in data["tool"]:
181
+ return pyproject_toml
182
+ except Exception: # noqa: BLE001
183
+ pass
184
+
185
+ current = current.parent
186
+ return None
187
+
188
+
189
+ def load_default_config(start_path: Path | None = None) -> ExtractionConfig | None:
190
+ """Load the default configuration from discovered config file.
191
+
192
+ Args:
193
+ start_path: Directory to start searching from. Defaults to current working directory.
194
+
195
+ Returns:
196
+ ExtractionConfig instance or None if no configuration found.
197
+ """
198
+ config_path = find_config_file(start_path)
199
+ if not config_path:
200
+ return None
201
+
202
+ try:
203
+ config_dict = load_config_from_file(config_path)
204
+ if not config_dict:
205
+ return None
206
+ return build_extraction_config_from_dict(config_dict)
207
+ except Exception: # noqa: BLE001
208
+ # Silently ignore configuration errors for default loading
209
+ return None
210
+
211
+
212
+ def load_config_from_path(config_path: Path | str) -> ExtractionConfig:
213
+ """Load configuration from a specific file path.
214
+
215
+ Args:
216
+ config_path: Path to the configuration file.
217
+
218
+ Returns:
219
+ ExtractionConfig instance.
220
+
221
+ Raises:
222
+ ValidationError: If the file cannot be read, parsed, or is invalid.
223
+ """
224
+ path = Path(config_path)
225
+ config_dict = load_config_from_file(path)
226
+ return build_extraction_config_from_dict(config_dict)
227
+
228
+
229
+ def discover_and_load_config(start_path: Path | str | None = None) -> ExtractionConfig:
230
+ """Load configuration by discovering config files in the directory tree.
231
+
232
+ Args:
233
+ start_path: Directory to start searching from. Defaults to current working directory.
234
+
235
+ Returns:
236
+ ExtractionConfig instance.
237
+
238
+ Raises:
239
+ ValidationError: If no configuration file is found or if the file is invalid.
240
+ """
241
+ search_path = Path(start_path) if start_path else None
242
+ config_path = find_config_file(search_path)
243
+
244
+ if not config_path:
245
+ raise ValidationError(
246
+ "No configuration file found. Searched for 'kreuzberg.toml' and 'pyproject.toml' with [tool.kreuzberg] section.",
247
+ context={"search_path": str(search_path or Path.cwd())},
248
+ )
249
+
250
+ config_dict = load_config_from_file(config_path)
251
+ if not config_dict:
252
+ raise ValidationError(
253
+ f"Configuration file found but contains no Kreuzberg configuration: {config_path}",
254
+ context={"config_path": str(config_path)},
255
+ )
256
+
257
+ return build_extraction_config_from_dict(config_dict)
258
+
259
+
260
+ def try_discover_config(start_path: Path | str | None = None) -> ExtractionConfig | None:
261
+ """Try to discover and load configuration, returning None if not found.
262
+
263
+ Args:
264
+ start_path: Directory to start searching from. Defaults to current working directory.
265
+
266
+ Returns:
267
+ ExtractionConfig instance or None if no configuration found.
268
+ """
269
+ try:
270
+ return discover_and_load_config(start_path)
271
+ except ValidationError:
272
+ return None
273
+
274
+
275
+ # Legacy functions for backward compatibility with CLI
276
+
277
+ # Define common configuration fields to avoid repetition
278
+ _CONFIG_FIELDS = [
279
+ "force_ocr",
280
+ "chunk_content",
281
+ "extract_tables",
282
+ "max_chars",
283
+ "max_overlap",
284
+ "ocr_backend",
285
+ "extract_entities",
286
+ "extract_keywords",
287
+ "auto_detect_language",
288
+ "enable_quality_processing",
289
+ ]
290
+
291
+
292
+ def _merge_file_config(config_dict: dict[str, Any], file_config: dict[str, Any]) -> None:
293
+ """Merge file configuration into config dictionary."""
294
+ if not file_config:
295
+ return
296
+
297
+ for field in _CONFIG_FIELDS:
298
+ if field in file_config:
299
+ config_dict[field] = file_config[field]
300
+
301
+
302
+ def _merge_cli_args(config_dict: dict[str, Any], cli_args: MutableMapping[str, Any]) -> None:
303
+ """Merge CLI arguments into config dictionary."""
304
+ for field in _CONFIG_FIELDS:
305
+ if field in cli_args and cli_args[field] is not None:
306
+ config_dict[field] = cli_args[field]
307
+
308
+
309
+ def _build_ocr_config_from_cli(
310
+ ocr_backend: str, cli_args: MutableMapping[str, Any]
311
+ ) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig | None:
312
+ """Build OCR config from CLI arguments."""
313
+ config_key = f"{ocr_backend}_config"
314
+ if not cli_args.get(config_key):
315
+ return None
316
+
317
+ backend_args = cli_args[config_key]
318
+ if ocr_backend == "tesseract":
319
+ return TesseractConfig(**backend_args)
320
+ if ocr_backend == "easyocr":
321
+ return EasyOCRConfig(**backend_args)
322
+ if ocr_backend == "paddleocr":
323
+ return PaddleOCRConfig(**backend_args)
324
+ return None
325
+
326
+
327
+ def _configure_ocr_backend(
328
+ config_dict: dict[str, Any],
329
+ file_config: dict[str, Any],
330
+ cli_args: MutableMapping[str, Any],
331
+ ) -> None:
332
+ """Configure OCR backend in config dictionary."""
333
+ ocr_backend = config_dict.get("ocr_backend")
334
+ if not ocr_backend or ocr_backend == "none":
335
+ return
336
+
337
+ # Try CLI config first, then file config
338
+ ocr_config = _build_ocr_config_from_cli(ocr_backend, cli_args)
339
+ if not ocr_config and file_config:
340
+ ocr_config = parse_ocr_backend_config(file_config, ocr_backend)
341
+
342
+ if ocr_config:
343
+ config_dict["ocr_config"] = ocr_config
344
+
345
+
346
+ def _configure_gmft(
347
+ config_dict: dict[str, Any],
348
+ file_config: dict[str, Any],
349
+ cli_args: MutableMapping[str, Any],
350
+ ) -> None:
351
+ """Configure GMFT in config dictionary."""
352
+ if not config_dict.get("extract_tables"):
353
+ return
354
+
355
+ gmft_config = None
356
+ if cli_args.get("gmft_config"):
357
+ gmft_config = GMFTConfig(**cli_args["gmft_config"])
358
+ elif "gmft" in file_config and isinstance(file_config["gmft"], dict):
359
+ gmft_config = GMFTConfig(**file_config["gmft"])
360
+
361
+ if gmft_config:
362
+ config_dict["gmft_config"] = gmft_config
363
+
364
+
365
+ def build_extraction_config(
366
+ file_config: dict[str, Any],
367
+ cli_args: MutableMapping[str, Any],
368
+ ) -> ExtractionConfig:
369
+ """Build ExtractionConfig from file config and CLI arguments.
370
+
371
+ Args:
372
+ file_config: Configuration loaded from file.
373
+ cli_args: CLI arguments.
374
+
375
+ Returns:
376
+ ExtractionConfig instance.
377
+ """
378
+ config_dict: dict[str, Any] = {}
379
+
380
+ # Merge configurations: file first, then CLI overrides
381
+ _merge_file_config(config_dict, file_config)
382
+ _merge_cli_args(config_dict, cli_args)
383
+
384
+ # Configure complex components
385
+ _configure_ocr_backend(config_dict, file_config, cli_args)
386
+ _configure_gmft(config_dict, file_config, cli_args)
387
+
388
+ # Convert "none" to None for ocr_backend
389
+ if config_dict.get("ocr_backend") == "none":
390
+ config_dict["ocr_backend"] = None
391
+
392
+ return ExtractionConfig(**config_dict)
393
+
394
+
395
+ def find_default_config() -> Path | None:
396
+ """Find the default configuration file (pyproject.toml).
397
+
398
+ Returns:
399
+ Path to the configuration file or None if not found.
400
+
401
+ Note:
402
+ This function is deprecated. Use find_config_file() instead.
403
+ """
404
+ return find_config_file()
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  import re
4
5
  from dataclasses import dataclass
5
6
  from functools import lru_cache
@@ -13,7 +14,7 @@ if TYPE_CHECKING:
13
14
  from pathlib import Path
14
15
 
15
16
 
16
- @dataclass(unsafe_hash=True, frozen=True)
17
+ @dataclass(unsafe_hash=True, frozen=True, slots=True)
17
18
  class SpacyEntityExtractionConfig:
18
19
  """Configuration for spaCy-based entity extraction."""
19
20
 
@@ -126,8 +127,8 @@ def extract_entities(
126
127
  """
127
128
  entities: list[Entity] = []
128
129
  if custom_patterns:
129
- custom_patterns_dict = dict(custom_patterns)
130
- for ent_type, pattern in custom_patterns_dict.items():
130
+ # Direct iteration over frozenset - no need to convert to dict
131
+ for ent_type, pattern in custom_patterns:
131
132
  entities.extend(
132
133
  Entity(type=ent_type, text=match.group(), start=match.start(), end=match.end())
133
134
  for match in re.finditer(pattern, text)
@@ -181,8 +182,6 @@ def _load_spacy_model(model_name: str, spacy_config: SpacyEntityExtractionConfig
181
182
  import spacy
182
183
 
183
184
  if spacy_config.model_cache_dir:
184
- import os
185
-
186
185
  os.environ["SPACY_DATA"] = str(spacy_config.model_cache_dir)
187
186
 
188
187
  nlp = spacy.load(model_name)
@@ -3,10 +3,12 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import TYPE_CHECKING, ClassVar
5
5
 
6
+ from kreuzberg._types import ExtractionResult, normalize_metadata
7
+ from kreuzberg._utils._quality import calculate_quality_score, clean_extracted_text
8
+
6
9
  if TYPE_CHECKING:
7
10
  from pathlib import Path
8
11
 
9
- from kreuzberg import ExtractionResult
10
12
  from kreuzberg._types import ExtractionConfig
11
13
 
12
14
 
@@ -104,8 +106,6 @@ class Extractor(ABC):
104
106
  if not self.config.enable_quality_processing:
105
107
  return result
106
108
 
107
- from kreuzberg._utils._quality import calculate_quality_score, clean_extracted_text
108
-
109
109
  if not result.content:
110
110
  return result
111
111
 
@@ -120,8 +120,6 @@ class Extractor(ABC):
120
120
  enhanced_metadata["quality_score"] = quality_score
121
121
 
122
122
  # Return enhanced result
123
- from kreuzberg._types import ExtractionResult, normalize_metadata
124
-
125
123
  return ExtractionResult(
126
124
  content=cleaned_content,
127
125
  mime_type=result.mime_type,
@@ -11,13 +11,17 @@ from anyio import Path as AsyncPath
11
11
  from kreuzberg._extractors._base import Extractor
12
12
  from kreuzberg._mime_types import IMAGE_MIME_TYPES
13
13
  from kreuzberg._ocr import get_ocr_backend
14
- from kreuzberg._types import ExtractionResult
14
+ from kreuzberg._ocr._easyocr import EasyOCRConfig
15
+ from kreuzberg._ocr._paddleocr import PaddleOCRConfig
16
+ from kreuzberg._ocr._tesseract import TesseractConfig
15
17
  from kreuzberg._utils._tmp import create_temp_file
16
18
  from kreuzberg.exceptions import ValidationError
17
19
 
18
20
  if TYPE_CHECKING: # pragma: no cover
19
21
  from collections.abc import Mapping
20
22
 
23
+ from kreuzberg._types import ExtractionResult
24
+
21
25
 
22
26
  class ImageExtractor(Extractor):
23
27
  SUPPORTED_MIME_TYPES: ClassVar[set[str]] = IMAGE_MIME_TYPES
@@ -78,44 +82,26 @@ class ImageExtractor(Extractor):
78
82
  if self.config.ocr_backend is None:
79
83
  raise ValidationError("ocr_backend is None, cannot perform OCR")
80
84
 
81
- if self.config.ocr_backend == "tesseract":
82
- from kreuzberg._ocr._sync import process_batch_images_sync
83
- from kreuzberg._ocr._tesseract import TesseractConfig
84
-
85
- if isinstance(self.config.ocr_config, TesseractConfig):
86
- config = self.config.ocr_config
87
- else:
88
- config = TesseractConfig()
89
-
90
- results = process_batch_images_sync([str(path)], config, backend="tesseract")
91
- if results:
92
- result = results[0]
93
- return self._apply_quality_processing(result)
94
- return ExtractionResult(content="", mime_type="text/plain", metadata={}, chunks=[])
95
-
96
- if self.config.ocr_backend == "paddleocr":
97
- from kreuzberg._ocr._paddleocr import PaddleOCRConfig
98
- from kreuzberg._ocr._sync import process_image_paddleocr_sync as paddle_process
85
+ backend = get_ocr_backend(self.config.ocr_backend)
99
86
 
87
+ if self.config.ocr_backend == "tesseract":
88
+ config = (
89
+ self.config.ocr_config if isinstance(self.config.ocr_config, TesseractConfig) else TesseractConfig()
90
+ )
91
+ result = backend.process_file_sync(path, **config.__dict__)
92
+ elif self.config.ocr_backend == "paddleocr":
100
93
  paddle_config = (
101
94
  self.config.ocr_config if isinstance(self.config.ocr_config, PaddleOCRConfig) else PaddleOCRConfig()
102
95
  )
103
-
104
- result = paddle_process(path, paddle_config)
105
- return self._apply_quality_processing(result)
106
-
107
- if self.config.ocr_backend == "easyocr":
108
- from kreuzberg._ocr._easyocr import EasyOCRConfig
109
- from kreuzberg._ocr._sync import process_image_easyocr_sync as easy_process
110
-
96
+ result = backend.process_file_sync(path, **paddle_config.__dict__)
97
+ elif self.config.ocr_backend == "easyocr":
111
98
  easy_config = (
112
99
  self.config.ocr_config if isinstance(self.config.ocr_config, EasyOCRConfig) else EasyOCRConfig()
113
100
  )
114
-
115
- result = easy_process(path, easy_config)
116
- return self._apply_quality_processing(result)
117
-
118
- raise NotImplementedError(f"Sync OCR not implemented for {self.config.ocr_backend}")
101
+ result = backend.process_file_sync(path, **easy_config.__dict__)
102
+ else:
103
+ raise NotImplementedError(f"Sync OCR not implemented for {self.config.ocr_backend}")
104
+ return self._apply_quality_processing(result)
119
105
 
120
106
  def _get_extension_from_mime_type(self, mime_type: str) -> str:
121
107
  if mime_type in self.IMAGE_MIME_TYPE_EXT_MAP:
@@ -1,8 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import contextlib
4
+ import os
4
5
  import re
6
+ import subprocess
5
7
  import sys
8
+ import tempfile
6
9
  from json import JSONDecodeError, loads
7
10
  from pathlib import Path
8
11
  from typing import TYPE_CHECKING, Any, ClassVar, Final, Literal, cast
@@ -203,10 +206,6 @@ class PandocExtractor(Extractor):
203
206
  Returns:
204
207
  ExtractionResult with the extracted text and metadata.
205
208
  """
206
- import os
207
- import tempfile
208
- from pathlib import Path
209
-
210
209
  extension = self._get_pandoc_type_from_mime_type(self.mime_type)
211
210
  fd, temp_path = tempfile.mkstemp(suffix=f".{extension}")
212
211
 
@@ -579,8 +578,6 @@ class PandocExtractor(Extractor):
579
578
 
580
579
  def _validate_pandoc_version_sync(self) -> None:
581
580
  """Synchronous version of _validate_pandoc_version."""
582
- import subprocess
583
-
584
581
  try:
585
582
  if self._checked_version:
586
583
  return
@@ -625,10 +622,6 @@ class PandocExtractor(Extractor):
625
622
 
626
623
  def _extract_metadata_sync(self, path: Path) -> Metadata:
627
624
  """Synchronous version of _handle_extract_metadata."""
628
- import os
629
- import subprocess
630
- import tempfile
631
-
632
625
  pandoc_type = self._get_pandoc_type_from_mime_type(self.mime_type)
633
626
  fd, metadata_file = tempfile.mkstemp(suffix=".json")
634
627
  os.close(fd)
@@ -663,10 +656,6 @@ class PandocExtractor(Extractor):
663
656
 
664
657
  def _extract_file_sync(self, path: Path) -> str:
665
658
  """Synchronous version of _handle_extract_file."""
666
- import os
667
- import subprocess
668
- import tempfile
669
-
670
659
  pandoc_type = self._get_pandoc_type_from_mime_type(self.mime_type)
671
660
  fd, output_path = tempfile.mkstemp(suffix=".md")
672
661
  os.close(fd)