kreuzberg 3.8.1__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +4 -0
- kreuzberg/_api/main.py +22 -1
- kreuzberg/_chunker.py +3 -3
- kreuzberg/_config.py +404 -0
- kreuzberg/_document_classification.py +156 -0
- kreuzberg/_entity_extraction.py +6 -6
- kreuzberg/_extractors/_image.py +4 -3
- kreuzberg/_extractors/_pdf.py +40 -29
- kreuzberg/_extractors/_spread_sheet.py +6 -8
- kreuzberg/_extractors/_structured.py +34 -25
- kreuzberg/_gmft.py +33 -42
- kreuzberg/_language_detection.py +1 -1
- kreuzberg/_mcp/server.py +58 -8
- kreuzberg/_mime_types.py +1 -1
- kreuzberg/_ocr/_base.py +1 -1
- kreuzberg/_ocr/_easyocr.py +5 -5
- kreuzberg/_ocr/_paddleocr.py +4 -4
- kreuzberg/_ocr/_tesseract.py +12 -21
- kreuzberg/_playa.py +2 -3
- kreuzberg/_types.py +65 -27
- kreuzberg/_utils/_cache.py +14 -17
- kreuzberg/_utils/_device.py +17 -27
- kreuzberg/_utils/_errors.py +41 -38
- kreuzberg/_utils/_quality.py +7 -11
- kreuzberg/_utils/_serialization.py +21 -16
- kreuzberg/_utils/_string.py +22 -12
- kreuzberg/_utils/_table.py +3 -4
- kreuzberg/cli.py +5 -5
- kreuzberg/exceptions.py +10 -0
- kreuzberg/extraction.py +20 -11
- kreuzberg-3.9.0.dist-info/METADATA +269 -0
- kreuzberg-3.9.0.dist-info/RECORD +54 -0
- kreuzberg/_cli_config.py +0 -175
- kreuzberg-3.8.1.dist-info/METADATA +0 -301
- kreuzberg-3.8.1.dist-info/RECORD +0 -53
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/WHEEL +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/licenses/LICENSE +0 -0
kreuzberg/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from importlib.metadata import version
|
2
2
|
|
3
|
+
from kreuzberg._config import discover_and_load_config, load_config_from_path, try_discover_config
|
3
4
|
from kreuzberg._entity_extraction import SpacyEntityExtractionConfig
|
4
5
|
from kreuzberg._gmft import GMFTConfig
|
5
6
|
from kreuzberg._language_detection import LanguageDetectionConfig
|
@@ -48,8 +49,11 @@ __all__ = [
|
|
48
49
|
"batch_extract_bytes_sync",
|
49
50
|
"batch_extract_file",
|
50
51
|
"batch_extract_file_sync",
|
52
|
+
"discover_and_load_config",
|
51
53
|
"extract_bytes",
|
52
54
|
"extract_bytes_sync",
|
53
55
|
"extract_file",
|
54
56
|
"extract_file_sync",
|
57
|
+
"load_config_from_path",
|
58
|
+
"try_discover_config",
|
55
59
|
]
|
kreuzberg/_api/main.py
CHANGED
@@ -3,7 +3,10 @@ from __future__ import annotations
|
|
3
3
|
from json import dumps
|
4
4
|
from typing import TYPE_CHECKING, Annotated, Any
|
5
5
|
|
6
|
+
import msgspec
|
7
|
+
|
6
8
|
from kreuzberg import (
|
9
|
+
ExtractionConfig,
|
7
10
|
ExtractionResult,
|
8
11
|
KreuzbergError,
|
9
12
|
MissingDependencyError,
|
@@ -11,6 +14,7 @@ from kreuzberg import (
|
|
11
14
|
ValidationError,
|
12
15
|
batch_extract_bytes,
|
13
16
|
)
|
17
|
+
from kreuzberg._config import try_discover_config
|
14
18
|
|
15
19
|
if TYPE_CHECKING:
|
16
20
|
from litestar.datastructures import UploadFile
|
@@ -66,8 +70,12 @@ async def handle_files_upload(
|
|
66
70
|
data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)],
|
67
71
|
) -> list[ExtractionResult]:
|
68
72
|
"""Extracts text content from an uploaded file."""
|
73
|
+
# Try to discover configuration from files
|
74
|
+
config = try_discover_config()
|
75
|
+
|
69
76
|
return await batch_extract_bytes(
|
70
77
|
[(await file.read(), file.content_type) for file in data],
|
78
|
+
config=config or ExtractionConfig(),
|
71
79
|
)
|
72
80
|
|
73
81
|
|
@@ -77,8 +85,21 @@ async def health_check() -> dict[str, str]:
|
|
77
85
|
return {"status": "ok"}
|
78
86
|
|
79
87
|
|
88
|
+
@get("/config", operation_id="GetConfiguration")
|
89
|
+
async def get_configuration() -> dict[str, Any]:
|
90
|
+
"""Get the current configuration."""
|
91
|
+
config = try_discover_config()
|
92
|
+
if config is None:
|
93
|
+
return {"message": "No configuration file found", "config": None}
|
94
|
+
|
95
|
+
return {
|
96
|
+
"message": "Configuration loaded successfully",
|
97
|
+
"config": msgspec.to_builtins(config, order="deterministic"),
|
98
|
+
}
|
99
|
+
|
100
|
+
|
80
101
|
app = Litestar(
|
81
|
-
route_handlers=[handle_files_upload, health_check],
|
102
|
+
route_handlers=[handle_files_upload, health_check, get_configuration],
|
82
103
|
plugins=[OpenTelemetryPlugin(OpenTelemetryConfig())],
|
83
104
|
logging_config=StructLoggingConfig(),
|
84
105
|
exception_handlers={
|
kreuzberg/_chunker.py
CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
|
-
from kreuzberg import MissingDependencyError
|
6
5
|
from kreuzberg._constants import DEFAULT_MAX_CHARACTERS, DEFAULT_MAX_OVERLAP
|
7
6
|
from kreuzberg._mime_types import MARKDOWN_MIME_TYPE
|
7
|
+
from kreuzberg.exceptions import MissingDependencyError
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from semantic_text_splitter import MarkdownSplitter, TextSplitter
|
@@ -36,11 +36,11 @@ def get_chunker(
|
|
36
36
|
if key not in _chunkers:
|
37
37
|
try:
|
38
38
|
if mime_type == MARKDOWN_MIME_TYPE:
|
39
|
-
from semantic_text_splitter import MarkdownSplitter
|
39
|
+
from semantic_text_splitter import MarkdownSplitter # noqa: PLC0415
|
40
40
|
|
41
41
|
_chunkers[key] = MarkdownSplitter(max_characters, overlap_characters)
|
42
42
|
else:
|
43
|
-
from semantic_text_splitter import TextSplitter
|
43
|
+
from semantic_text_splitter import TextSplitter # noqa: PLC0415
|
44
44
|
|
45
45
|
_chunkers[key] = TextSplitter(max_characters, overlap_characters)
|
46
46
|
except ImportError as e:
|
kreuzberg/_config.py
ADDED
@@ -0,0 +1,404 @@
|
|
1
|
+
"""Configuration discovery and loading for Kreuzberg.
|
2
|
+
|
3
|
+
This module provides configuration loading from both kreuzberg.toml and pyproject.toml files.
|
4
|
+
Configuration is automatically discovered by searching up the directory tree from the current
|
5
|
+
working directory.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
import sys
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import TYPE_CHECKING, Any
|
13
|
+
|
14
|
+
if sys.version_info >= (3, 11):
|
15
|
+
import tomllib
|
16
|
+
else:
|
17
|
+
import tomli as tomllib # type: ignore[import-not-found]
|
18
|
+
|
19
|
+
from kreuzberg._gmft import GMFTConfig
|
20
|
+
from kreuzberg._ocr._easyocr import EasyOCRConfig
|
21
|
+
from kreuzberg._ocr._paddleocr import PaddleOCRConfig
|
22
|
+
from kreuzberg._ocr._tesseract import TesseractConfig
|
23
|
+
from kreuzberg._types import ExtractionConfig, OcrBackendType
|
24
|
+
from kreuzberg.exceptions import ValidationError
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from collections.abc import MutableMapping
|
28
|
+
|
29
|
+
|
30
|
+
def load_config_from_file(config_path: Path) -> dict[str, Any]:
|
31
|
+
"""Load configuration from a TOML file.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
config_path: Path to the configuration file.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
Dictionary containing the loaded configuration.
|
38
|
+
|
39
|
+
Raises:
|
40
|
+
ValidationError: If the file cannot be read or parsed.
|
41
|
+
"""
|
42
|
+
try:
|
43
|
+
with config_path.open("rb") as f:
|
44
|
+
data = tomllib.load(f)
|
45
|
+
except FileNotFoundError as e:
|
46
|
+
raise ValidationError(f"Configuration file not found: {config_path}") from e
|
47
|
+
except tomllib.TOMLDecodeError as e:
|
48
|
+
raise ValidationError(f"Invalid TOML in configuration file: {e}") from e
|
49
|
+
|
50
|
+
# Handle both kreuzberg.toml (root level) and pyproject.toml ([tool.kreuzberg])
|
51
|
+
if config_path.name == "kreuzberg.toml":
|
52
|
+
return data # type: ignore[no-any-return]
|
53
|
+
return data.get("tool", {}).get("kreuzberg", {}) # type: ignore[no-any-return]
|
54
|
+
|
55
|
+
|
56
|
+
def merge_configs(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
57
|
+
"""Merge two configuration dictionaries recursively.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
base: Base configuration dictionary.
|
61
|
+
override: Configuration dictionary to override base values.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Merged configuration dictionary.
|
65
|
+
"""
|
66
|
+
result = base.copy()
|
67
|
+
for key, value in override.items():
|
68
|
+
if isinstance(value, dict) and key in result and isinstance(result[key], dict):
|
69
|
+
result[key] = merge_configs(result[key], value)
|
70
|
+
else:
|
71
|
+
result[key] = value
|
72
|
+
return result
|
73
|
+
|
74
|
+
|
75
|
+
def parse_ocr_backend_config(
|
76
|
+
config_dict: dict[str, Any], backend: OcrBackendType
|
77
|
+
) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig | None:
|
78
|
+
"""Parse OCR backend-specific configuration.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
config_dict: Configuration dictionary.
|
82
|
+
backend: The OCR backend type.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Backend-specific configuration object or None.
|
86
|
+
"""
|
87
|
+
if backend not in config_dict:
|
88
|
+
return None
|
89
|
+
|
90
|
+
backend_config = config_dict[backend]
|
91
|
+
if not isinstance(backend_config, dict):
|
92
|
+
return None
|
93
|
+
|
94
|
+
if backend == "tesseract":
|
95
|
+
# Convert psm integer to PSMMode enum if needed
|
96
|
+
processed_config = backend_config.copy()
|
97
|
+
if "psm" in processed_config and isinstance(processed_config["psm"], int):
|
98
|
+
from kreuzberg._ocr._tesseract import PSMMode # noqa: PLC0415
|
99
|
+
|
100
|
+
processed_config["psm"] = PSMMode(processed_config["psm"])
|
101
|
+
return TesseractConfig(**processed_config)
|
102
|
+
if backend == "easyocr":
|
103
|
+
return EasyOCRConfig(**backend_config)
|
104
|
+
if backend == "paddleocr":
|
105
|
+
return PaddleOCRConfig(**backend_config)
|
106
|
+
return None
|
107
|
+
|
108
|
+
|
109
|
+
def build_extraction_config_from_dict(config_dict: dict[str, Any]) -> ExtractionConfig:
|
110
|
+
"""Build ExtractionConfig from a configuration dictionary.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
config_dict: Configuration dictionary from TOML file.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
ExtractionConfig instance.
|
117
|
+
"""
|
118
|
+
extraction_config: dict[str, Any] = {}
|
119
|
+
|
120
|
+
# Copy basic configuration fields using dictionary comprehension
|
121
|
+
basic_fields = {
|
122
|
+
"force_ocr",
|
123
|
+
"chunk_content",
|
124
|
+
"extract_tables",
|
125
|
+
"max_chars",
|
126
|
+
"max_overlap",
|
127
|
+
"ocr_backend",
|
128
|
+
"extract_entities",
|
129
|
+
"extract_keywords",
|
130
|
+
"auto_detect_language",
|
131
|
+
"enable_quality_processing",
|
132
|
+
}
|
133
|
+
extraction_config.update({field: config_dict[field] for field in basic_fields if field in config_dict})
|
134
|
+
|
135
|
+
# Handle OCR backend configuration
|
136
|
+
ocr_backend = extraction_config.get("ocr_backend")
|
137
|
+
if ocr_backend and ocr_backend != "none":
|
138
|
+
ocr_config = parse_ocr_backend_config(config_dict, ocr_backend)
|
139
|
+
if ocr_config:
|
140
|
+
extraction_config["ocr_config"] = ocr_config
|
141
|
+
|
142
|
+
# Handle GMFT configuration for table extraction
|
143
|
+
if extraction_config.get("extract_tables") and "gmft" in config_dict and isinstance(config_dict["gmft"], dict):
|
144
|
+
extraction_config["gmft_config"] = GMFTConfig(**config_dict["gmft"])
|
145
|
+
|
146
|
+
# Convert "none" to None for ocr_backend
|
147
|
+
if extraction_config.get("ocr_backend") == "none":
|
148
|
+
extraction_config["ocr_backend"] = None
|
149
|
+
|
150
|
+
return ExtractionConfig(**extraction_config)
|
151
|
+
|
152
|
+
|
153
|
+
def find_config_file(start_path: Path | None = None) -> Path | None:
|
154
|
+
"""Find configuration file by searching up the directory tree.
|
155
|
+
|
156
|
+
Searches for configuration files in the following order:
|
157
|
+
1. kreuzberg.toml
|
158
|
+
2. pyproject.toml (with [tool.kreuzberg] section)
|
159
|
+
|
160
|
+
Args:
|
161
|
+
start_path: Directory to start searching from. Defaults to current working directory.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
Path to the configuration file or None if not found.
|
165
|
+
"""
|
166
|
+
current = start_path or Path.cwd()
|
167
|
+
|
168
|
+
while current != current.parent:
|
169
|
+
# First, look for kreuzberg.toml
|
170
|
+
kreuzberg_toml = current / "kreuzberg.toml"
|
171
|
+
if kreuzberg_toml.exists():
|
172
|
+
return kreuzberg_toml
|
173
|
+
|
174
|
+
# Then, look for pyproject.toml with [tool.kreuzberg] section
|
175
|
+
pyproject_toml = current / "pyproject.toml"
|
176
|
+
if pyproject_toml.exists():
|
177
|
+
try:
|
178
|
+
with pyproject_toml.open("rb") as f:
|
179
|
+
data = tomllib.load(f)
|
180
|
+
if "tool" in data and "kreuzberg" in data["tool"]:
|
181
|
+
return pyproject_toml
|
182
|
+
except Exception: # noqa: BLE001
|
183
|
+
pass
|
184
|
+
|
185
|
+
current = current.parent
|
186
|
+
return None
|
187
|
+
|
188
|
+
|
189
|
+
def load_default_config(start_path: Path | None = None) -> ExtractionConfig | None:
|
190
|
+
"""Load the default configuration from discovered config file.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
start_path: Directory to start searching from. Defaults to current working directory.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
ExtractionConfig instance or None if no configuration found.
|
197
|
+
"""
|
198
|
+
config_path = find_config_file(start_path)
|
199
|
+
if not config_path:
|
200
|
+
return None
|
201
|
+
|
202
|
+
try:
|
203
|
+
config_dict = load_config_from_file(config_path)
|
204
|
+
if not config_dict:
|
205
|
+
return None
|
206
|
+
return build_extraction_config_from_dict(config_dict)
|
207
|
+
except Exception: # noqa: BLE001
|
208
|
+
# Silently ignore configuration errors for default loading
|
209
|
+
return None
|
210
|
+
|
211
|
+
|
212
|
+
def load_config_from_path(config_path: Path | str) -> ExtractionConfig:
|
213
|
+
"""Load configuration from a specific file path.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
config_path: Path to the configuration file.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
ExtractionConfig instance.
|
220
|
+
|
221
|
+
Raises:
|
222
|
+
ValidationError: If the file cannot be read, parsed, or is invalid.
|
223
|
+
"""
|
224
|
+
path = Path(config_path)
|
225
|
+
config_dict = load_config_from_file(path)
|
226
|
+
return build_extraction_config_from_dict(config_dict)
|
227
|
+
|
228
|
+
|
229
|
+
def discover_and_load_config(start_path: Path | str | None = None) -> ExtractionConfig:
|
230
|
+
"""Load configuration by discovering config files in the directory tree.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
start_path: Directory to start searching from. Defaults to current working directory.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
ExtractionConfig instance.
|
237
|
+
|
238
|
+
Raises:
|
239
|
+
ValidationError: If no configuration file is found or if the file is invalid.
|
240
|
+
"""
|
241
|
+
search_path = Path(start_path) if start_path else None
|
242
|
+
config_path = find_config_file(search_path)
|
243
|
+
|
244
|
+
if not config_path:
|
245
|
+
raise ValidationError(
|
246
|
+
"No configuration file found. Searched for 'kreuzberg.toml' and 'pyproject.toml' with [tool.kreuzberg] section.",
|
247
|
+
context={"search_path": str(search_path or Path.cwd())},
|
248
|
+
)
|
249
|
+
|
250
|
+
config_dict = load_config_from_file(config_path)
|
251
|
+
if not config_dict:
|
252
|
+
raise ValidationError(
|
253
|
+
f"Configuration file found but contains no Kreuzberg configuration: {config_path}",
|
254
|
+
context={"config_path": str(config_path)},
|
255
|
+
)
|
256
|
+
|
257
|
+
return build_extraction_config_from_dict(config_dict)
|
258
|
+
|
259
|
+
|
260
|
+
def try_discover_config(start_path: Path | str | None = None) -> ExtractionConfig | None:
|
261
|
+
"""Try to discover and load configuration, returning None if not found.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
start_path: Directory to start searching from. Defaults to current working directory.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
ExtractionConfig instance or None if no configuration found.
|
268
|
+
"""
|
269
|
+
try:
|
270
|
+
return discover_and_load_config(start_path)
|
271
|
+
except ValidationError:
|
272
|
+
return None
|
273
|
+
|
274
|
+
|
275
|
+
# Legacy functions for backward compatibility with CLI
|
276
|
+
|
277
|
+
# Define common configuration fields to avoid repetition
|
278
|
+
_CONFIG_FIELDS = [
|
279
|
+
"force_ocr",
|
280
|
+
"chunk_content",
|
281
|
+
"extract_tables",
|
282
|
+
"max_chars",
|
283
|
+
"max_overlap",
|
284
|
+
"ocr_backend",
|
285
|
+
"extract_entities",
|
286
|
+
"extract_keywords",
|
287
|
+
"auto_detect_language",
|
288
|
+
"enable_quality_processing",
|
289
|
+
]
|
290
|
+
|
291
|
+
|
292
|
+
def _merge_file_config(config_dict: dict[str, Any], file_config: dict[str, Any]) -> None:
|
293
|
+
"""Merge file configuration into config dictionary."""
|
294
|
+
if not file_config:
|
295
|
+
return
|
296
|
+
|
297
|
+
for field in _CONFIG_FIELDS:
|
298
|
+
if field in file_config:
|
299
|
+
config_dict[field] = file_config[field]
|
300
|
+
|
301
|
+
|
302
|
+
def _merge_cli_args(config_dict: dict[str, Any], cli_args: MutableMapping[str, Any]) -> None:
|
303
|
+
"""Merge CLI arguments into config dictionary."""
|
304
|
+
for field in _CONFIG_FIELDS:
|
305
|
+
if field in cli_args and cli_args[field] is not None:
|
306
|
+
config_dict[field] = cli_args[field]
|
307
|
+
|
308
|
+
|
309
|
+
def _build_ocr_config_from_cli(
|
310
|
+
ocr_backend: str, cli_args: MutableMapping[str, Any]
|
311
|
+
) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig | None:
|
312
|
+
"""Build OCR config from CLI arguments."""
|
313
|
+
config_key = f"{ocr_backend}_config"
|
314
|
+
if not cli_args.get(config_key):
|
315
|
+
return None
|
316
|
+
|
317
|
+
backend_args = cli_args[config_key]
|
318
|
+
if ocr_backend == "tesseract":
|
319
|
+
return TesseractConfig(**backend_args)
|
320
|
+
if ocr_backend == "easyocr":
|
321
|
+
return EasyOCRConfig(**backend_args)
|
322
|
+
if ocr_backend == "paddleocr":
|
323
|
+
return PaddleOCRConfig(**backend_args)
|
324
|
+
return None
|
325
|
+
|
326
|
+
|
327
|
+
def _configure_ocr_backend(
|
328
|
+
config_dict: dict[str, Any],
|
329
|
+
file_config: dict[str, Any],
|
330
|
+
cli_args: MutableMapping[str, Any],
|
331
|
+
) -> None:
|
332
|
+
"""Configure OCR backend in config dictionary."""
|
333
|
+
ocr_backend = config_dict.get("ocr_backend")
|
334
|
+
if not ocr_backend or ocr_backend == "none":
|
335
|
+
return
|
336
|
+
|
337
|
+
# Try CLI config first, then file config
|
338
|
+
ocr_config = _build_ocr_config_from_cli(ocr_backend, cli_args)
|
339
|
+
if not ocr_config and file_config:
|
340
|
+
ocr_config = parse_ocr_backend_config(file_config, ocr_backend)
|
341
|
+
|
342
|
+
if ocr_config:
|
343
|
+
config_dict["ocr_config"] = ocr_config
|
344
|
+
|
345
|
+
|
346
|
+
def _configure_gmft(
|
347
|
+
config_dict: dict[str, Any],
|
348
|
+
file_config: dict[str, Any],
|
349
|
+
cli_args: MutableMapping[str, Any],
|
350
|
+
) -> None:
|
351
|
+
"""Configure GMFT in config dictionary."""
|
352
|
+
if not config_dict.get("extract_tables"):
|
353
|
+
return
|
354
|
+
|
355
|
+
gmft_config = None
|
356
|
+
if cli_args.get("gmft_config"):
|
357
|
+
gmft_config = GMFTConfig(**cli_args["gmft_config"])
|
358
|
+
elif "gmft" in file_config and isinstance(file_config["gmft"], dict):
|
359
|
+
gmft_config = GMFTConfig(**file_config["gmft"])
|
360
|
+
|
361
|
+
if gmft_config:
|
362
|
+
config_dict["gmft_config"] = gmft_config
|
363
|
+
|
364
|
+
|
365
|
+
def build_extraction_config(
|
366
|
+
file_config: dict[str, Any],
|
367
|
+
cli_args: MutableMapping[str, Any],
|
368
|
+
) -> ExtractionConfig:
|
369
|
+
"""Build ExtractionConfig from file config and CLI arguments.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
file_config: Configuration loaded from file.
|
373
|
+
cli_args: CLI arguments.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
ExtractionConfig instance.
|
377
|
+
"""
|
378
|
+
config_dict: dict[str, Any] = {}
|
379
|
+
|
380
|
+
# Merge configurations: file first, then CLI overrides
|
381
|
+
_merge_file_config(config_dict, file_config)
|
382
|
+
_merge_cli_args(config_dict, cli_args)
|
383
|
+
|
384
|
+
# Configure complex components
|
385
|
+
_configure_ocr_backend(config_dict, file_config, cli_args)
|
386
|
+
_configure_gmft(config_dict, file_config, cli_args)
|
387
|
+
|
388
|
+
# Convert "none" to None for ocr_backend
|
389
|
+
if config_dict.get("ocr_backend") == "none":
|
390
|
+
config_dict["ocr_backend"] = None
|
391
|
+
|
392
|
+
return ExtractionConfig(**config_dict)
|
393
|
+
|
394
|
+
|
395
|
+
def find_default_config() -> Path | None:
|
396
|
+
"""Find the default configuration file (pyproject.toml).
|
397
|
+
|
398
|
+
Returns:
|
399
|
+
Path to the configuration file or None if not found.
|
400
|
+
|
401
|
+
Note:
|
402
|
+
This function is deprecated. Use find_config_file() instead.
|
403
|
+
"""
|
404
|
+
return find_config_file()
|
@@ -0,0 +1,156 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import re
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
from kreuzberg._ocr import get_ocr_backend
|
7
|
+
from kreuzberg.exceptions import MissingDependencyError
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from kreuzberg._types import ExtractionConfig, ExtractionResult
|
13
|
+
|
14
|
+
|
15
|
+
DOCUMENT_CLASSIFIERS = {
|
16
|
+
"invoice": [
|
17
|
+
r"invoice",
|
18
|
+
r"bill to",
|
19
|
+
r"invoice number",
|
20
|
+
r"total amount",
|
21
|
+
r"tax id",
|
22
|
+
],
|
23
|
+
"receipt": [
|
24
|
+
r"receipt",
|
25
|
+
r"cash receipt",
|
26
|
+
r"payment",
|
27
|
+
r"subtotal",
|
28
|
+
r"total due",
|
29
|
+
],
|
30
|
+
"contract": [
|
31
|
+
r"agreement",
|
32
|
+
r"contract",
|
33
|
+
r"party a",
|
34
|
+
r"party b",
|
35
|
+
r"terms and conditions",
|
36
|
+
r"signature",
|
37
|
+
],
|
38
|
+
"report": [r"report", r"summary", r"analysis", r"findings", r"conclusion"],
|
39
|
+
"form": [r"form", r"fill out", r"signature", r"date", r"submit"],
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
def _get_translated_text(result: ExtractionResult) -> str:
|
44
|
+
"""Translate extracted text to English using Google Translate API.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
result: ExtractionResult containing the text to be translated
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
str: The translated text in lowercase English
|
51
|
+
|
52
|
+
Raises:
|
53
|
+
MissingDependencyError: If the deep-translator package is not installed
|
54
|
+
"""
|
55
|
+
try:
|
56
|
+
from deep_translator import GoogleTranslator # noqa: PLC0415
|
57
|
+
except ImportError as e:
|
58
|
+
raise MissingDependencyError(
|
59
|
+
"The 'deep-translator' library is not installed. Please install it with: pip install 'kreuzberg[auto-classify-document-type]'"
|
60
|
+
) from e
|
61
|
+
|
62
|
+
return str(GoogleTranslator(source="auto", target="en").translate(result.content).lower())
|
63
|
+
|
64
|
+
|
65
|
+
def classify_document(result: ExtractionResult, config: ExtractionConfig) -> tuple[str | None, float | None]:
|
66
|
+
"""Classifies the document type based on keywords and patterns.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
result: The extraction result containing the content.
|
70
|
+
config: The extraction configuration.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A tuple containing the detected document type and the confidence score,
|
74
|
+
or (None, None) if no type is detected with sufficient confidence.
|
75
|
+
"""
|
76
|
+
translated_text = _get_translated_text(result)
|
77
|
+
scores = dict.fromkeys(DOCUMENT_CLASSIFIERS, 0)
|
78
|
+
|
79
|
+
for doc_type, patterns in DOCUMENT_CLASSIFIERS.items():
|
80
|
+
for pattern in patterns:
|
81
|
+
if re.search(pattern, translated_text):
|
82
|
+
scores[doc_type] += 1
|
83
|
+
|
84
|
+
total_score = sum(scores.values())
|
85
|
+
if total_score == 0:
|
86
|
+
return None, None
|
87
|
+
|
88
|
+
confidences = {doc_type: score / total_score for doc_type, score in scores.items()}
|
89
|
+
|
90
|
+
best_type, best_confidence = max(confidences.items(), key=lambda item: item[1])
|
91
|
+
|
92
|
+
if best_confidence >= config.document_type_confidence_threshold:
|
93
|
+
return best_type, best_confidence
|
94
|
+
|
95
|
+
return None, None
|
96
|
+
|
97
|
+
|
98
|
+
def classify_document_from_layout(
|
99
|
+
result: ExtractionResult, config: ExtractionConfig
|
100
|
+
) -> tuple[str | None, float | None]:
|
101
|
+
"""Classifies the document type based on layout information from OCR.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
result: The extraction result containing the layout data.
|
105
|
+
config: The extraction configuration.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
A tuple containing the detected document type and the confidence score,
|
109
|
+
or (None, None) if no type is detected with sufficient confidence.
|
110
|
+
"""
|
111
|
+
translated_text = _get_translated_text(result)
|
112
|
+
|
113
|
+
if result.layout is None or result.layout.empty:
|
114
|
+
return None, None
|
115
|
+
|
116
|
+
layout_df = result.layout
|
117
|
+
if not all(col in layout_df.columns for col in ["text", "top", "height"]):
|
118
|
+
return None, None
|
119
|
+
|
120
|
+
layout_df["translated_text"] = translated_text
|
121
|
+
|
122
|
+
page_height = layout_df["top"].max() + layout_df["height"].max()
|
123
|
+
scores = dict.fromkeys(DOCUMENT_CLASSIFIERS, 0.0)
|
124
|
+
|
125
|
+
for doc_type, patterns in DOCUMENT_CLASSIFIERS.items():
|
126
|
+
for pattern in patterns:
|
127
|
+
found_words = layout_df[layout_df["translated_text"].str.contains(pattern, case=False, na=False)]
|
128
|
+
if not found_words.empty:
|
129
|
+
scores[doc_type] += 1.0
|
130
|
+
word_top = found_words.iloc[0]["top"]
|
131
|
+
if word_top < page_height * 0.3:
|
132
|
+
scores[doc_type] += 0.5
|
133
|
+
|
134
|
+
total_score = sum(scores.values())
|
135
|
+
if total_score == 0:
|
136
|
+
return None, None
|
137
|
+
|
138
|
+
confidences = {doc_type: score / total_score for doc_type, score in scores.items()}
|
139
|
+
|
140
|
+
best_type, best_confidence = max(confidences.items(), key=lambda item: item[1])
|
141
|
+
|
142
|
+
if best_confidence >= config.document_type_confidence_threshold:
|
143
|
+
return best_type, best_confidence
|
144
|
+
|
145
|
+
return None, None
|
146
|
+
|
147
|
+
|
148
|
+
def auto_detect_document_type(
|
149
|
+
result: ExtractionResult, config: ExtractionConfig, file_path: Path | None = None
|
150
|
+
) -> ExtractionResult:
|
151
|
+
if config.document_classification_mode == "vision" and file_path:
|
152
|
+
layout_result = get_ocr_backend("tesseract").process_file_sync(file_path, **config.get_config_dict())
|
153
|
+
result.document_type, result.document_type_confidence = classify_document_from_layout(layout_result, config)
|
154
|
+
else:
|
155
|
+
result.document_type, result.document_type_confidence = classify_document(result, config)
|
156
|
+
return result
|
kreuzberg/_entity_extraction.py
CHANGED
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|
14
14
|
from pathlib import Path
|
15
15
|
|
16
16
|
|
17
|
-
@dataclass(unsafe_hash=True, frozen=True)
|
17
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
18
18
|
class SpacyEntityExtractionConfig:
|
19
19
|
"""Configuration for spaCy-based entity extraction."""
|
20
20
|
|
@@ -127,8 +127,8 @@ def extract_entities(
|
|
127
127
|
"""
|
128
128
|
entities: list[Entity] = []
|
129
129
|
if custom_patterns:
|
130
|
-
|
131
|
-
for ent_type, pattern in
|
130
|
+
# Direct iteration over frozenset - no need to convert to dict
|
131
|
+
for ent_type, pattern in custom_patterns:
|
132
132
|
entities.extend(
|
133
133
|
Entity(type=ent_type, text=match.group(), start=match.start(), end=match.end())
|
134
134
|
for match in re.finditer(pattern, text)
|
@@ -138,7 +138,7 @@ def extract_entities(
|
|
138
138
|
spacy_config = SpacyEntityExtractionConfig()
|
139
139
|
|
140
140
|
try:
|
141
|
-
import spacy # noqa: F401
|
141
|
+
import spacy # noqa: F401, PLC0415
|
142
142
|
except ImportError as e:
|
143
143
|
raise MissingDependencyError.create_for_package(
|
144
144
|
package_name="spacy",
|
@@ -179,7 +179,7 @@ def extract_entities(
|
|
179
179
|
def _load_spacy_model(model_name: str, spacy_config: SpacyEntityExtractionConfig) -> Any:
|
180
180
|
"""Load a spaCy model with caching."""
|
181
181
|
try:
|
182
|
-
import spacy
|
182
|
+
import spacy # noqa: PLC0415
|
183
183
|
|
184
184
|
if spacy_config.model_cache_dir:
|
185
185
|
os.environ["SPACY_DATA"] = str(spacy_config.model_cache_dir)
|
@@ -223,7 +223,7 @@ def extract_keywords(
|
|
223
223
|
MissingDependencyError: If `keybert` is not installed.
|
224
224
|
"""
|
225
225
|
try:
|
226
|
-
from keybert import KeyBERT
|
226
|
+
from keybert import KeyBERT # noqa: PLC0415
|
227
227
|
|
228
228
|
kw_model = KeyBERT()
|
229
229
|
keywords = kw_model.extract_keywords(text, top_n=keyword_count)
|