kreuzberg 3.8.1__py3-none-any.whl → 3.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kreuzberg/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from importlib.metadata import version
2
2
 
3
+ from kreuzberg._config import discover_and_load_config, load_config_from_path, try_discover_config
3
4
  from kreuzberg._entity_extraction import SpacyEntityExtractionConfig
4
5
  from kreuzberg._gmft import GMFTConfig
5
6
  from kreuzberg._language_detection import LanguageDetectionConfig
@@ -48,8 +49,11 @@ __all__ = [
48
49
  "batch_extract_bytes_sync",
49
50
  "batch_extract_file",
50
51
  "batch_extract_file_sync",
52
+ "discover_and_load_config",
51
53
  "extract_bytes",
52
54
  "extract_bytes_sync",
53
55
  "extract_file",
54
56
  "extract_file_sync",
57
+ "load_config_from_path",
58
+ "try_discover_config",
55
59
  ]
kreuzberg/_api/main.py CHANGED
@@ -3,7 +3,10 @@ from __future__ import annotations
3
3
  from json import dumps
4
4
  from typing import TYPE_CHECKING, Annotated, Any
5
5
 
6
+ import msgspec
7
+
6
8
  from kreuzberg import (
9
+ ExtractionConfig,
7
10
  ExtractionResult,
8
11
  KreuzbergError,
9
12
  MissingDependencyError,
@@ -11,6 +14,7 @@ from kreuzberg import (
11
14
  ValidationError,
12
15
  batch_extract_bytes,
13
16
  )
17
+ from kreuzberg._config import try_discover_config
14
18
 
15
19
  if TYPE_CHECKING:
16
20
  from litestar.datastructures import UploadFile
@@ -66,8 +70,12 @@ async def handle_files_upload(
66
70
  data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)],
67
71
  ) -> list[ExtractionResult]:
68
72
  """Extracts text content from an uploaded file."""
73
+ # Try to discover configuration from files
74
+ config = try_discover_config()
75
+
69
76
  return await batch_extract_bytes(
70
77
  [(await file.read(), file.content_type) for file in data],
78
+ config=config or ExtractionConfig(),
71
79
  )
72
80
 
73
81
 
@@ -77,8 +85,21 @@ async def health_check() -> dict[str, str]:
77
85
  return {"status": "ok"}
78
86
 
79
87
 
