kreuzberg 3.11.4__py3-none-any.whl → 3.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +14 -13
- kreuzberg/__main__.py +0 -2
- kreuzberg/_api/main.py +119 -9
- kreuzberg/_chunker.py +0 -15
- kreuzberg/_config.py +212 -292
- kreuzberg/_document_classification.py +20 -47
- kreuzberg/_entity_extraction.py +1 -122
- kreuzberg/_extractors/_base.py +4 -71
- kreuzberg/_extractors/_email.py +1 -15
- kreuzberg/_extractors/_html.py +9 -12
- kreuzberg/_extractors/_image.py +1 -25
- kreuzberg/_extractors/_pandoc.py +10 -147
- kreuzberg/_extractors/_pdf.py +38 -94
- kreuzberg/_extractors/_presentation.py +0 -99
- kreuzberg/_extractors/_spread_sheet.py +13 -55
- kreuzberg/_extractors/_structured.py +1 -4
- kreuzberg/_gmft.py +14 -199
- kreuzberg/_language_detection.py +1 -36
- kreuzberg/_mcp/__init__.py +0 -2
- kreuzberg/_mcp/server.py +3 -10
- kreuzberg/_mime_types.py +1 -19
- kreuzberg/_ocr/_base.py +4 -76
- kreuzberg/_ocr/_easyocr.py +124 -186
- kreuzberg/_ocr/_paddleocr.py +154 -224
- kreuzberg/_ocr/_table_extractor.py +184 -0
- kreuzberg/_ocr/_tesseract.py +797 -361
- kreuzberg/_playa.py +5 -31
- kreuzberg/_registry.py +0 -36
- kreuzberg/_types.py +588 -93
- kreuzberg/_utils/_cache.py +84 -138
- kreuzberg/_utils/_device.py +0 -74
- kreuzberg/_utils/_document_cache.py +0 -75
- kreuzberg/_utils/_errors.py +0 -50
- kreuzberg/_utils/_ocr_cache.py +136 -0
- kreuzberg/_utils/_pdf_lock.py +0 -16
- kreuzberg/_utils/_process_pool.py +17 -64
- kreuzberg/_utils/_quality.py +0 -60
- kreuzberg/_utils/_ref.py +32 -0
- kreuzberg/_utils/_serialization.py +0 -30
- kreuzberg/_utils/_string.py +9 -59
- kreuzberg/_utils/_sync.py +0 -77
- kreuzberg/_utils/_table.py +49 -101
- kreuzberg/_utils/_tmp.py +0 -9
- kreuzberg/cli.py +54 -74
- kreuzberg/extraction.py +39 -32
- {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/METADATA +19 -15
- kreuzberg-3.13.1.dist-info/RECORD +57 -0
- kreuzberg-3.11.4.dist-info/RECORD +0 -54
- {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/WHEEL +0 -0
- {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/entry_points.txt +0 -0
- {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/licenses/LICENSE +0 -0
kreuzberg/_utils/_cache.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
"""General-purpose file-based caching layer for Kreuzberg."""
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
|
5
3
|
import hashlib
|
@@ -14,6 +12,7 @@ from typing import Any, Generic, TypeVar
|
|
14
12
|
from anyio import Path as AsyncPath
|
15
13
|
|
16
14
|
from kreuzberg._types import ExtractionResult
|
15
|
+
from kreuzberg._utils._ref import Ref
|
17
16
|
from kreuzberg._utils._serialization import deserialize, serialize
|
18
17
|
from kreuzberg._utils._sync import run_sync
|
19
18
|
|
@@ -21,12 +20,6 @@ T = TypeVar("T")
|
|
21
20
|
|
22
21
|
|
23
22
|
class KreuzbergCache(Generic[T]):
|
24
|
-
"""File-based cache for Kreuzberg operations.
|
25
|
-
|
26
|
-
Provides both sync and async interfaces for caching extraction results,
|
27
|
-
OCR results, table data, and other expensive operations to disk.
|
28
|
-
"""
|
29
|
-
|
30
23
|
def __init__(
|
31
24
|
self,
|
32
25
|
cache_type: str,
|
@@ -34,14 +27,6 @@ class KreuzbergCache(Generic[T]):
|
|
34
27
|
max_cache_size_mb: float = 500.0,
|
35
28
|
max_age_days: int = 30,
|
36
29
|
) -> None:
|
37
|
-
"""Initialize cache.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
cache_type: Type of cache (e.g., 'ocr', 'tables', 'documents', 'mime')
|
41
|
-
cache_dir: Cache directory (defaults to .kreuzberg/{cache_type} in cwd)
|
42
|
-
max_cache_size_mb: Maximum cache size in MB (default: 500MB)
|
43
|
-
max_age_days: Maximum age of cached results in days (default: 30 days)
|
44
|
-
"""
|
45
30
|
if cache_dir is None:
|
46
31
|
cache_dir = Path.cwd() / ".kreuzberg" / cache_type
|
47
32
|
|
@@ -57,22 +42,12 @@ class KreuzbergCache(Generic[T]):
|
|
57
42
|
self._lock = threading.Lock()
|
58
43
|
|
59
44
|
def _get_cache_key(self, **kwargs: Any) -> str:
|
60
|
-
"""Generate cache key from kwargs.
|
61
|
-
|
62
|
-
Args:
|
63
|
-
**kwargs: Key-value pairs to generate cache key from
|
64
|
-
|
65
|
-
Returns:
|
66
|
-
Unique cache key string
|
67
|
-
"""
|
68
45
|
if not kwargs:
|
69
46
|
return "empty"
|
70
47
|
|
71
|
-
# Build cache key using list + join (faster than StringIO)
|
72
48
|
parts = []
|
73
49
|
for key in sorted(kwargs):
|
74
50
|
value = kwargs[key]
|
75
|
-
# Convert common types efficiently
|
76
51
|
if isinstance(value, (str, int, float, bool)):
|
77
52
|
parts.append(f"{key}={value}")
|
78
53
|
elif isinstance(value, bytes):
|
@@ -81,15 +56,12 @@ class KreuzbergCache(Generic[T]):
|
|
81
56
|
parts.append(f"{key}={type(value).__name__}:{value!s}")
|
82
57
|
|
83
58
|
cache_str = "&".join(parts)
|
84
|
-
# SHA256 is secure and fast enough for cache keys
|
85
59
|
return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
|
86
60
|
|
87
61
|
def _get_cache_path(self, cache_key: str) -> Path:
|
88
|
-
"""Get cache file path for key."""
|
89
62
|
return self.cache_dir / f"{cache_key}.msgpack"
|
90
63
|
|
91
64
|
def _is_cache_valid(self, cache_path: Path) -> bool:
|
92
|
-
"""Check if cached result is still valid."""
|
93
65
|
try:
|
94
66
|
if not cache_path.exists():
|
95
67
|
return False
|
@@ -102,18 +74,14 @@ class KreuzbergCache(Generic[T]):
|
|
102
74
|
return False
|
103
75
|
|
104
76
|
def _serialize_result(self, result: T) -> dict[str, Any]:
|
105
|
-
"""Serialize result for caching with metadata."""
|
106
|
-
# Handle TableData objects that contain DataFrames
|
107
77
|
if isinstance(result, list) and result and isinstance(result[0], dict) and "df" in result[0]:
|
108
78
|
serialized_data = []
|
109
79
|
for item in result:
|
110
80
|
if isinstance(item, dict) and "df" in item:
|
111
|
-
# Build new dict without unnecessary copy
|
112
81
|
serialized_item = {k: v for k, v in item.items() if k != "df"}
|
113
82
|
if hasattr(item["df"], "to_csv"):
|
114
83
|
serialized_item["df_csv"] = item["df"].to_csv(index=False)
|
115
84
|
else:
|
116
|
-
# Fallback for non-DataFrame objects
|
117
85
|
serialized_item["df_csv"] = str(item["df"])
|
118
86
|
serialized_data.append(serialized_item)
|
119
87
|
else:
|
@@ -123,7 +91,6 @@ class KreuzbergCache(Generic[T]):
|
|
123
91
|
return {"type": type(result).__name__, "data": result, "cached_at": time.time()}
|
124
92
|
|
125
93
|
def _deserialize_result(self, cached_data: dict[str, Any]) -> T:
|
126
|
-
"""Deserialize cached result."""
|
127
94
|
data = cached_data["data"]
|
128
95
|
|
129
96
|
if cached_data.get("type") == "TableDataList" and isinstance(data, list):
|
@@ -132,7 +99,6 @@ class KreuzbergCache(Generic[T]):
|
|
132
99
|
deserialized_data = []
|
133
100
|
for item in data:
|
134
101
|
if isinstance(item, dict) and "df_csv" in item:
|
135
|
-
# Build new dict without unnecessary copy
|
136
102
|
deserialized_item = {k: v for k, v in item.items() if k != "df_csv"}
|
137
103
|
deserialized_item["df"] = pd.read_csv(StringIO(item["df_csv"]))
|
138
104
|
deserialized_data.append(deserialized_item)
|
@@ -146,7 +112,6 @@ class KreuzbergCache(Generic[T]):
|
|
146
112
|
return data # type: ignore[no-any-return]
|
147
113
|
|
148
114
|
def _cleanup_cache(self) -> None:
|
149
|
-
"""Clean up old and oversized cache entries."""
|
150
115
|
try:
|
151
116
|
cache_files = list(self.cache_dir.glob("*.msgpack"))
|
152
117
|
|
@@ -180,14 +145,6 @@ class KreuzbergCache(Generic[T]):
|
|
180
145
|
pass
|
181
146
|
|
182
147
|
def get(self, **kwargs: Any) -> T | None:
|
183
|
-
"""Get cached result (sync).
|
184
|
-
|
185
|
-
Args:
|
186
|
-
**kwargs: Key-value pairs to generate cache key from
|
187
|
-
|
188
|
-
Returns:
|
189
|
-
Cached result if available, None otherwise
|
190
|
-
"""
|
191
148
|
cache_key = self._get_cache_key(**kwargs)
|
192
149
|
cache_path = self._get_cache_path(cache_key)
|
193
150
|
|
@@ -204,12 +161,6 @@ class KreuzbergCache(Generic[T]):
|
|
204
161
|
return None
|
205
162
|
|
206
163
|
def set(self, result: T, **kwargs: Any) -> None:
|
207
|
-
"""Cache result (sync).
|
208
|
-
|
209
|
-
Args:
|
210
|
-
result: Result to cache
|
211
|
-
**kwargs: Key-value pairs to generate cache key from
|
212
|
-
"""
|
213
164
|
cache_key = self._get_cache_key(**kwargs)
|
214
165
|
cache_path = self._get_cache_path(cache_key)
|
215
166
|
|
@@ -224,14 +175,6 @@ class KreuzbergCache(Generic[T]):
|
|
224
175
|
pass
|
225
176
|
|
226
177
|
async def aget(self, **kwargs: Any) -> T | None:
|
227
|
-
"""Get cached result (async).
|
228
|
-
|
229
|
-
Args:
|
230
|
-
**kwargs: Key-value pairs to generate cache key from
|
231
|
-
|
232
|
-
Returns:
|
233
|
-
Cached result if available, None otherwise
|
234
|
-
"""
|
235
178
|
cache_key = self._get_cache_key(**kwargs)
|
236
179
|
cache_path = AsyncPath(self._get_cache_path(cache_key))
|
237
180
|
|
@@ -248,12 +191,6 @@ class KreuzbergCache(Generic[T]):
|
|
248
191
|
return None
|
249
192
|
|
250
193
|
async def aset(self, result: T, **kwargs: Any) -> None:
|
251
|
-
"""Cache result (async).
|
252
|
-
|
253
|
-
Args:
|
254
|
-
result: Result to cache
|
255
|
-
**kwargs: Key-value pairs to generate cache key from
|
256
|
-
"""
|
257
194
|
cache_key = self._get_cache_key(**kwargs)
|
258
195
|
cache_path = AsyncPath(self._get_cache_path(cache_key))
|
259
196
|
|
@@ -268,13 +205,11 @@ class KreuzbergCache(Generic[T]):
|
|
268
205
|
pass
|
269
206
|
|
270
207
|
def is_processing(self, **kwargs: Any) -> bool:
|
271
|
-
"""Check if operation is currently being processed."""
|
272
208
|
cache_key = self._get_cache_key(**kwargs)
|
273
209
|
with self._lock:
|
274
210
|
return cache_key in self._processing
|
275
211
|
|
276
212
|
def mark_processing(self, **kwargs: Any) -> threading.Event:
|
277
|
-
"""Mark operation as being processed and return event to wait on."""
|
278
213
|
cache_key = self._get_cache_key(**kwargs)
|
279
214
|
|
280
215
|
with self._lock:
|
@@ -283,7 +218,6 @@ class KreuzbergCache(Generic[T]):
|
|
283
218
|
return self._processing[cache_key]
|
284
219
|
|
285
220
|
def mark_complete(self, **kwargs: Any) -> None:
|
286
|
-
"""Mark operation processing as complete."""
|
287
221
|
cache_key = self._get_cache_key(**kwargs)
|
288
222
|
|
289
223
|
with self._lock:
|
@@ -292,7 +226,6 @@ class KreuzbergCache(Generic[T]):
|
|
292
226
|
event.set()
|
293
227
|
|
294
228
|
def clear(self) -> None:
|
295
|
-
"""Clear all cached results."""
|
296
229
|
try:
|
297
230
|
for cache_file in self.cache_dir.glob("*.msgpack"):
|
298
231
|
cache_file.unlink(missing_ok=True)
|
@@ -303,7 +236,6 @@ class KreuzbergCache(Generic[T]):
|
|
303
236
|
pass
|
304
237
|
|
305
238
|
def get_stats(self) -> dict[str, Any]:
|
306
|
-
"""Get cache statistics."""
|
307
239
|
try:
|
308
240
|
cache_files = list(self.cache_dir.glob("*.msgpack"))
|
309
241
|
total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists())
|
@@ -331,87 +263,101 @@ class KreuzbergCache(Generic[T]):
|
|
331
263
|
}
|
332
264
|
|
333
265
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
266
|
+
def _create_ocr_cache() -> KreuzbergCache[ExtractionResult]:
|
267
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
268
|
+
cache_dir: Path | None = None
|
269
|
+
if cache_dir_str:
|
270
|
+
cache_dir = Path(cache_dir_str) / "ocr"
|
271
|
+
|
272
|
+
return KreuzbergCache[ExtractionResult](
|
273
|
+
cache_type="ocr",
|
274
|
+
cache_dir=cache_dir,
|
275
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_OCR_CACHE_SIZE_MB", "500")),
|
276
|
+
max_age_days=int(os.environ.get("KREUZBERG_OCR_CACHE_AGE_DAYS", "30")),
|
277
|
+
)
|
278
|
+
|
279
|
+
|
280
|
+
_ocr_cache_ref = Ref("ocr_cache", _create_ocr_cache)
|
338
281
|
|
339
282
|
|
340
283
|
def get_ocr_cache() -> KreuzbergCache[ExtractionResult]:
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
)
|
355
|
-
|
284
|
+
return _ocr_cache_ref.get()
|
285
|
+
|
286
|
+
|
287
|
+
def _create_document_cache() -> KreuzbergCache[ExtractionResult]:
|
288
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
289
|
+
cache_dir: Path | None = None
|
290
|
+
if cache_dir_str:
|
291
|
+
cache_dir = Path(cache_dir_str) / "documents"
|
292
|
+
|
293
|
+
return KreuzbergCache[ExtractionResult](
|
294
|
+
cache_type="documents",
|
295
|
+
cache_dir=cache_dir,
|
296
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_DOCUMENT_CACHE_SIZE_MB", "1000")),
|
297
|
+
max_age_days=int(os.environ.get("KREUZBERG_DOCUMENT_CACHE_AGE_DAYS", "7")),
|
298
|
+
)
|
299
|
+
|
300
|
+
|
301
|
+
_document_cache_ref = Ref("document_cache", _create_document_cache)
|
356
302
|
|
357
303
|
|
358
304
|
def get_document_cache() -> KreuzbergCache[ExtractionResult]:
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
)
|
373
|
-
|
305
|
+
return _document_cache_ref.get()
|
306
|
+
|
307
|
+
|
308
|
+
def _create_table_cache() -> KreuzbergCache[Any]:
|
309
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
310
|
+
cache_dir: Path | None = None
|
311
|
+
if cache_dir_str:
|
312
|
+
cache_dir = Path(cache_dir_str) / "tables"
|
313
|
+
|
314
|
+
return KreuzbergCache[Any](
|
315
|
+
cache_type="tables",
|
316
|
+
cache_dir=cache_dir,
|
317
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_TABLE_CACHE_SIZE_MB", "200")),
|
318
|
+
max_age_days=int(os.environ.get("KREUZBERG_TABLE_CACHE_AGE_DAYS", "30")),
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
_table_cache_ref = Ref("table_cache", _create_table_cache)
|
374
323
|
|
375
324
|
|
376
325
|
def get_table_cache() -> KreuzbergCache[Any]:
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
)
|
391
|
-
|
326
|
+
return _table_cache_ref.get()
|
327
|
+
|
328
|
+
|
329
|
+
def _create_mime_cache() -> KreuzbergCache[str]:
|
330
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
331
|
+
cache_dir: Path | None = None
|
332
|
+
if cache_dir_str:
|
333
|
+
cache_dir = Path(cache_dir_str) / "mime"
|
334
|
+
|
335
|
+
return KreuzbergCache[str](
|
336
|
+
cache_type="mime",
|
337
|
+
cache_dir=cache_dir,
|
338
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_MIME_CACHE_SIZE_MB", "50")),
|
339
|
+
max_age_days=int(os.environ.get("KREUZBERG_MIME_CACHE_AGE_DAYS", "60")),
|
340
|
+
)
|
341
|
+
|
342
|
+
|
343
|
+
_mime_cache_ref = Ref("mime_cache", _create_mime_cache)
|
392
344
|
|
393
345
|
|
394
346
|
def get_mime_cache() -> KreuzbergCache[str]:
|
395
|
-
|
396
|
-
global _mime_cache
|
397
|
-
if _mime_cache is None:
|
398
|
-
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
399
|
-
cache_dir: Path | None = None
|
400
|
-
if cache_dir_str:
|
401
|
-
cache_dir = Path(cache_dir_str) / "mime"
|
402
|
-
|
403
|
-
_mime_cache = KreuzbergCache[str](
|
404
|
-
cache_type="mime",
|
405
|
-
cache_dir=cache_dir,
|
406
|
-
max_cache_size_mb=float(os.environ.get("KREUZBERG_MIME_CACHE_SIZE_MB", "50")),
|
407
|
-
max_age_days=int(os.environ.get("KREUZBERG_MIME_CACHE_AGE_DAYS", "60")),
|
408
|
-
)
|
409
|
-
return _mime_cache
|
347
|
+
return _mime_cache_ref.get()
|
410
348
|
|
411
349
|
|
412
350
|
def clear_all_caches() -> None:
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
351
|
+
if _ocr_cache_ref.is_initialized():
|
352
|
+
get_ocr_cache().clear()
|
353
|
+
if _document_cache_ref.is_initialized():
|
354
|
+
get_document_cache().clear()
|
355
|
+
if _table_cache_ref.is_initialized():
|
356
|
+
get_table_cache().clear()
|
357
|
+
if _mime_cache_ref.is_initialized():
|
358
|
+
get_mime_cache().clear()
|
359
|
+
|
360
|
+
_ocr_cache_ref.clear()
|
361
|
+
_document_cache_ref.clear()
|
362
|
+
_table_cache_ref.clear()
|
363
|
+
_mime_cache_ref.clear()
|
kreuzberg/_utils/_device.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
"""Device detection and management utilities for GPU acceleration."""
|
2
1
|
# ruff: noqa: BLE001 # ~keep
|
3
2
|
|
4
3
|
from __future__ import annotations
|
@@ -15,8 +14,6 @@ DeviceType = Literal["cpu", "cuda", "mps", "auto"]
|
|
15
14
|
|
16
15
|
@dataclass(frozen=True, slots=True)
|
17
16
|
class DeviceInfo:
|
18
|
-
"""Information about a compute device."""
|
19
|
-
|
20
17
|
device_type: Literal["cpu", "cuda", "mps"]
|
21
18
|
"""The type of device."""
|
22
19
|
device_id: int | None = None
|
@@ -30,12 +27,6 @@ class DeviceInfo:
|
|
30
27
|
|
31
28
|
|
32
29
|
def detect_available_devices() -> list[DeviceInfo]:
|
33
|
-
"""Detect all available compute devices.
|
34
|
-
|
35
|
-
Returns:
|
36
|
-
List of available devices, with the most preferred device first.
|
37
|
-
"""
|
38
|
-
# Build device lists efficiently using generators
|
39
30
|
cpu_device = DeviceInfo(device_type="cpu", name="CPU")
|
40
31
|
|
41
32
|
cuda_devices = _get_cuda_devices() if _is_cuda_available() else []
|
@@ -43,17 +34,11 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
43
34
|
mps_device = _get_mps_device() if _is_mps_available() else None
|
44
35
|
mps_devices = [mps_device] if mps_device else []
|
45
36
|
|
46
|
-
# Return GPU devices first, then CPU using itertools.chain
|
47
37
|
gpu_devices = list(chain(cuda_devices, mps_devices))
|
48
38
|
return [*gpu_devices, cpu_device]
|
49
39
|
|
50
40
|
|
51
41
|
def get_optimal_device() -> DeviceInfo:
|
52
|
-
"""Get the optimal device for OCR processing.
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
The best available device, preferring GPU over CPU.
|
56
|
-
"""
|
57
42
|
devices = detect_available_devices()
|
58
43
|
return devices[0] if devices else DeviceInfo(device_type="cpu", name="CPU")
|
59
44
|
|
@@ -65,20 +50,6 @@ def validate_device_request(
|
|
65
50
|
memory_limit: float | None = None,
|
66
51
|
fallback_to_cpu: bool = True,
|
67
52
|
) -> DeviceInfo:
|
68
|
-
"""Validate and resolve a device request.
|
69
|
-
|
70
|
-
Args:
|
71
|
-
requested: The requested device type.
|
72
|
-
backend: Name of the OCR backend requesting the device.
|
73
|
-
memory_limit: Optional memory limit in GB.
|
74
|
-
fallback_to_cpu: Whether to fallback to CPU if requested device unavailable.
|
75
|
-
|
76
|
-
Returns:
|
77
|
-
A validated DeviceInfo object.
|
78
|
-
|
79
|
-
Raises:
|
80
|
-
ValidationError: If the requested device is not available and fallback is disabled.
|
81
|
-
"""
|
82
53
|
available_devices = detect_available_devices()
|
83
54
|
|
84
55
|
if requested == "auto":
|
@@ -118,14 +89,6 @@ def validate_device_request(
|
|
118
89
|
|
119
90
|
|
120
91
|
def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | None]:
|
121
|
-
"""Get memory information for a device.
|
122
|
-
|
123
|
-
Args:
|
124
|
-
device: The device to query.
|
125
|
-
|
126
|
-
Returns:
|
127
|
-
Tuple of (total_memory_gb, available_memory_gb). None values if unknown.
|
128
|
-
"""
|
129
92
|
if device.device_type == "cpu":
|
130
93
|
return None, None
|
131
94
|
|
@@ -139,7 +102,6 @@ def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | No
|
|
139
102
|
|
140
103
|
|
141
104
|
def _is_cuda_available() -> bool:
|
142
|
-
"""Check if CUDA is available."""
|
143
105
|
try:
|
144
106
|
import torch # type: ignore[import-not-found,unused-ignore] # noqa: PLC0415
|
145
107
|
|
@@ -149,7 +111,6 @@ def _is_cuda_available() -> bool:
|
|
149
111
|
|
150
112
|
|
151
113
|
def _is_mps_available() -> bool:
|
152
|
-
"""Check if MPS (Apple Silicon) is available."""
|
153
114
|
try:
|
154
115
|
import torch # type: ignore[import-not-found,unused-ignore] # noqa: PLC0415
|
155
116
|
|
@@ -159,7 +120,6 @@ def _is_mps_available() -> bool:
|
|
159
120
|
|
160
121
|
|
161
122
|
def _get_cuda_devices() -> list[DeviceInfo]:
|
162
|
-
"""Get information about available CUDA devices."""
|
163
123
|
devices: list[DeviceInfo] = []
|
164
124
|
|
165
125
|
try:
|
@@ -197,7 +157,6 @@ def _get_cuda_devices() -> list[DeviceInfo]:
|
|
197
157
|
|
198
158
|
|
199
159
|
def _get_mps_device() -> DeviceInfo | None:
|
200
|
-
"""Get information about the MPS device."""
|
201
160
|
try:
|
202
161
|
import torch # noqa: PLC0415
|
203
162
|
|
@@ -214,7 +173,6 @@ def _get_mps_device() -> DeviceInfo | None:
|
|
214
173
|
|
215
174
|
|
216
175
|
def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
|
217
|
-
"""Get CUDA memory information for a specific device."""
|
218
176
|
try:
|
219
177
|
import torch # noqa: PLC0415
|
220
178
|
|
@@ -237,20 +195,10 @@ def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
|
|
237
195
|
|
238
196
|
|
239
197
|
def _get_mps_memory_info() -> tuple[float | None, float | None]:
|
240
|
-
"""Get MPS memory information."""
|
241
198
|
return None, None
|
242
199
|
|
243
200
|
|
244
201
|
def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
|
245
|
-
"""Validate that a device has enough memory for the requested limit.
|
246
|
-
|
247
|
-
Args:
|
248
|
-
device: The device to validate.
|
249
|
-
memory_limit: Required memory in GB.
|
250
|
-
|
251
|
-
Raises:
|
252
|
-
ValidationError: If the device doesn't have enough memory.
|
253
|
-
"""
|
254
202
|
if device.device_type == "cpu":
|
255
203
|
# CPU memory validation is complex and OS-dependent, skip for now # ~keep
|
256
204
|
return
|
@@ -279,28 +227,11 @@ def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
|
|
279
227
|
|
280
228
|
|
281
229
|
def is_backend_gpu_compatible(backend: str) -> bool:
|
282
|
-
"""Check if an OCR backend supports GPU acceleration.
|
283
|
-
|
284
|
-
Args:
|
285
|
-
backend: Name of the OCR backend.
|
286
|
-
|
287
|
-
Returns:
|
288
|
-
True if the backend supports GPU acceleration.
|
289
|
-
"""
|
290
230
|
# EasyOCR and PaddleOCR support GPU, Tesseract does not # ~keep
|
291
231
|
return backend.lower() in ("easyocr", "paddleocr")
|
292
232
|
|
293
233
|
|
294
234
|
def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0) -> int:
|
295
|
-
"""Get recommended batch size for OCR processing.
|
296
|
-
|
297
|
-
Args:
|
298
|
-
device: The device to optimize for.
|
299
|
-
input_size_mb: Estimated input size per item in MB.
|
300
|
-
|
301
|
-
Returns:
|
302
|
-
Recommended batch size.
|
303
|
-
"""
|
304
235
|
if device.device_type == "cpu":
|
305
236
|
# Conservative batch size for CPU # ~keep
|
306
237
|
return 1
|
@@ -322,11 +253,6 @@ def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0)
|
|
322
253
|
|
323
254
|
|
324
255
|
def cleanup_device_memory(device: DeviceInfo) -> None:
|
325
|
-
"""Clean up device memory.
|
326
|
-
|
327
|
-
Args:
|
328
|
-
device: The device to clean up.
|
329
|
-
"""
|
330
256
|
if device.device_type == "cuda":
|
331
257
|
try:
|
332
258
|
import torch # noqa: PLC0415
|