kreuzberg 3.8.1__py3-none-any.whl → 3.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +4 -0
- kreuzberg/_api/main.py +22 -1
- kreuzberg/_config.py +404 -0
- kreuzberg/_entity_extraction.py +3 -3
- kreuzberg/_extractors/_pdf.py +22 -19
- kreuzberg/_extractors/_spread_sheet.py +2 -3
- kreuzberg/_extractors/_structured.py +10 -7
- kreuzberg/_gmft.py +8 -11
- kreuzberg/_language_detection.py +1 -1
- kreuzberg/_mcp/server.py +58 -8
- kreuzberg/_ocr/_easyocr.py +1 -1
- kreuzberg/_ocr/_paddleocr.py +1 -1
- kreuzberg/_ocr/_tesseract.py +2 -7
- kreuzberg/_playa.py +2 -3
- kreuzberg/_types.py +46 -24
- kreuzberg/_utils/_cache.py +15 -17
- kreuzberg/_utils/_device.py +10 -20
- kreuzberg/_utils/_errors.py +41 -38
- kreuzberg/_utils/_quality.py +7 -11
- kreuzberg/_utils/_serialization.py +21 -16
- kreuzberg/_utils/_string.py +22 -12
- kreuzberg/_utils/_table.py +3 -4
- kreuzberg/cli.py +3 -3
- kreuzberg/exceptions.py +10 -0
- kreuzberg/extraction.py +2 -2
- kreuzberg-3.8.2.dist-info/METADATA +265 -0
- kreuzberg-3.8.2.dist-info/RECORD +53 -0
- kreuzberg/_cli_config.py +0 -175
- kreuzberg-3.8.1.dist-info/METADATA +0 -301
- kreuzberg-3.8.1.dist-info/RECORD +0 -53
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.8.2.dist-info}/WHEEL +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.8.2.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.8.1.dist-info → kreuzberg-3.8.2.dist-info}/licenses/LICENSE +0 -0
kreuzberg/_gmft.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import io
|
3
4
|
import multiprocessing as mp
|
4
5
|
import os
|
5
6
|
import queue
|
@@ -9,6 +10,8 @@ from dataclasses import dataclass, field
|
|
9
10
|
from io import StringIO
|
10
11
|
from typing import TYPE_CHECKING, Any, Literal
|
11
12
|
|
13
|
+
import msgspec
|
14
|
+
|
12
15
|
from kreuzberg._types import TableData
|
13
16
|
from kreuzberg._utils._sync import run_sync
|
14
17
|
from kreuzberg.exceptions import MissingDependencyError, ParsingError
|
@@ -20,7 +23,7 @@ if TYPE_CHECKING:
|
|
20
23
|
from pandas import DataFrame
|
21
24
|
|
22
25
|
|
23
|
-
@dataclass(unsafe_hash=True)
|
26
|
+
@dataclass(unsafe_hash=True, slots=True)
|
24
27
|
class GMFTConfig:
|
25
28
|
"""Configuration options for GMFT.
|
26
29
|
|
@@ -178,7 +181,7 @@ async def extract_tables( # noqa: PLR0915
|
|
178
181
|
cache_kwargs = {
|
179
182
|
"file_info": str(sorted(file_info.items())),
|
180
183
|
"extractor": "gmft",
|
181
|
-
"config": str(sorted(config.
|
184
|
+
"config": str(sorted(msgspec.to_builtins(config).items())),
|
182
185
|
}
|
183
186
|
|
184
187
|
table_cache = get_table_cache()
|
@@ -308,7 +311,7 @@ def extract_tables_sync(
|
|
308
311
|
cache_kwargs = {
|
309
312
|
"file_info": str(sorted(file_info.items())),
|
310
313
|
"extractor": "gmft",
|
311
|
-
"config": str(sorted(config.
|
314
|
+
"config": str(sorted(msgspec.to_builtins(config).items())),
|
312
315
|
}
|
313
316
|
|
314
317
|
table_cache = get_table_cache()
|
@@ -435,8 +438,6 @@ def _extract_tables_in_process(
|
|
435
438
|
|
436
439
|
results = []
|
437
440
|
for data_frame, cropped_table in zip(dataframes, cropped_tables, strict=False):
|
438
|
-
import io
|
439
|
-
|
440
441
|
img_bytes = io.BytesIO()
|
441
442
|
cropped_image = cropped_table.image()
|
442
443
|
cropped_image.save(img_bytes, format="PNG")
|
@@ -480,7 +481,7 @@ def _extract_tables_isolated(
|
|
480
481
|
RuntimeError: If extraction fails or times out
|
481
482
|
"""
|
482
483
|
config = config or GMFTConfig()
|
483
|
-
config_dict =
|
484
|
+
config_dict = msgspec.to_builtins(config)
|
484
485
|
|
485
486
|
ctx = mp.get_context("spawn")
|
486
487
|
result_queue = ctx.Queue()
|
@@ -528,8 +529,6 @@ def _extract_tables_isolated(
|
|
528
529
|
if success:
|
529
530
|
tables = []
|
530
531
|
for table_dict in result:
|
531
|
-
import io
|
532
|
-
|
533
532
|
from PIL import Image
|
534
533
|
|
535
534
|
img = Image.open(io.BytesIO(table_dict["cropped_image_bytes"]))
|
@@ -596,7 +595,7 @@ async def _extract_tables_isolated_async(
|
|
596
595
|
import anyio
|
597
596
|
|
598
597
|
config = config or GMFTConfig()
|
599
|
-
config_dict =
|
598
|
+
config_dict = msgspec.to_builtins(config)
|
600
599
|
|
601
600
|
ctx = mp.get_context("spawn")
|
602
601
|
result_queue = ctx.Queue()
|
@@ -640,8 +639,6 @@ async def _extract_tables_isolated_async(
|
|
640
639
|
if success:
|
641
640
|
tables = []
|
642
641
|
for table_dict in result:
|
643
|
-
import io
|
644
|
-
|
645
642
|
from PIL import Image
|
646
643
|
|
647
644
|
img = Image.open(io.BytesIO(table_dict["cropped_image_bytes"]))
|
kreuzberg/_language_detection.py
CHANGED
kreuzberg/_mcp/server.py
CHANGED
@@ -3,11 +3,14 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import base64
|
6
|
+
import json
|
6
7
|
from typing import Any
|
7
8
|
|
9
|
+
import msgspec
|
8
10
|
from mcp.server import FastMCP
|
9
11
|
from mcp.types import TextContent
|
10
12
|
|
13
|
+
from kreuzberg._config import try_discover_config
|
11
14
|
from kreuzberg._types import ExtractionConfig, OcrBackendType
|
12
15
|
from kreuzberg.extraction import extract_bytes_sync, extract_file_sync
|
13
16
|
|
@@ -15,6 +18,44 @@ from kreuzberg.extraction import extract_bytes_sync, extract_file_sync
|
|
15
18
|
mcp = FastMCP("Kreuzberg Text Extraction")
|
16
19
|
|
17
20
|
|
21
|
+
def _create_config_with_overrides(**kwargs: Any) -> ExtractionConfig:
|
22
|
+
"""Create ExtractionConfig with discovered config as base and tool parameters as overrides.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
**kwargs: Tool parameters to override defaults/discovered config.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
ExtractionConfig instance.
|
29
|
+
"""
|
30
|
+
# Try to discover configuration from files
|
31
|
+
base_config = try_discover_config()
|
32
|
+
|
33
|
+
if base_config is None:
|
34
|
+
# No config file found, use defaults
|
35
|
+
return ExtractionConfig(**kwargs)
|
36
|
+
|
37
|
+
# Merge discovered config with tool parameters (tool params take precedence)
|
38
|
+
config_dict: dict[str, Any] = {
|
39
|
+
"force_ocr": base_config.force_ocr,
|
40
|
+
"chunk_content": base_config.chunk_content,
|
41
|
+
"extract_tables": base_config.extract_tables,
|
42
|
+
"extract_entities": base_config.extract_entities,
|
43
|
+
"extract_keywords": base_config.extract_keywords,
|
44
|
+
"ocr_backend": base_config.ocr_backend,
|
45
|
+
"max_chars": base_config.max_chars,
|
46
|
+
"max_overlap": base_config.max_overlap,
|
47
|
+
"keyword_count": base_config.keyword_count,
|
48
|
+
"auto_detect_language": base_config.auto_detect_language,
|
49
|
+
"ocr_config": base_config.ocr_config,
|
50
|
+
"gmft_config": base_config.gmft_config,
|
51
|
+
}
|
52
|
+
|
53
|
+
# Override with provided parameters
|
54
|
+
config_dict.update(kwargs)
|
55
|
+
|
56
|
+
return ExtractionConfig(**config_dict)
|
57
|
+
|
58
|
+
|
18
59
|
@mcp.tool()
|
19
60
|
def extract_document( # noqa: PLR0913
|
20
61
|
file_path: str,
|
@@ -49,7 +90,7 @@ def extract_document( # noqa: PLR0913
|
|
49
90
|
Returns:
|
50
91
|
Extracted content with metadata, tables, chunks, entities, and keywords
|
51
92
|
"""
|
52
|
-
config =
|
93
|
+
config = _create_config_with_overrides(
|
53
94
|
force_ocr=force_ocr,
|
54
95
|
chunk_content=chunk_content,
|
55
96
|
extract_tables=extract_tables,
|
@@ -63,7 +104,7 @@ def extract_document( # noqa: PLR0913
|
|
63
104
|
)
|
64
105
|
|
65
106
|
result = extract_file_sync(file_path, mime_type, config)
|
66
|
-
return result.to_dict()
|
107
|
+
return result.to_dict(include_none=True)
|
67
108
|
|
68
109
|
|
69
110
|
@mcp.tool()
|
@@ -102,7 +143,7 @@ def extract_bytes( # noqa: PLR0913
|
|
102
143
|
"""
|
103
144
|
content_bytes = base64.b64decode(content_base64)
|
104
145
|
|
105
|
-
config =
|
146
|
+
config = _create_config_with_overrides(
|
106
147
|
force_ocr=force_ocr,
|
107
148
|
chunk_content=chunk_content,
|
108
149
|
extract_tables=extract_tables,
|
@@ -116,7 +157,7 @@ def extract_bytes( # noqa: PLR0913
|
|
116
157
|
)
|
117
158
|
|
118
159
|
result = extract_bytes_sync(content_bytes, mime_type, config)
|
119
|
-
return result.to_dict()
|
160
|
+
return result.to_dict(include_none=True)
|
120
161
|
|
121
162
|
|
122
163
|
@mcp.tool()
|
@@ -133,7 +174,7 @@ def extract_simple(
|
|
133
174
|
Returns:
|
134
175
|
Extracted text content as a string
|
135
176
|
"""
|
136
|
-
config =
|
177
|
+
config = _create_config_with_overrides()
|
137
178
|
result = extract_file_sync(file_path, mime_type, config)
|
138
179
|
return result.content
|
139
180
|
|
@@ -142,7 +183,16 @@ def extract_simple(
|
|
142
183
|
def get_default_config() -> str:
|
143
184
|
"""Get the default extraction configuration."""
|
144
185
|
config = ExtractionConfig()
|
145
|
-
return
|
186
|
+
return json.dumps(msgspec.to_builtins(config, order="deterministic"), indent=2)
|
187
|
+
|
188
|
+
|
189
|
+
@mcp.resource("config://discovered")
|
190
|
+
def get_discovered_config() -> str:
|
191
|
+
"""Get the discovered configuration from config files."""
|
192
|
+
config = try_discover_config()
|
193
|
+
if config is None:
|
194
|
+
return "No configuration file found"
|
195
|
+
return json.dumps(msgspec.to_builtins(config, order="deterministic"), indent=2)
|
146
196
|
|
147
197
|
|
148
198
|
@mcp.resource("config://available-backends")
|
@@ -175,7 +225,7 @@ def extract_and_summarize(file_path: str) -> list[TextContent]:
|
|
175
225
|
Returns:
|
176
226
|
Extracted content with summarization prompt
|
177
227
|
"""
|
178
|
-
result = extract_file_sync(file_path, None,
|
228
|
+
result = extract_file_sync(file_path, None, _create_config_with_overrides())
|
179
229
|
|
180
230
|
return [
|
181
231
|
TextContent(
|
@@ -195,7 +245,7 @@ def extract_structured(file_path: str) -> list[TextContent]:
|
|
195
245
|
Returns:
|
196
246
|
Extracted content with structured analysis prompt
|
197
247
|
"""
|
198
|
-
config =
|
248
|
+
config = _create_config_with_overrides(
|
199
249
|
extract_entities=True,
|
200
250
|
extract_keywords=True,
|
201
251
|
extract_tables=True,
|
kreuzberg/_ocr/_easyocr.py
CHANGED
kreuzberg/_ocr/_paddleocr.py
CHANGED
@@ -31,7 +31,7 @@ except ImportError: # pragma: no cover
|
|
31
31
|
PADDLEOCR_SUPPORTED_LANGUAGE_CODES: Final[set[str]] = {"ch", "en", "french", "german", "japan", "korean"}
|
32
32
|
|
33
33
|
|
34
|
-
@dataclass(unsafe_hash=True, frozen=True)
|
34
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
35
35
|
class PaddleOCRConfig:
|
36
36
|
"""Configuration options for PaddleOCR.
|
37
37
|
|
kreuzberg/_ocr/_tesseract.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import hashlib
|
4
|
+
import io
|
4
5
|
import os
|
5
6
|
import re
|
6
7
|
import subprocess
|
@@ -192,7 +193,7 @@ class PSMMode(Enum):
|
|
192
193
|
"""Treat the image as a single character."""
|
193
194
|
|
194
195
|
|
195
|
-
@dataclass(unsafe_hash=True, frozen=True)
|
196
|
+
@dataclass(unsafe_hash=True, frozen=True, slots=True)
|
196
197
|
class TesseractConfig:
|
197
198
|
"""Configuration options for Tesseract OCR engine."""
|
198
199
|
|
@@ -235,8 +236,6 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
235
236
|
image: Image,
|
236
237
|
**kwargs: Unpack[TesseractConfig],
|
237
238
|
) -> ExtractionResult:
|
238
|
-
import io
|
239
|
-
|
240
239
|
from kreuzberg._utils._cache import get_ocr_cache
|
241
240
|
|
242
241
|
image_buffer = io.BytesIO()
|
@@ -424,8 +423,6 @@ class TesseractBackend(OCRBackend[TesseractConfig]):
|
|
424
423
|
Returns:
|
425
424
|
The extraction result object
|
426
425
|
"""
|
427
|
-
import io
|
428
|
-
|
429
426
|
from kreuzberg._utils._cache import get_ocr_cache
|
430
427
|
|
431
428
|
image_buffer = io.BytesIO()
|
@@ -774,8 +771,6 @@ def _process_image_bytes_with_tesseract(
|
|
774
771
|
OCR result as dictionary.
|
775
772
|
"""
|
776
773
|
try:
|
777
|
-
import io
|
778
|
-
|
779
774
|
from PIL import Image
|
780
775
|
|
781
776
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_image:
|
kreuzberg/_playa.py
CHANGED
@@ -114,9 +114,8 @@ def _extract_keyword_metadata(pdf_info: dict[str, Any], result: Metadata) -> Non
|
|
114
114
|
if keywords := pdf_info.get("keywords"):
|
115
115
|
if isinstance(keywords, (str, bytes)):
|
116
116
|
kw_str = decode_text(keywords)
|
117
|
-
|
118
|
-
|
119
|
-
result["keywords"] = [k for k in kw_list if k]
|
117
|
+
# Combine multiple operations into a single comprehension
|
118
|
+
result["keywords"] = [k.strip() for part in kw_str.replace(";", ",").split(",") if (k := part.strip())]
|
120
119
|
elif isinstance(keywords, list):
|
121
120
|
result["keywords"] = [decode_text(k) for k in keywords]
|
122
121
|
|
kreuzberg/_types.py
CHANGED
@@ -5,7 +5,10 @@ from collections.abc import Awaitable, Callable
|
|
5
5
|
from dataclasses import asdict, dataclass, field
|
6
6
|
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
7
7
|
|
8
|
+
import msgspec
|
9
|
+
|
8
10
|
from kreuzberg._constants import DEFAULT_MAX_CHARACTERS, DEFAULT_MAX_OVERLAP
|
11
|
+
from kreuzberg._utils._table import export_table_to_csv, export_table_to_tsv, extract_table_structure_info
|
9
12
|
from kreuzberg.exceptions import ValidationError
|
10
13
|
|
11
14
|
if sys.version_info < (3, 11): # pragma: no cover
|
@@ -191,7 +194,7 @@ def normalize_metadata(data: dict[str, Any] | None) -> Metadata:
|
|
191
194
|
return normalized
|
192
195
|
|
193
196
|
|
194
|
-
@dataclass(frozen=True)
|
197
|
+
@dataclass(frozen=True, slots=True)
|
195
198
|
class Entity:
|
196
199
|
"""Represents an extracted entity with type, text, and position."""
|
197
200
|
|
@@ -205,7 +208,7 @@ class Entity:
|
|
205
208
|
"""End character offset in the content"""
|
206
209
|
|
207
210
|
|
208
|
-
@dataclass
|
211
|
+
@dataclass(slots=True)
|
209
212
|
class ExtractionResult:
|
210
213
|
"""The result of a file extraction."""
|
211
214
|
|
@@ -226,9 +229,29 @@ class ExtractionResult:
|
|
226
229
|
detected_languages: list[str] | None = None
|
227
230
|
"""Languages detected in the extracted content, if language detection is enabled."""
|
228
231
|
|
229
|
-
def to_dict(self) -> dict[str, Any]:
|
230
|
-
"""Converts the ExtractionResult to a dictionary.
|
231
|
-
|
232
|
+
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
233
|
+
"""Converts the ExtractionResult to a dictionary.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
include_none: If True, include fields with None values.
|
237
|
+
If False (default), exclude None values.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
Dictionary representation of the ExtractionResult.
|
241
|
+
"""
|
242
|
+
# Use msgspec.to_builtins for efficient conversion
|
243
|
+
# The builtin_types parameter allows DataFrames to pass through
|
244
|
+
result = msgspec.to_builtins(
|
245
|
+
self,
|
246
|
+
builtin_types=(type(None),), # Allow None to pass through
|
247
|
+
order="deterministic", # Ensure consistent output
|
248
|
+
)
|
249
|
+
|
250
|
+
if include_none:
|
251
|
+
return result # type: ignore[no-any-return]
|
252
|
+
|
253
|
+
# Remove None values to match expected behavior
|
254
|
+
return {k: v for k, v in result.items() if v is not None}
|
232
255
|
|
233
256
|
def export_tables_to_csv(self) -> list[str]:
|
234
257
|
"""Export all tables to CSV format.
|
@@ -239,8 +262,6 @@ class ExtractionResult:
|
|
239
262
|
if not self.tables:
|
240
263
|
return []
|
241
264
|
|
242
|
-
from kreuzberg._utils._table import export_table_to_csv
|
243
|
-
|
244
265
|
return [export_table_to_csv(table) for table in self.tables]
|
245
266
|
|
246
267
|
def export_tables_to_tsv(self) -> list[str]:
|
@@ -252,8 +273,6 @@ class ExtractionResult:
|
|
252
273
|
if not self.tables:
|
253
274
|
return []
|
254
275
|
|
255
|
-
from kreuzberg._utils._table import export_table_to_tsv
|
256
|
-
|
257
276
|
return [export_table_to_tsv(table) for table in self.tables]
|
258
277
|
|
259
278
|
def get_table_summaries(self) -> list[dict[str, Any]]:
|
@@ -265,8 +284,6 @@ class ExtractionResult:
|
|
265
284
|
if not self.tables:
|
266
285
|
return []
|
267
286
|
|
268
|
-
from kreuzberg._utils._table import extract_table_structure_info
|
269
|
-
|
270
287
|
return [extract_table_structure_info(table) for table in self.tables]
|
271
288
|
|
272
289
|
|
@@ -274,7 +291,7 @@ PostProcessingHook = Callable[[ExtractionResult], ExtractionResult | Awaitable[E
|
|
274
291
|
ValidationHook = Callable[[ExtractionResult], None | Awaitable[None]]
|
275
292
|
|
276
293
|
|
277
|
-
@dataclass(unsafe_hash=True)
|
294
|
+
@dataclass(unsafe_hash=True, slots=True)
|
278
295
|
class ExtractionConfig:
|
279
296
|
"""Represents configuration settings for an extraction process.
|
280
297
|
|
@@ -355,18 +372,23 @@ class ExtractionConfig:
|
|
355
372
|
Returns:
|
356
373
|
A dict of the OCR configuration or an empty dict if no backend is provided.
|
357
374
|
"""
|
358
|
-
if self.ocr_backend is
|
359
|
-
|
360
|
-
return asdict(self.ocr_config)
|
361
|
-
if self.ocr_backend == "tesseract":
|
362
|
-
from kreuzberg._ocr._tesseract import TesseractConfig
|
375
|
+
if self.ocr_backend is None:
|
376
|
+
return {}
|
363
377
|
|
364
|
-
|
365
|
-
|
366
|
-
|
378
|
+
if self.ocr_config is not None:
|
379
|
+
# Use asdict for OCR configs to preserve enum objects correctly
|
380
|
+
return asdict(self.ocr_config)
|
367
381
|
|
368
|
-
|
369
|
-
|
382
|
+
# Lazy load and cache default configs instead of creating new instances
|
383
|
+
if self.ocr_backend == "tesseract":
|
384
|
+
from kreuzberg._ocr._tesseract import TesseractConfig
|
370
385
|
|
371
|
-
return asdict(
|
372
|
-
|
386
|
+
return asdict(TesseractConfig())
|
387
|
+
if self.ocr_backend == "easyocr":
|
388
|
+
from kreuzberg._ocr._easyocr import EasyOCRConfig
|
389
|
+
|
390
|
+
return asdict(EasyOCRConfig())
|
391
|
+
# paddleocr
|
392
|
+
from kreuzberg._ocr._paddleocr import PaddleOCRConfig
|
393
|
+
|
394
|
+
return asdict(PaddleOCRConfig())
|
kreuzberg/_utils/_cache.py
CHANGED
@@ -64,11 +64,10 @@ class KreuzbergCache(Generic[T]):
|
|
64
64
|
Returns:
|
65
65
|
Unique cache key string
|
66
66
|
"""
|
67
|
-
# Use more efficient string building for cache key
|
68
67
|
if not kwargs:
|
69
68
|
return "empty"
|
70
69
|
|
71
|
-
# Build key
|
70
|
+
# Build cache key using list + join (faster than StringIO)
|
72
71
|
parts = []
|
73
72
|
for key in sorted(kwargs):
|
74
73
|
value = kwargs[key]
|
@@ -81,6 +80,7 @@ class KreuzbergCache(Generic[T]):
|
|
81
80
|
parts.append(f"{key}={type(value).__name__}:{value!s}")
|
82
81
|
|
83
82
|
cache_str = "&".join(parts)
|
83
|
+
# SHA256 is secure and fast enough for cache keys
|
84
84
|
return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
|
85
85
|
|
86
86
|
def _get_cache_path(self, cache_key: str) -> Path:
|
@@ -107,15 +107,14 @@ class KreuzbergCache(Generic[T]):
|
|
107
107
|
serialized_data = []
|
108
108
|
for item in result:
|
109
109
|
if isinstance(item, dict) and "df" in item:
|
110
|
-
#
|
111
|
-
|
110
|
+
# Build new dict without unnecessary copy
|
111
|
+
serialized_item = {k: v for k, v in item.items() if k != "df"}
|
112
112
|
if hasattr(item["df"], "to_csv"):
|
113
|
-
|
113
|
+
serialized_item["df_csv"] = item["df"].to_csv(index=False)
|
114
114
|
else:
|
115
115
|
# Fallback for non-DataFrame objects
|
116
|
-
|
117
|
-
|
118
|
-
serialized_data.append(item_copy)
|
116
|
+
serialized_item["df_csv"] = str(item["df"])
|
117
|
+
serialized_data.append(serialized_item)
|
119
118
|
else:
|
120
119
|
serialized_data.append(item)
|
121
120
|
return {"type": "TableDataList", "data": serialized_data, "cached_at": time.time()}
|
@@ -127,18 +126,17 @@ class KreuzbergCache(Generic[T]):
|
|
127
126
|
data = cached_data["data"]
|
128
127
|
|
129
128
|
if cached_data.get("type") == "TableDataList" and isinstance(data, list):
|
129
|
+
from io import StringIO
|
130
|
+
|
131
|
+
import pandas as pd
|
132
|
+
|
130
133
|
deserialized_data = []
|
131
134
|
for item in data:
|
132
135
|
if isinstance(item, dict) and "df_csv" in item:
|
133
|
-
#
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
import pandas as pd
|
138
|
-
|
139
|
-
item_copy["df"] = pd.read_csv(StringIO(item["df_csv"]))
|
140
|
-
del item_copy["df_csv"]
|
141
|
-
deserialized_data.append(item_copy)
|
136
|
+
# Build new dict without unnecessary copy
|
137
|
+
deserialized_item = {k: v for k, v in item.items() if k != "df_csv"}
|
138
|
+
deserialized_item["df"] = pd.read_csv(StringIO(item["df_csv"]))
|
139
|
+
deserialized_data.append(deserialized_item)
|
142
140
|
else:
|
143
141
|
deserialized_data.append(item)
|
144
142
|
return deserialized_data # type: ignore[return-value]
|
kreuzberg/_utils/_device.py
CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
5
5
|
|
6
6
|
import warnings
|
7
7
|
from dataclasses import dataclass
|
8
|
+
from itertools import chain
|
8
9
|
from typing import Literal
|
9
10
|
|
10
11
|
from kreuzberg.exceptions import ValidationError
|
@@ -12,7 +13,7 @@ from kreuzberg.exceptions import ValidationError
|
|
12
13
|
DeviceType = Literal["cpu", "cuda", "mps", "auto"]
|
13
14
|
|
14
15
|
|
15
|
-
@dataclass(frozen=True)
|
16
|
+
@dataclass(frozen=True, slots=True)
|
16
17
|
class DeviceInfo:
|
17
18
|
"""Information about a compute device."""
|
18
19
|
|
@@ -34,28 +35,17 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
34
35
|
Returns:
|
35
36
|
List of available devices, with the most preferred device first.
|
36
37
|
"""
|
37
|
-
|
38
|
-
|
39
|
-
devices.append(
|
40
|
-
DeviceInfo(
|
41
|
-
device_type="cpu",
|
42
|
-
name="CPU",
|
43
|
-
)
|
44
|
-
)
|
45
|
-
|
46
|
-
if _is_cuda_available():
|
47
|
-
cuda_devices = _get_cuda_devices()
|
48
|
-
devices.extend(cuda_devices)
|
38
|
+
# Build device lists efficiently using generators
|
39
|
+
cpu_device = DeviceInfo(device_type="cpu", name="CPU")
|
49
40
|
|
50
|
-
if
|
51
|
-
mps_device = _get_mps_device()
|
52
|
-
if mps_device:
|
53
|
-
devices.append(mps_device)
|
41
|
+
cuda_devices = _get_cuda_devices() if _is_cuda_available() else []
|
54
42
|
|
55
|
-
|
56
|
-
|
43
|
+
mps_device = _get_mps_device() if _is_mps_available() else None
|
44
|
+
mps_devices = [mps_device] if mps_device else []
|
57
45
|
|
58
|
-
|
46
|
+
# Return GPU devices first, then CPU using itertools.chain
|
47
|
+
gpu_devices = list(chain(cuda_devices, mps_devices))
|
48
|
+
return [*gpu_devices, cpu_device]
|
59
49
|
|
60
50
|
|
61
51
|
def get_optimal_device() -> DeviceInfo:
|
kreuzberg/_utils/_errors.py
CHANGED
@@ -12,6 +12,42 @@ import psutil
|
|
12
12
|
|
13
13
|
from kreuzberg.exceptions import ValidationError
|
14
14
|
|
15
|
+
# Define error keywords as frozensets for O(1) membership testing
|
16
|
+
_SYSTEM_ERROR_KEYWORDS = frozenset({"memory", "resource", "process", "thread"})
|
17
|
+
_TRANSIENT_ERROR_PATTERNS = frozenset(
|
18
|
+
{
|
19
|
+
"temporary",
|
20
|
+
"locked",
|
21
|
+
"in use",
|
22
|
+
"access denied",
|
23
|
+
"permission",
|
24
|
+
"timeout",
|
25
|
+
"connection",
|
26
|
+
"network",
|
27
|
+
"too many open files",
|
28
|
+
"cannot allocate memory",
|
29
|
+
"resource temporarily unavailable",
|
30
|
+
"broken pipe",
|
31
|
+
"subprocess",
|
32
|
+
"signal",
|
33
|
+
}
|
34
|
+
)
|
35
|
+
_RESOURCE_ERROR_PATTERNS = frozenset(
|
36
|
+
{
|
37
|
+
"memory",
|
38
|
+
"out of memory",
|
39
|
+
"cannot allocate",
|
40
|
+
"too many open files",
|
41
|
+
"file descriptor",
|
42
|
+
"resource",
|
43
|
+
"exhausted",
|
44
|
+
"limit",
|
45
|
+
"cpu",
|
46
|
+
"thread",
|
47
|
+
"process",
|
48
|
+
}
|
49
|
+
)
|
50
|
+
|
15
51
|
|
16
52
|
def create_error_context(
|
17
53
|
*,
|
@@ -52,11 +88,7 @@ def create_error_context(
|
|
52
88
|
"traceback": traceback.format_exception_only(type(error), error),
|
53
89
|
}
|
54
90
|
|
55
|
-
if (
|
56
|
-
any(keyword in str(error).lower() for keyword in ["memory", "resource", "process", "thread"])
|
57
|
-
if error
|
58
|
-
else False
|
59
|
-
):
|
91
|
+
if error and any(keyword in str(error).lower() for keyword in _SYSTEM_ERROR_KEYWORDS):
|
60
92
|
try:
|
61
93
|
mem = psutil.virtual_memory()
|
62
94
|
context["system"] = {
|
@@ -94,25 +126,8 @@ def is_transient_error(error: Exception) -> bool:
|
|
94
126
|
if isinstance(error, transient_types):
|
95
127
|
return True
|
96
128
|
|
97
|
-
transient_patterns = [
|
98
|
-
"temporary",
|
99
|
-
"locked",
|
100
|
-
"in use",
|
101
|
-
"access denied",
|
102
|
-
"permission",
|
103
|
-
"timeout",
|
104
|
-
"connection",
|
105
|
-
"network",
|
106
|
-
"too many open files",
|
107
|
-
"cannot allocate memory",
|
108
|
-
"resource temporarily unavailable",
|
109
|
-
"broken pipe",
|
110
|
-
"subprocess",
|
111
|
-
"signal",
|
112
|
-
]
|
113
|
-
|
114
129
|
error_str = str(error).lower()
|
115
|
-
return any(pattern in error_str for pattern in
|
130
|
+
return any(pattern in error_str for pattern in _TRANSIENT_ERROR_PATTERNS)
|
116
131
|
|
117
132
|
|
118
133
|
def is_resource_error(error: Exception) -> bool:
|
@@ -124,22 +139,8 @@ def is_resource_error(error: Exception) -> bool:
|
|
124
139
|
Returns:
|
125
140
|
True if the error is resource-related
|
126
141
|
"""
|
127
|
-
resource_patterns = [
|
128
|
-
"memory",
|
129
|
-
"out of memory",
|
130
|
-
"cannot allocate",
|
131
|
-
"too many open files",
|
132
|
-
"file descriptor",
|
133
|
-
"resource",
|
134
|
-
"exhausted",
|
135
|
-
"limit",
|
136
|
-
"cpu",
|
137
|
-
"thread",
|
138
|
-
"process",
|
139
|
-
]
|
140
|
-
|
141
142
|
error_str = str(error).lower()
|
142
|
-
return any(pattern in error_str for pattern in
|
143
|
+
return any(pattern in error_str for pattern in _RESOURCE_ERROR_PATTERNS)
|
143
144
|
|
144
145
|
|
145
146
|
def should_retry(error: Exception, attempt: int, max_attempts: int = 3) -> bool:
|
@@ -165,6 +166,8 @@ def should_retry(error: Exception, attempt: int, max_attempts: int = 3) -> bool:
|
|
165
166
|
class BatchExtractionResult:
|
166
167
|
"""Result container for batch operations with partial success support."""
|
167
168
|
|
169
|
+
__slots__ = ("failed", "successful", "total_count")
|
170
|
+
|
168
171
|
def __init__(self) -> None:
|
169
172
|
"""Initialize batch result container."""
|
170
173
|
self.successful: list[tuple[int, Any]] = []
|