88
+ @get("/config", operation_id="GetConfiguration")
89
+ async def get_configuration() -> dict[str, Any]:
90
+ """Get the current configuration."""
91
+ config = try_discover_config()
92
+ if config is None:
93
+ return {"message": "No configuration file found", "config": None}
94
+
95
+ return {
96
+ "message": "Configuration loaded successfully",
97
+ "config": msgspec.to_builtins(config, order="deterministic"),
98
+ }
99
+
100
+
80
101
  app = Litestar(
81
- route_handlers=[handle_files_upload, health_check],
102
+ route_handlers=[handle_files_upload, health_check, get_configuration],
82
103
  plugins=[OpenTelemetryPlugin(OpenTelemetryConfig())],
83
104
  logging_config=StructLoggingConfig(),
84
105
  exception_handlers={
kreuzberg/_chunker.py CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
5
- from kreuzberg import MissingDependencyError
6
5
  from kreuzberg._constants import DEFAULT_MAX_CHARACTERS, DEFAULT_MAX_OVERLAP
7
6
  from kreuzberg._mime_types import MARKDOWN_MIME_TYPE
7
+ from kreuzberg.exceptions import MissingDependencyError
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from semantic_text_splitter import MarkdownSplitter, TextSplitter
@@ -36,11 +36,11 @@ def get_chunker(
36
36
  if key not in _chunkers:
37
37
  try:
38
38
  if mime_type == MARKDOWN_MIME_TYPE:
39
- from semantic_text_splitter import MarkdownSplitter
39
+ from semantic_text_splitter import MarkdownSplitter # noqa: PLC0415
40
40
 
41
41
  _chunkers[key] = MarkdownSplitter(max_characters, overlap_characters)
42
42
  else:
43
- from semantic_text_splitter import TextSplitter
43
+ from semantic_text_splitter import TextSplitter # noqa: PLC0415
44
44
 
45
45
  _chunkers[key] = TextSplitter(max_characters, overlap_characters)
46
46
  except ImportError as e:
kreuzberg/_config.py ADDED
@@ -0,0 +1,404 @@
1
+ """Configuration discovery and loading for Kreuzberg.
2
+
3
+ This module provides configuration loading from both kreuzberg.toml and pyproject.toml files.
4
+ Configuration is automatically discovered by searching up the directory tree from the current
5
+ working directory.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ if sys.version_info >= (3, 11):
15
+ import tomllib
16
+ else:
17
+ import tomli as tomllib # type: ignore[import-not-found]
18
+
19
+ from kreuzberg._gmft import GMFTConfig
20
+ from kreuzberg._ocr._easyocr import EasyOCRConfig
21
+ from kreuzberg._ocr._paddleocr import PaddleOCRConfig
22
+ from kreuzberg._ocr._tesseract import TesseractConfig
23
+ from kreuzberg._types import ExtractionConfig, OcrBackendType
24
+ from kreuzberg.exceptions import ValidationError
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import MutableMapping
28
+
29
+
30
+ def load_config_from_file(config_path: Path) -> dict[str, Any]:
31
+ """Load configuration from a TOML file.
32
+
33
+ Args:
34
+ config_path: Path to the configuration file.
35
+
36
+ Returns:
37
+ Dictionary containing the loaded configuration.
38
+
39
+ Raises:
40
+ ValidationError: If the file cannot be read or parsed.
41
+ """
42
+ try:
43
+ with config_path.open("rb") as f:
44
+ data = tomllib.load(f)
45
+ except FileNotFoundError as e:
46
+ raise ValidationError(f"Configuration file not found: {config_path}") from e
47
+ except tomllib.TOMLDecodeError as e:
48
+ raise ValidationError(f"Invalid TOML in configuration file: {e}") from e
49
+
50
+ # Handle both kreuzberg.toml (root level) and pyproject.toml ([tool.kreuzberg])
51
+ if config_path.name == "kreuzberg.toml":
52
+ return data # type: ignore[no-any-return]
53
+ return data.get("tool", {}).get("kreuzberg", {}) # type: ignore[no-any-return]
54
+
55
+
56
+ def merge_configs(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
57
+ """Merge two configuration dictionaries recursively.
58
+
59
+ Args:
60
+ base: Base configuration dictionary.
61
+ override: Configuration dictionary to override base values.
62
+
63
+ Returns:
64
+ Merged configuration dictionary.
65
+ """
66
+ result = base.copy()
67
+ for key, value in override.items():
68
+ if isinstance(value, dict) and key in result and isinstance(result[key], dict):
69
+ result[key] = merge_configs(result[key], value)
70
+ else:
71
+ result[key] = value
72
+ return result
73
+
74
+
75
+ def parse_ocr_backend_config(
76
+ config_dict: dict[str, Any], backend: OcrBackendType
77
+ ) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig | None:
78
+ """Parse OCR backend-specific configuration.
79
+
80
+ Args:
81
+ config_dict: Configuration dictionary.
82
+ backend: The OCR backend type.
83
+
84
+ Returns:
85
+ Backend-specific configuration object or None.
86
+ """
87
+ if backend not in config_dict:
88
+ return None
89
+
90
+ backend_config = config_dict[backend]
91
+ if not isinstance(backend_config, dict):
92
+ return None
93
+
94
+ if backend == "tesseract":
95
+ # Convert psm integer to PSMMode enum if needed
96
+ processed_config = backend_config.copy()
97
+ if "psm" in processed_config and isinstance(processed_config["psm"], int):
98
+ from kreuzberg._ocr._tesseract import PSMMode # noqa: PLC0415
99
+
100
+ processed_config["psm"] = PSMMode(processed_config["psm"])
101
+ return TesseractConfig(**processed_config)
102
+ if backend == "easyocr":
103
+ return EasyOCRConfig(**backend_config)
104
+ if backend == "paddleocr":
105
+ return PaddleOCRConfig(**backend_config)
106
+ return None
107
+
108
+
109
+ def build_extraction_config_from_dict(config_dict: dict[str, Any]) -> ExtractionConfig:
110
+ """Build ExtractionConfig from a configuration dictionary.
111
+
112
+ Args:
113
+ config_dict: Configuration dictionary from TOML file.
114
+
115
+ Returns:
116
+ ExtractionConfig instance.
117
+ """
118
+ extraction_config: dict[str, Any] = {}
119
+
120
+ # Copy basic configuration fields using dictionary comprehension
121
+ basic_fields = {
122
+ "force_ocr",
123
+ "chunk_content",
124
+ "extract_tables",
125
+ "max_chars",
126
+ "max_overlap",
127
+ "ocr_backend",
128
+ "extract_entities",
129
+ "extract_keywords",
130
+ "auto_detect_language",
131
+ "enable_quality_processing",
132
+ }
133
+ extraction_config.update({field: config_dict[field] for field in basic_fields if field in config_dict})
134
+
135
+ # Handle OCR backend configuration
136
+ ocr_backend = extraction_config.get("ocr_backend")
137
+ if ocr_backend and ocr_backend != "none":
138
+ ocr_config = parse_ocr_backend_config(config_dict, ocr_backend)
139
+ if ocr_config:
140
+ extraction_config["ocr_config"] = ocr_config
141
+
142
+ # Handle GMFT configuration for table extraction
143
+ if extraction_config.get("extract_tables") and "gmft" in config_dict and isinstance(config_dict["gmft"], dict):
144
+ extraction_config["gmft_config"] = GMFTConfig(**config_dict["gmft"])
145
+
146
+ # Convert "none" to None for ocr_backend
147
+ if extraction_config.get("ocr_backend") == "none":
148
+ extraction_config["ocr_backend"] = None
149
+
150
+ return ExtractionConfig(**extraction_config)
151
+
152
+
153
+ def find_config_file(start_path: Path | None = None) -> Path | None:
154
+ """Find configuration file by searching up the directory tree.
155
+
156
+ Searches for configuration files in the following order:
157
+ 1. kreuzberg.toml
158
+ 2. pyproject.toml (with [tool.kreuzberg] section)
159
+
160
+ Args:
161
+ start_path: Directory to start searching from. Defaults to current working directory.
162
+
163
+ Returns:
164
+ Path to the configuration file or None if not found.
165
+ """
166
+ current = start_path or Path.cwd()
167
+
168
+ while current != current.parent:
169
+ # First, look for kreuzberg.toml
170
+ kreuzberg_toml = current / "kreuzberg.toml"
171
+ if kreuzberg_toml.exists():
172
+ return kreuzberg_toml
173
+
174
+ # Then, look for pyproject.toml with [tool.kreuzberg] section
175
+ pyproject_toml = current / "pyproject.toml"
176
+ if pyproject_toml.exists():
177
+ try:
178
+ with pyproject_toml.open("rb") as f:
179
+ data = tomllib.load(f)
180
+ if "tool" in data and "kreuzberg" in data["tool"]:
181
+ return pyproject_toml
182
+ except Exception: # noqa: BLE001
183
+ pass
184
+
185
+ current = current.parent
186
+ return None
187
+
188
+
189
+ def load_default_config(start_path: Path | None = None) -> ExtractionConfig | None:
190
+ """Load the default configuration from discovered config file.
191
+
192
+ Args:
193
+ start_path: Directory to start searching from. Defaults to current working directory.
194
+
195
+ Returns:
196
+ ExtractionConfig instance or None if no configuration found.
197
+ """
198
+ config_path = find_config_file(start_path)
199
+ if not config_path:
200
+ return None
201
+
202
+ try:
203
+ config_dict = load_config_from_file(config_path)
204
+ if not config_dict:
205
+ return None
206
+ return build_extraction_config_from_dict(config_dict)
207
+ except Exception: # noqa: BLE001
208
+ # Silently ignore configuration errors for default loading
209
+ return None
210
+
211
+
212
+ def load_config_from_path(config_path: Path | str) -> ExtractionConfig:
213
+ """Load configuration from a specific file path.
214
+
215
+ Args:
216
+ config_path: Path to the configuration file.
217
+
218
+ Returns:
219
+ ExtractionConfig instance.
220
+
221
+ Raises:
222
+ ValidationError: If the file cannot be read, parsed, or is invalid.
223
+ """
224
+ path = Path(config_path)
225
+ config_dict = load_config_from_file(path)
226
+ return build_extraction_config_from_dict(config_dict)
227
+
228
+
229
+ def discover_and_load_config(start_path: Path | str | None = None) -> ExtractionConfig:
230
+ """Load configuration by discovering config files in the directory tree.
231
+
232
+ Args:
233
+ start_path: Directory to start searching from. Defaults to current working directory.
234
+
235
+ Returns:
236
+ ExtractionConfig instance.
237
+
238
+ Raises:
239
+ ValidationError: If no configuration file is found or if the file is invalid.
240
+ """
241
+ search_path = Path(start_path) if start_path else None
242
+ config_path = find_config_file(search_path)
243
+
244
+ if not config_path:
245
+ raise ValidationError(
246
+ "No configuration file found. Searched for 'kreuzberg.toml' and 'pyproject.toml' with [tool.kreuzberg] section.",
247
+ context={"search_path": str(search_path or Path.cwd())},
248
+ )
249
+
250
+ config_dict = load_config_from_file(config_path)
251
+ if not config_dict:
252
+ raise ValidationError(
253
+ f"Configuration file found but contains no Kreuzberg configuration: {config_path}",
254
+ context={"config_path": str(config_path)},
255
+ )
256
+
257
+ return build_extraction_config_from_dict(config_dict)
258
+
259
+
260
+ def try_discover_config(start_path: Path | str | None = None) -> ExtractionConfig | None:
261
+ """Try to discover and load configuration, returning None if not found.
262
+
263
+ Args:
264
+ start_path: Directory to start searching from. Defaults to current working directory.
265
+
266
+ Returns:
267
+ ExtractionConfig instance or None if no configuration found.
268
+ """
269
+ try:
270
+ return discover_and_load_config(start_path)
271
+ except ValidationError:
272
+ return None
273
+
274
+
275
+ # Legacy functions for backward compatibility with CLI
276
+
277
+ # Define common configuration fields to avoid repetition
278
+ _CONFIG_FIELDS = [
279
+ "force_ocr",
280
+ "chunk_content",
281
+ "extract_tables",
282
+ "max_chars",
283
+ "max_overlap",
284
+ "ocr_backend",
285
+ "extract_entities",
286
+ "extract_keywords",
287
+ "auto_detect_language",
288
+ "enable_quality_processing",
289
+ ]
290
+
291
+
292
+ def _merge_file_config(config_dict: dict[str, Any], file_config: dict[str, Any]) -> None:
293
+ """Merge file configuration into config dictionary."""
294
+ if not file_config:
295
+ return
296
+
297
+ for field in _CONFIG_FIELDS:
298
+ if field in file_config:
299
+ config_dict[field] = file_config[field]
300
+
301
+
302
+ def _merge_cli_args(config_dict: dict[str, Any], cli_args: MutableMapping[str, Any]) -> None:
303
+ """Merge CLI arguments into config dictionary."""
304
+ for field in _CONFIG_FIELDS:
305
+ if field in cli_args and cli_args[field] is not None:
306
+ config_dict[field] = cli_args[field]
307
+
308
+
309
+ def _build_ocr_config_from_cli(
310
+ ocr_backend: str, cli_args: MutableMapping[str, Any]
311
+ ) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig | None:
312
+ """Build OCR config from CLI arguments."""
313
+ config_key = f"{ocr_backend}_config"
314
+ if not cli_args.get(config_key):
315
+ return None
316
+
317
+ backend_args = cli_args[config_key]
318
+ if ocr_backend == "tesseract":
319
+ return TesseractConfig(**backend_args)
320
+ if ocr_backend == "easyocr":
321
+ return EasyOCRConfig(**backend_args)
322
+ if ocr_backend == "paddleocr":
323
+ return PaddleOCRConfig(**backend_args)
324
+ return None
325
+
326
+
327
+ def _configure_ocr_backend(
328
+ config_dict: dict[str, Any],
329
+ file_config: dict[str, Any],
330
+ cli_args: MutableMapping[str, Any],
331
+ ) -> None:
332
+ """Configure OCR backend in config dictionary."""
333
+ ocr_backend = config_dict.get("ocr_backend")
334
+ if not ocr_backend or ocr_backend == "none":
335
+ return
336
+
337
+ # Try CLI config first, then file config
338
+ ocr_config = _build_ocr_config_from_cli(ocr_backend, cli_args)
339
+ if not ocr_config and file_config:
340
+ ocr_config = parse_ocr_backend_config(file_config, ocr_backend)
341
+
342
+ if ocr_config:
343
+ config_dict["ocr_config"] = ocr_config
344
+
345
+
346
+ def _configure_gmft(
347
+ config_dict: dict[str, Any],
348
+ file_config: dict[str, Any],
349
+ cli_args: MutableMapping[str, Any],
350
+ ) -> None:
351
+ """Configure GMFT in config dictionary."""
352
+ if not config_dict.get("extract_tables"):
353
+ return
354
+
355
+ gmft_config = None
356
+ if cli_args.get("gmft_config"):
357
+ gmft_config = GMFTConfig(**cli_args["gmft_config"])
358
+ elif "gmft" in file_config and isinstance(file_config["gmft"], dict):
359
+ gmft_config = GMFTConfig(**file_config["gmft"])
360
+
361
+ if gmft_config:
362
+ config_dict["gmft_config"] = gmft_config
363
+
364
+
365
+ def build_extraction_config(
366
+ file_config: dict[str, Any],
367
+ cli_args: MutableMapping[str, Any],
368
+ ) -> ExtractionConfig:
369
+ """Build ExtractionConfig from file config and CLI arguments.
370
+
371
+ Args:
372
+ file_config: Configuration loaded from file.
373
+ cli_args: CLI arguments.
374
+
375
+ Returns:
376
+ ExtractionConfig instance.
377
+ """
378
+ config_dict: dict[str, Any] = {}
379
+
380
+ # Merge configurations: file first, then CLI overrides
381
+ _merge_file_config(config_dict, file_config)
382
+ _merge_cli_args(config_dict, cli_args)
383
+
384
+ # Configure complex components
385
+ _configure_ocr_backend(config_dict, file_config, cli_args)
386
+ _configure_gmft(config_dict, file_config, cli_args)
387
+
388
+ # Convert "none" to None for ocr_backend
389
+ if config_dict.get("ocr_backend") == "none":
390
+ config_dict["ocr_backend"] = None
391
+
392
+ return ExtractionConfig(**config_dict)
393
+
394
+
395
+ def find_default_config() -> Path | None:
396
+ """Find the default configuration file (pyproject.toml).
397
+
398
+ Returns:
399
+ Path to the configuration file or None if not found.
400
+
401
+ Note:
402
+ This function is deprecated. Use find_config_file() instead.
403
+ """
404
+ return find_config_file()
@@ -0,0 +1,156 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import TYPE_CHECKING
5
+
6
+ from kreuzberg._ocr import get_ocr_backend
7
+ from kreuzberg.exceptions import MissingDependencyError
8
+
9
+ if TYPE_CHECKING:
10
+ from pathlib import Path
11
+
12
+ from kreuzberg._types import ExtractionConfig, ExtractionResult
13
+
14
+
15
+ DOCUMENT_CLASSIFIERS = {
16
+ "invoice": [
17
+ r"invoice",
18
+ r"bill to",
19
+ r"invoice number",
20
+ r"total amount",
21
+ r"tax id",
22
+ ],
23
+ "receipt": [
24
+ r"receipt",
25
+ r"cash receipt",
26
+ r"payment",
27
+ r"subtotal",
28
+ r"total due",
29
+ ],
30
+ "contract": [
31
+ r"agreement",
32
+ r"contract",
33
+ r"party a",
34
+ r"party b",
35
+ r"terms and conditions",
36
+ r"signature",
37
+ ],
38
+ "report": [r"report", r"summary", r"analysis", r"findings", r"conclusion"],
39
+ "form": [r"form", r"fill out", r"signature", r"date", r"submit"],
40
+ }
41
+
42
+
43
+ def _get_translated_text(result: ExtractionResult) -> str:
44
+ """Translate extracted text to English using Google Translate API.
45
+
46
+ Args:
47
+ result: ExtractionResult containing the text to be translated
48
+
49
+ Returns:
50
+ str: The translated text in lowercase English
51
+
52
+ Raises:
53
+ MissingDependencyError: If the deep-translator package is not installed
54
+ """
55
+ try:
56
+ from deep_translator import GoogleTranslator # noqa: PLC0415
57
+ except ImportError as e:
58
+ raise MissingDependencyError(
59
+ "The 'deep-translator' library is not installed. Please install it with: pip install 'kreuzberg[auto-classify-document-type]'"
60
+ ) from e
61
+
62
+ return str(GoogleTranslator(source="auto", target="en").translate(result.content).lower())
63
+
64
+
65
+ def classify_document(result: ExtractionResult, config: ExtractionConfig) -> tuple[str | None, float | None]:
66
+ """Classifies the document type based on keywords and patterns.
67
+
68
+ Args:
69
+ result: The extraction result containing the content.
70
+ config: The extraction configuration.
71
+
72
+ Returns:
73
+ A tuple containing the detected document type and the confidence score,
74
+ or (None, None) if no type is detected with sufficient confidence.
75
+ """
76
+ translated_text = _get_translated_text(result)
77
+ scores = dict.fromkeys(DOCUMENT_CLASSIFIERS, 0)
78
+
79
+ for doc_type, patterns in DOCUMENT_CLASSIFIERS.items():
80
+ for pattern in patterns:
81
+ if re.search(pattern, translated_text):
82
+ scores[doc_type] += 1
83
+
84
+ total_score = sum(scores.values())
85
+ if total_score == 0:
86
+ return None, None
87
+
88
+ confidences = {doc_type: score / total_score for doc_type, score in scores.items()}
89
+
90
+ best_type, best_confidence = max(confidences.items(), key=lambda item: item[1])
91
+
92
+ if best_confidence >= config.document_type_confidence_threshold:
93
+ return best_type, best_confidence
94
+
95
+ return None, None
96
+
97
+
98
+ def classify_document_from_layout(
99
+ result: ExtractionResult, config: ExtractionConfig
100
+ ) -> tuple[str | None, float | None]:
101
+ """Classifies the document type based on layout information from OCR.
102
+
103
+ Args:
104
+ result: The extraction result containing the layout data.
105
+ config: The extraction configuration.
106
+
107
+ Returns:
108
+ A tuple containing the detected document type and the confidence score,
109
+ or (None, None) if no type is detected with sufficient confidence.
110
+ """
111
+ translated_text = _get_translated_text(result)
112
+
113
+ if result.layout is None or result.layout.empty:
114
+ return None, None
115
+
116
+ layout_df = result.layout
117
+ if not all(col in layout_df.columns for col in ["text", "top", "height"]):
118
+ return None, None
119
+
120
+ layout_df["translated_text"] = translated_text
121
+
122
+ page_height = layout_df["top"].max() + layout_df["height"].max()
123
+ scores = dict.fromkeys(DOCUMENT_CLASSIFIERS, 0.0)
124
+
125
+ for doc_type, patterns in DOCUMENT_CLASSIFIERS.items():
126
+ for pattern in patterns:
127
+ found_words = layout_df[layout_df["translated_text"].str.contains(pattern, case=False, na=False)]
128
+ if not found_words.empty:
129
+ scores[doc_type] += 1.0
130
+ word_top = found_words.iloc[0]["top"]
131
+ if word_top < page_height * 0.3:
132
+ scores[doc_type] += 0.5
133
+
134
+ total_score = sum(scores.values())
135
+ if total_score == 0:
136
+ return None, None
137
+
138
+ confidences = {doc_type: score / total_score for doc_type, score in scores.items()}
139
+
140
+ best_type, best_confidence = max(confidences.items(), key=lambda item: item[1])
141
+
142
+ if best_confidence >= config.document_type_confidence_threshold:
143
+ return best_type, best_confidence
144
+
145
+ return None, None
146
+
147
+
148
+ def auto_detect_document_type(
149
+ result: ExtractionResult, config: ExtractionConfig, file_path: Path | None = None
150
+ ) -> ExtractionResult:
151
+ if config.document_classification_mode == "vision" and file_path:
152
+ layout_result = get_ocr_backend("tesseract").process_file_sync(file_path, **config.get_config_dict())
153
+ result.document_type, result.document_type_confidence = classify_document_from_layout(layout_result, config)
154
+ else:
155
+ result.document_type, result.document_type_confidence = classify_document(result, config)
156
+ return result
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
14
14
  from pathlib import Path
15
15
 
16
16
 
17
- @dataclass(unsafe_hash=True, frozen=True)
17
+ @dataclass(unsafe_hash=True, frozen=True, slots=True)
18
18
  class SpacyEntityExtractionConfig:
19
19
  """Configuration for spaCy-based entity extraction."""
20
20
 
@@ -127,8 +127,8 @@ def extract_entities(
127
127
  """
128
128
  entities: list[Entity] = []
129
129
  if custom_patterns:
130
- custom_patterns_dict = dict(custom_patterns)
131
- for ent_type, pattern in custom_patterns_dict.items():
130
+ # Direct iteration over frozenset - no need to convert to dict
131
+ for ent_type, pattern in custom_patterns:
132
132
  entities.extend(
133
133
  Entity(type=ent_type, text=match.group(), start=match.start(), end=match.end())
134
134
  for match in re.finditer(pattern, text)
@@ -138,7 +138,7 @@ def extract_entities(
138
138
  spacy_config = SpacyEntityExtractionConfig()
139
139
 
140
140
  try:
141
- import spacy # noqa: F401
141
+ import spacy # noqa: F401, PLC0415
142
142
  except ImportError as e:
143
143
  raise MissingDependencyError.create_for_package(
144
144
  package_name="spacy",
@@ -179,7 +179,7 @@ def extract_entities(
179
179
  def _load_spacy_model(model_name: str, spacy_config: SpacyEntityExtractionConfig) -> Any:
180
180
  """Load a spaCy model with caching."""
181
181
  try:
182
- import spacy
182
+ import spacy # noqa: PLC0415
183
183
 
184
184
  if spacy_config.model_cache_dir:
185
185
  os.environ["SPACY_DATA"] = str(spacy_config.model_cache_dir)
@@ -223,7 +223,7 @@ def extract_keywords(
223
223
  MissingDependencyError: If `keybert` is not installed.
224
224
  """
225
225
  try:
226
- from keybert import KeyBERT
226
+ from keybert import KeyBERT # noqa: PLC0415
227
227
 
228
228
  kw_model = KeyBERT()
229
229
  keywords = kw_model.extract_keywords(text, top_n=keyword_count)