kreuzberg 3.8.1__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +4 -0
- kreuzberg/_api/main.py +22 -1
- kreuzberg/_chunker.py +3 -3
- kreuzberg/_config.py +404 -0
- kreuzberg/_document_classification.py +156 -0
- kreuzberg/_entity_extraction.py +6 -6
- kreuzberg/_extractors/_image.py +4 -3
- kreuzberg/_extractors/_pdf.py +40 -29
- kreuzberg/_extractors/_spread_sheet.py +6 -8
- kreuzberg/_extractors/_structured.py +34 -25
- kreuzberg/_gmft.py +33 -42
- kreuzberg/_language_detection.py +1 -1
- kreuzberg/_mcp/server.py +58 -8
- kreuzberg/_mime_types.py +1 -1
- kreuzberg/_ocr/_base.py +1 -1
- kreuzberg/_ocr/_easyocr.py +5 -5
- kreuzberg/_ocr/_paddleocr.py +4 -4
- kreuzberg/_ocr/_tesseract.py +12 -21
- kreuzberg/_playa.py +2 -3
- kreuzberg/_types.py +65 -27
- kreuzberg/_utils/_cache.py +14 -17
- kreuzberg/_utils/_device.py +17 -27
- kreuzberg/_utils/_errors.py +41 -38
- kreuzberg/_utils/_quality.py +7 -11
- kreuzberg/_utils/_serialization.py +21 -16
- kreuzberg/_utils/_string.py +22 -12
- kreuzberg/_utils/_table.py +3 -4
- kreuzberg/cli.py +5 -5
- kreuzberg/exceptions.py +10 -0
- kreuzberg/extraction.py +20 -11
- kreuzberg-3.9.0.dist-info/METADATA +269 -0
- kreuzberg-3.9.0.dist-info/RECORD +54 -0
- kreuzberg/_cli_config.py +0 -175
- kreuzberg-3.8.1.dist-info/METADATA +0 -301
- kreuzberg-3.8.1.dist-info/RECORD +0 -53
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/WHEEL +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.9.0.dist-info}/licenses/LICENSE +0 -0
kreuzberg/_utils/_device.py
CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
5
5
|
|
6
6
|
import warnings
|
7
7
|
from dataclasses import dataclass
|
8
|
+
from itertools import chain
|
8
9
|
from typing import Literal
|
9
10
|
|
10
11
|
from kreuzberg.exceptions import ValidationError
|
@@ -12,7 +13,7 @@ from kreuzberg.exceptions import ValidationError
|
|
12
13
|
DeviceType = Literal["cpu", "cuda", "mps", "auto"]
|
13
14
|
|
14
15
|
|
15
|
-
@dataclass(frozen=True)
|
16
|
+
@dataclass(frozen=True, slots=True)
|
16
17
|
class DeviceInfo:
|
17
18
|
"""Information about a compute device."""
|
18
19
|
|
@@ -34,28 +35,17 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
34
35
|
Returns:
|
35
36
|
List of available devices, with the most preferred device first.
|
36
37
|
"""
|
37
|
-
|
38
|
-
|
39
|
-
devices.append(
|
40
|
-
DeviceInfo(
|
41
|
-
device_type="cpu",
|
42
|
-
name="CPU",
|
43
|
-
)
|
44
|
-
)
|
45
|
-
|
46
|
-
if _is_cuda_available():
|
47
|
-
cuda_devices = _get_cuda_devices()
|
48
|
-
devices.extend(cuda_devices)
|
38
|
+
# Build device lists efficiently using generators
|
39
|
+
cpu_device = DeviceInfo(device_type="cpu", name="CPU")
|
49
40
|
|
50
|
-
if
|
51
|
-
mps_device = _get_mps_device()
|
52
|
-
if mps_device:
|
53
|
-
devices.append(mps_device)
|
41
|
+
cuda_devices = _get_cuda_devices() if _is_cuda_available() else []
|
54
42
|
|
55
|
-
|
56
|
-
|
43
|
+
mps_device = _get_mps_device() if _is_mps_available() else None
|
44
|
+
mps_devices = [mps_device] if mps_device else []
|
57
45
|
|
58
|
-
|
46
|
+
# Return GPU devices first, then CPU using itertools.chain
|
47
|
+
gpu_devices = list(chain(cuda_devices, mps_devices))
|
48
|
+
return [*gpu_devices, cpu_device]
|
59
49
|
|
60
50
|
|
61
51
|
def get_optimal_device() -> DeviceInfo:
|
@@ -151,7 +141,7 @@ def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | No
|
|
151
141
|
def _is_cuda_available() -> bool:
|
152
142
|
"""Check if CUDA is available."""
|
153
143
|
try:
|
154
|
-
import torch # type: ignore[import-not-found,unused-ignore]
|
144
|
+
import torch # type: ignore[import-not-found,unused-ignore] # noqa: PLC0415
|
155
145
|
|
156
146
|
return bool(torch.cuda.is_available())
|
157
147
|
except ImportError:
|
@@ -161,7 +151,7 @@ def _is_cuda_available() -> bool:
|
|
161
151
|
def _is_mps_available() -> bool:
|
162
152
|
"""Check if MPS (Apple Silicon) is available."""
|
163
153
|
try:
|
164
|
-
import torch # type: ignore[import-not-found,unused-ignore]
|
154
|
+
import torch # type: ignore[import-not-found,unused-ignore] # noqa: PLC0415
|
165
155
|
|
166
156
|
return bool(torch.backends.mps.is_available())
|
167
157
|
except ImportError:
|
@@ -173,7 +163,7 @@ def _get_cuda_devices() -> list[DeviceInfo]:
|
|
173
163
|
devices: list[DeviceInfo] = []
|
174
164
|
|
175
165
|
try:
|
176
|
-
import torch
|
166
|
+
import torch # noqa: PLC0415
|
177
167
|
|
178
168
|
if not torch.cuda.is_available():
|
179
169
|
return devices
|
@@ -209,7 +199,7 @@ def _get_cuda_devices() -> list[DeviceInfo]:
|
|
209
199
|
def _get_mps_device() -> DeviceInfo | None:
|
210
200
|
"""Get information about the MPS device."""
|
211
201
|
try:
|
212
|
-
import torch
|
202
|
+
import torch # noqa: PLC0415
|
213
203
|
|
214
204
|
if not torch.backends.mps.is_available():
|
215
205
|
return None
|
@@ -226,7 +216,7 @@ def _get_mps_device() -> DeviceInfo | None:
|
|
226
216
|
def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
|
227
217
|
"""Get CUDA memory information for a specific device."""
|
228
218
|
try:
|
229
|
-
import torch
|
219
|
+
import torch # noqa: PLC0415
|
230
220
|
|
231
221
|
if not torch.cuda.is_available():
|
232
222
|
return None, None
|
@@ -339,7 +329,7 @@ def cleanup_device_memory(device: DeviceInfo) -> None:
|
|
339
329
|
"""
|
340
330
|
if device.device_type == "cuda":
|
341
331
|
try:
|
342
|
-
import torch
|
332
|
+
import torch # noqa: PLC0415
|
343
333
|
|
344
334
|
if torch.cuda.is_available():
|
345
335
|
torch.cuda.empty_cache()
|
@@ -348,7 +338,7 @@ def cleanup_device_memory(device: DeviceInfo) -> None:
|
|
348
338
|
|
349
339
|
elif device.device_type == "mps":
|
350
340
|
try:
|
351
|
-
import torch
|
341
|
+
import torch # noqa: PLC0415
|
352
342
|
|
353
343
|
if torch.backends.mps.is_available():
|
354
344
|
torch.mps.empty_cache()
|
kreuzberg/_utils/_errors.py
CHANGED
@@ -12,6 +12,42 @@ import psutil
|
|
12
12
|
|
13
13
|
from kreuzberg.exceptions import ValidationError
|
14
14
|
|
15
|
+
# Define error keywords as frozensets for O(1) membership testing
|
16
|
+
_SYSTEM_ERROR_KEYWORDS = frozenset({"memory", "resource", "process", "thread"})
|
17
|
+
_TRANSIENT_ERROR_PATTERNS = frozenset(
|
18
|
+
{
|
19
|
+
"temporary",
|
20
|
+
"locked",
|
21
|
+
"in use",
|
22
|
+
"access denied",
|
23
|
+
"permission",
|
24
|
+
"timeout",
|
25
|
+
"connection",
|
26
|
+
"network",
|
27
|
+
"too many open files",
|
28
|
+
"cannot allocate memory",
|
29
|
+
"resource temporarily unavailable",
|
30
|
+
"broken pipe",
|
31
|
+
"subprocess",
|
32
|
+
"signal",
|
33
|
+
}
|
34
|
+
)
|
35
|
+
_RESOURCE_ERROR_PATTERNS = frozenset(
|
36
|
+
{
|
37
|
+
"memory",
|
38
|
+
"out of memory",
|
39
|
+
"cannot allocate",
|
40
|
+
"too many open files",
|
41
|
+
"file descriptor",
|
42
|
+
"resource",
|
43
|
+
"exhausted",
|
44
|
+
"limit",
|
45
|
+
"cpu",
|
46
|
+
"thread",
|
47
|
+
"process",
|
48
|
+
}
|
49
|
+
)
|
50
|
+
|
15
51
|
|
16
52
|
def create_error_context(
|
17
53
|
*,
|
@@ -52,11 +88,7 @@ def create_error_context(
|
|
52
88
|
"traceback": traceback.format_exception_only(type(error), error),
|
53
89
|
}
|
54
90
|
|
55
|
-
if (
|
56
|
-
any(keyword in str(error).lower() for keyword in ["memory", "resource", "process", "thread"])
|
57
|
-
if error
|
58
|
-
else False
|
59
|
-
):
|
91
|
+
if error and any(keyword in str(error).lower() for keyword in _SYSTEM_ERROR_KEYWORDS):
|
60
92
|
try:
|
61
93
|
mem = psutil.virtual_memory()
|
62
94
|
context["system"] = {
|
@@ -94,25 +126,8 @@ def is_transient_error(error: Exception) -> bool:
|
|
94
126
|
if isinstance(error, transient_types):
|
95
127
|
return True
|
96
128
|
|
97
|
-
transient_patterns = [
|
98
|
-
"temporary",
|
99
|
-
"locked",
|
100
|
-
"in use",
|
101
|
-
"access denied",
|
102
|
-
"permission",
|
103
|
-
"timeout",
|
104
|
-
"connection",
|
105
|
-
"network",
|
106
|
-
"too many open files",
|
107
|
-
"cannot allocate memory",
|
108
|
-
"resource temporarily unavailable",
|
109
|
-
"broken pipe",
|
110
|
-
"subprocess",
|
111
|
-
"signal",
|
112
|
-
]
|
113
|
-
|
114
129
|
error_str = str(error).lower()
|
115
|
-
return any(pattern in error_str for pattern in
|
130
|
+
return any(pattern in error_str for pattern in _TRANSIENT_ERROR_PATTERNS)
|
116
131
|
|
117
132
|
|
118
133
|
def is_resource_error(error: Exception) -> bool:
|
@@ -124,22 +139,8 @@ def is_resource_error(error: Exception) -> bool:
|
|
124
139
|
Returns:
|
125
140
|
True if the error is resource-related
|
126
141
|
"""
|
127
|
-
resource_patterns = [
|
128
|
-
"memory",
|
129
|
-
"out of memory",
|
130
|
-
"cannot allocate",
|
131
|
-
"too many open files",
|
132
|
-
"file descriptor",
|
133
|
-
"resource",
|
134
|
-
"exhausted",
|
135
|
-
"limit",
|
136
|
-
"cpu",
|
137
|
-
"thread",
|
138
|
-
"process",
|
139
|
-
]
|
140
|
-
|
141
142
|
error_str = str(error).lower()
|
142
|
-
return any(pattern in error_str for pattern in
|
143
|
+
return any(pattern in error_str for pattern in _RESOURCE_ERROR_PATTERNS)
|
143
144
|
|
144
145
|
|
145
146
|
def should_retry(error: Exception, attempt: int, max_attempts: int = 3) -> bool:
|
@@ -165,6 +166,8 @@ def should_retry(error: Exception, attempt: int, max_attempts: int = 3) -> bool:
|
|
165
166
|
class BatchExtractionResult:
|
166
167
|
"""Result container for batch operations with partial success support."""
|
167
168
|
|
169
|
+
__slots__ = ("failed", "successful", "total_count")
|
170
|
+
|
168
171
|
def __init__(self) -> None:
|
169
172
|
"""Initialize batch result container."""
|
170
173
|
self.successful: list[tuple[int, Any]] = []
|
kreuzberg/_utils/_quality.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import re
|
6
|
+
from functools import reduce
|
6
7
|
from typing import Any
|
7
8
|
|
8
9
|
# Pre-compiled patterns for performance
|
@@ -102,9 +103,8 @@ def clean_extracted_text(text: str) -> str:
|
|
102
103
|
if not text:
|
103
104
|
return text
|
104
105
|
|
105
|
-
# Remove script and style content
|
106
|
-
|
107
|
-
text = pattern.sub(" ", text)
|
106
|
+
# Remove script and style content using functools.reduce for single pass
|
107
|
+
text = reduce(lambda t, pattern: pattern.sub(" ", t), _SCRIPT_PATTERNS.values(), text)
|
108
108
|
|
109
109
|
# Clean OCR artifacts
|
110
110
|
text = _clean_ocr_artifacts(text)
|
@@ -134,10 +134,8 @@ def _calculate_script_penalty(text: str, total_chars: int) -> float:
|
|
134
134
|
if total_chars == 0:
|
135
135
|
return 0.0
|
136
136
|
|
137
|
-
|
138
|
-
for pattern in _SCRIPT_PATTERNS.values()
|
139
|
-
matches = pattern.findall(text)
|
140
|
-
script_chars += sum(len(match) for match in matches)
|
137
|
+
# Use sum with generator expression for single-pass calculation
|
138
|
+
script_chars = sum(len(match) for pattern in _SCRIPT_PATTERNS.values() for match in pattern.findall(text))
|
141
139
|
|
142
140
|
return min(1.0, script_chars / total_chars)
|
143
141
|
|
@@ -147,10 +145,8 @@ def _calculate_navigation_penalty(text: str, total_chars: int) -> float:
|
|
147
145
|
if total_chars == 0:
|
148
146
|
return 0.0
|
149
147
|
|
150
|
-
|
151
|
-
for pattern in _NAVIGATION_PATTERNS.values()
|
152
|
-
matches = pattern.findall(text)
|
153
|
-
nav_chars += sum(len(match) for match in matches)
|
148
|
+
# Use sum with generator expression for single-pass calculation
|
149
|
+
nav_chars = sum(len(match) for pattern in _NAVIGATION_PATTERNS.values() for match in pattern.findall(text))
|
154
150
|
|
155
151
|
return min(1.0, nav_chars / total_chars)
|
156
152
|
|
@@ -2,16 +2,28 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
-
from dataclasses import
|
6
|
-
from enum import Enum
|
5
|
+
from dataclasses import is_dataclass
|
7
6
|
from typing import Any, TypeVar, cast
|
8
7
|
|
8
|
+
import msgspec
|
9
9
|
from msgspec import MsgspecError
|
10
10
|
from msgspec.msgpack import decode, encode
|
11
11
|
|
12
12
|
T = TypeVar("T")
|
13
13
|
|
14
14
|
|
15
|
+
# Define dict method names in priority order
|
16
|
+
_DICT_METHOD_NAMES = (
|
17
|
+
"to_dict",
|
18
|
+
"as_dict",
|
19
|
+
"dict",
|
20
|
+
"model_dump",
|
21
|
+
"json",
|
22
|
+
"to_list",
|
23
|
+
"tolist",
|
24
|
+
)
|
25
|
+
|
26
|
+
|
15
27
|
def encode_hook(obj: Any) -> Any:
|
16
28
|
"""Custom encoder for complex objects."""
|
17
29
|
if callable(obj):
|
@@ -20,22 +32,15 @@ def encode_hook(obj: Any) -> Any:
|
|
20
32
|
if isinstance(obj, Exception):
|
21
33
|
return {"message": str(obj), "type": type(obj).__name__}
|
22
34
|
|
23
|
-
for
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
"json",
|
29
|
-
"to_list",
|
30
|
-
"tolist",
|
31
|
-
):
|
32
|
-
if hasattr(obj, key):
|
33
|
-
method = getattr(obj, key) # Cache the attribute lookup
|
34
|
-
if callable(method):
|
35
|
-
return method()
|
35
|
+
# Check for dict-like methods more efficiently using any() with generator
|
36
|
+
for attr_name in _DICT_METHOD_NAMES:
|
37
|
+
method = getattr(obj, attr_name, None)
|
38
|
+
if method is not None and callable(method):
|
39
|
+
return method()
|
36
40
|
|
37
41
|
if is_dataclass(obj) and not isinstance(obj, type):
|
38
|
-
|
42
|
+
# Use msgspec.to_builtins for more efficient conversion
|
43
|
+
return msgspec.to_builtins(obj)
|
39
44
|
|
40
45
|
if hasattr(obj, "save") and hasattr(obj, "format"):
|
41
46
|
return None
|
kreuzberg/_utils/_string.py
CHANGED
@@ -28,6 +28,7 @@ _encoding_cache: dict[str, str] = {}
|
|
28
28
|
@lru_cache(maxsize=128)
|
29
29
|
def _get_encoding_cache_key(data_hash: str, size: int) -> str:
|
30
30
|
"""Generate cache key for encoding detection."""
|
31
|
+
# Use string interpolation which is faster than format strings for simple cases
|
31
32
|
return f"{data_hash}:{size}"
|
32
33
|
|
33
34
|
|
@@ -104,25 +105,29 @@ def _calculate_text_confidence(text: str) -> float:
|
|
104
105
|
if not text:
|
105
106
|
return 0.0
|
106
107
|
|
107
|
-
# Check for common encoding problems
|
108
|
-
replacement_count = len(_MOJIBAKE_PATTERNS["replacement_chars"].findall(text))
|
109
|
-
control_count = len(_MOJIBAKE_PATTERNS["control_chars"].findall(text))
|
110
108
|
total_chars = len(text)
|
111
|
-
|
112
109
|
if total_chars == 0:
|
113
110
|
return 0.0
|
114
111
|
|
112
|
+
# Check for common encoding problems - compile patterns once
|
113
|
+
replacement_count = len(_MOJIBAKE_PATTERNS["replacement_chars"].findall(text))
|
114
|
+
control_count = len(_MOJIBAKE_PATTERNS["control_chars"].findall(text))
|
115
|
+
|
115
116
|
# Penalize replacement and control characters
|
116
117
|
penalty = (replacement_count + control_count * 2) / total_chars
|
117
118
|
|
118
|
-
# Bonus for readable character ranges
|
119
|
+
# Bonus for readable character ranges - more efficient counting
|
120
|
+
# Use generator expression with early termination
|
119
121
|
readable_chars = sum(1 for c in text if c.isprintable() or c.isspace())
|
120
122
|
readability_score = readable_chars / total_chars
|
121
123
|
|
122
124
|
# Check for suspicious Cyrillic that might be misencoded Hebrew
|
123
125
|
cyrillic_matches = _MOJIBAKE_PATTERNS["hebrew_as_cyrillic"].findall(text)
|
124
|
-
if cyrillic_matches
|
125
|
-
|
126
|
+
if cyrillic_matches:
|
127
|
+
# Calculate total length more efficiently
|
128
|
+
cyrillic_length = sum(len(match) for match in cyrillic_matches)
|
129
|
+
if cyrillic_length > total_chars * 0.1:
|
130
|
+
penalty += 0.3 # Heavy penalty for likely mojibake
|
126
131
|
|
127
132
|
return max(0.0, min(1.0, readability_score - penalty))
|
128
133
|
|
@@ -164,7 +169,8 @@ def normalize_spaces(text: str) -> str:
|
|
164
169
|
|
165
170
|
# Split by double newlines to preserve paragraph breaks
|
166
171
|
paragraphs = text.split("\n\n")
|
167
|
-
|
172
|
+
|
173
|
+
result_paragraphs = []
|
168
174
|
|
169
175
|
for paragraph in paragraphs:
|
170
176
|
# Use pre-compiled patterns for better performance
|
@@ -173,10 +179,14 @@ def normalize_spaces(text: str) -> str:
|
|
173
179
|
# Clean up multiple newlines within paragraph (keep single newlines)
|
174
180
|
cleaned = _NEWLINES_PATTERN.sub("\n", cleaned)
|
175
181
|
|
176
|
-
#
|
177
|
-
lines = [
|
182
|
+
# Process lines efficiently - manual loop avoids double strip() calls
|
183
|
+
lines = []
|
184
|
+
for line in cleaned.split("\n"):
|
185
|
+
stripped_line = line.strip()
|
186
|
+
if stripped_line:
|
187
|
+
lines.append(stripped_line)
|
178
188
|
|
179
189
|
if lines:
|
180
|
-
|
190
|
+
result_paragraphs.append("\n".join(lines))
|
181
191
|
|
182
|
-
return "\n\n".join(
|
192
|
+
return "\n\n".join(result_paragraphs)
|
kreuzberg/_utils/_table.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import csv
|
6
|
-
from io import StringIO
|
7
6
|
from typing import TYPE_CHECKING, Any
|
8
7
|
|
9
8
|
if TYPE_CHECKING:
|
@@ -23,9 +22,9 @@ def export_table_to_csv(table: TableData, separator: str = ",") -> str:
|
|
23
22
|
if "df" not in table or table["df"] is None:
|
24
23
|
return ""
|
25
24
|
|
26
|
-
|
27
|
-
table["df"].to_csv(
|
28
|
-
return
|
25
|
+
# Use pandas to_csv() direct string return instead of StringIO
|
26
|
+
csv_output = table["df"].to_csv(sep=separator, index=False, quoting=csv.QUOTE_MINIMAL, lineterminator="\n")
|
27
|
+
return str(csv_output).strip()
|
29
28
|
|
30
29
|
|
31
30
|
def export_table_to_tsv(table: TableData) -> str:
|
kreuzberg/cli.py
CHANGED
@@ -18,7 +18,7 @@ except ImportError as e:
|
|
18
18
|
) from e
|
19
19
|
|
20
20
|
from kreuzberg import __version__, extract_bytes_sync, extract_file_sync
|
21
|
-
from kreuzberg.
|
21
|
+
from kreuzberg._config import build_extraction_config, find_config_file, load_config_from_file
|
22
22
|
from kreuzberg.exceptions import KreuzbergError, MissingDependencyError
|
23
23
|
|
24
24
|
DEFAULT_MAX_CHARACTERS = 4000
|
@@ -92,7 +92,7 @@ def _load_config(config: Path | None, verbose: bool) -> dict[str, Any]:
|
|
92
92
|
if config:
|
93
93
|
file_config = load_config_from_file(config)
|
94
94
|
else:
|
95
|
-
default_config =
|
95
|
+
default_config = find_config_file()
|
96
96
|
if default_config:
|
97
97
|
try:
|
98
98
|
file_config = load_config_from_file(default_config)
|
@@ -160,7 +160,7 @@ def _perform_extraction(file: Path | None, extraction_config: ExtractionConfig,
|
|
160
160
|
progress.add_task("Extracting text...", total=None)
|
161
161
|
|
162
162
|
try:
|
163
|
-
import magic # type: ignore[import-not-found]
|
163
|
+
import magic # type: ignore[import-not-found] # noqa: PLC0415
|
164
164
|
|
165
165
|
mime_type = magic.from_buffer(input_bytes, mime=True)
|
166
166
|
except ImportError:
|
@@ -260,7 +260,7 @@ def cli(ctx: click.Context) -> None:
|
|
260
260
|
@click.option("--paddleocr-languages", help="PaddleOCR language codes (comma-separated, e.g., 'en,german')")
|
261
261
|
@click.pass_context
|
262
262
|
def extract( # noqa: PLR0913
|
263
|
-
|
263
|
+
_: click.Context,
|
264
264
|
file: Path | None,
|
265
265
|
output: Path | None,
|
266
266
|
force_ocr: bool,
|
@@ -314,7 +314,7 @@ def extract( # noqa: PLR0913
|
|
314
314
|
def config(config: Path | None) -> None:
|
315
315
|
"""Show current configuration."""
|
316
316
|
try:
|
317
|
-
config_path = config or
|
317
|
+
config_path = config or find_config_file()
|
318
318
|
|
319
319
|
if config_path:
|
320
320
|
file_config = load_config_from_file(config_path)
|
kreuzberg/exceptions.py
CHANGED
@@ -7,6 +7,8 @@ from typing import Any
|
|
7
7
|
class KreuzbergError(Exception):
|
8
8
|
"""Base exception for all Kreuzberg errors."""
|
9
9
|
|
10
|
+
__slots__ = ("context",)
|
11
|
+
|
10
12
|
context: Any
|
11
13
|
"""The context of the error."""
|
12
14
|
|
@@ -43,14 +45,20 @@ class KreuzbergError(Exception):
|
|
43
45
|
class ParsingError(KreuzbergError):
|
44
46
|
"""Raised when a parsing error occurs."""
|
45
47
|
|
48
|
+
__slots__ = ()
|
49
|
+
|
46
50
|
|
47
51
|
class ValidationError(KreuzbergError):
|
48
52
|
"""Raised when a validation error occurs."""
|
49
53
|
|
54
|
+
__slots__ = ()
|
55
|
+
|
50
56
|
|
51
57
|
class MissingDependencyError(KreuzbergError):
|
52
58
|
"""Raised when a dependency is missing."""
|
53
59
|
|
60
|
+
__slots__ = ()
|
61
|
+
|
54
62
|
@classmethod
|
55
63
|
def create_for_package(
|
56
64
|
cls, *, dependency_group: str, functionality: str, package_name: str
|
@@ -79,3 +87,5 @@ class MissingDependencyError(KreuzbergError):
|
|
79
87
|
|
80
88
|
class OCRError(KreuzbergError):
|
81
89
|
"""Raised when an OCR error occurs."""
|
90
|
+
|
91
|
+
__slots__ = ()
|
kreuzberg/extraction.py
CHANGED
@@ -7,15 +7,15 @@ from typing import TYPE_CHECKING, Any, Final, cast
|
|
7
7
|
|
8
8
|
import anyio
|
9
9
|
|
10
|
-
from kreuzberg import ExtractionResult
|
11
10
|
from kreuzberg._chunker import get_chunker
|
11
|
+
from kreuzberg._document_classification import auto_detect_document_type
|
12
12
|
from kreuzberg._entity_extraction import extract_entities, extract_keywords
|
13
13
|
from kreuzberg._language_detection import detect_languages
|
14
14
|
from kreuzberg._mime_types import (
|
15
15
|
validate_mime_type,
|
16
16
|
)
|
17
17
|
from kreuzberg._registry import ExtractorRegistry
|
18
|
-
from kreuzberg._types import ExtractionConfig
|
18
|
+
from kreuzberg._types import ExtractionConfig, ExtractionResult
|
19
19
|
from kreuzberg._utils._document_cache import get_document_cache
|
20
20
|
from kreuzberg._utils._errors import create_error_context
|
21
21
|
from kreuzberg._utils._string import safe_decode
|
@@ -30,7 +30,9 @@ if TYPE_CHECKING:
|
|
30
30
|
DEFAULT_CONFIG: Final[ExtractionConfig] = ExtractionConfig()
|
31
31
|
|
32
32
|
|
33
|
-
def _validate_and_post_process_helper(
|
33
|
+
def _validate_and_post_process_helper(
|
34
|
+
result: ExtractionResult, config: ExtractionConfig, file_path: Path | None = None
|
35
|
+
) -> ExtractionResult:
|
34
36
|
if config.chunk_content:
|
35
37
|
result.chunks = _handle_chunk_content(
|
36
38
|
mime_type=result.mime_type,
|
@@ -62,14 +64,19 @@ def _validate_and_post_process_helper(result: ExtractionResult, config: Extracti
|
|
62
64
|
config=config.language_detection_config,
|
63
65
|
)
|
64
66
|
|
67
|
+
if config.auto_detect_document_type:
|
68
|
+
result = auto_detect_document_type(result, config, file_path=file_path)
|
69
|
+
|
65
70
|
return result
|
66
71
|
|
67
72
|
|
68
|
-
async def _validate_and_post_process_async(
|
73
|
+
async def _validate_and_post_process_async(
|
74
|
+
result: ExtractionResult, config: ExtractionConfig, file_path: Path | None = None
|
75
|
+
) -> ExtractionResult:
|
69
76
|
for validator in config.validators or []:
|
70
77
|
await run_maybe_sync(validator, result)
|
71
78
|
|
72
|
-
result = _validate_and_post_process_helper(result, config)
|
79
|
+
result = _validate_and_post_process_helper(result, config, file_path)
|
73
80
|
|
74
81
|
for post_processor in config.post_processing_hooks or []:
|
75
82
|
result = await run_maybe_sync(post_processor, result)
|
@@ -77,11 +84,13 @@ async def _validate_and_post_process_async(result: ExtractionResult, config: Ext
|
|
77
84
|
return result
|
78
85
|
|
79
86
|
|
80
|
-
def _validate_and_post_process_sync(
|
87
|
+
def _validate_and_post_process_sync(
|
88
|
+
result: ExtractionResult, config: ExtractionConfig, file_path: Path | None = None
|
89
|
+
) -> ExtractionResult:
|
81
90
|
for validator in config.validators or []:
|
82
91
|
run_sync_only(validator, result)
|
83
92
|
|
84
|
-
result = _validate_and_post_process_helper(result, config)
|
93
|
+
result = _validate_and_post_process_helper(result, config, file_path)
|
85
94
|
|
86
95
|
for post_processor in config.post_processing_hooks or []:
|
87
96
|
result = run_sync_only(post_processor, result)
|
@@ -172,7 +181,7 @@ async def extract_file(
|
|
172
181
|
metadata={},
|
173
182
|
)
|
174
183
|
|
175
|
-
result = await _validate_and_post_process_async(result=result, config=config)
|
184
|
+
result = await _validate_and_post_process_async(result=result, config=config, file_path=path)
|
176
185
|
|
177
186
|
cache.set(path, config, result)
|
178
187
|
|
@@ -357,7 +366,7 @@ def extract_file_sync(
|
|
357
366
|
metadata={},
|
358
367
|
)
|
359
368
|
|
360
|
-
result = _validate_and_post_process_sync(result=result, config=config)
|
369
|
+
result = _validate_and_post_process_sync(result=result, config=config, file_path=path)
|
361
370
|
|
362
371
|
cache.set(path, config, result)
|
363
372
|
|
@@ -460,8 +469,8 @@ def batch_extract_bytes_sync(
|
|
460
469
|
return (index, error_result)
|
461
470
|
|
462
471
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
463
|
-
|
464
|
-
future_to_index = {executor.submit(extract_single,
|
472
|
+
# Avoid creating intermediate list, use enumerate directly
|
473
|
+
future_to_index = {executor.submit(extract_single, (i, content)): i for i, content in enumerate(contents)}
|
465
474
|
|
466
475
|
results: list[ExtractionResult] = [None] * len(contents) # type: ignore[list-item]
|
467
476
|
for future in as_completed(future_to_index):
|