kreuzberg 3.2.0__py3-none-any.whl → 3.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kreuzberg/__init__.py +3 -0
- kreuzberg/__main__.py +8 -0
- kreuzberg/_api/__init__.py +0 -0
- kreuzberg/_api/main.py +87 -0
- kreuzberg/_cli_config.py +175 -0
- kreuzberg/_extractors/_image.py +39 -4
- kreuzberg/_extractors/_pandoc.py +158 -18
- kreuzberg/_extractors/_pdf.py +199 -19
- kreuzberg/_extractors/_presentation.py +1 -1
- kreuzberg/_extractors/_spread_sheet.py +65 -7
- kreuzberg/_gmft.py +222 -16
- kreuzberg/_mime_types.py +62 -16
- kreuzberg/_multiprocessing/__init__.py +6 -0
- kreuzberg/_multiprocessing/gmft_isolated.py +332 -0
- kreuzberg/_multiprocessing/process_manager.py +188 -0
- kreuzberg/_multiprocessing/sync_tesseract.py +261 -0
- kreuzberg/_multiprocessing/tesseract_pool.py +359 -0
- kreuzberg/_ocr/_easyocr.py +6 -12
- kreuzberg/_ocr/_paddleocr.py +15 -13
- kreuzberg/_ocr/_tesseract.py +136 -46
- kreuzberg/_playa.py +43 -0
- kreuzberg/_types.py +4 -0
- kreuzberg/_utils/_cache.py +372 -0
- kreuzberg/_utils/_device.py +10 -27
- kreuzberg/_utils/_document_cache.py +220 -0
- kreuzberg/_utils/_errors.py +232 -0
- kreuzberg/_utils/_pdf_lock.py +72 -0
- kreuzberg/_utils/_process_pool.py +100 -0
- kreuzberg/_utils/_serialization.py +82 -0
- kreuzberg/_utils/_string.py +1 -1
- kreuzberg/_utils/_sync.py +21 -0
- kreuzberg/cli.py +338 -0
- kreuzberg/extraction.py +247 -36
- kreuzberg-3.4.0.dist-info/METADATA +290 -0
- kreuzberg-3.4.0.dist-info/RECORD +50 -0
- {kreuzberg-3.2.0.dist-info → kreuzberg-3.4.0.dist-info}/WHEEL +1 -2
- kreuzberg-3.4.0.dist-info/entry_points.txt +2 -0
- kreuzberg-3.2.0.dist-info/METADATA +0 -166
- kreuzberg-3.2.0.dist-info/RECORD +0 -34
- kreuzberg-3.2.0.dist-info/top_level.txt +0 -1
- {kreuzberg-3.2.0.dist-info → kreuzberg-3.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,372 @@
|
|
1
|
+
"""General-purpose file-based caching layer for Kreuzberg."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import hashlib
|
6
|
+
import os
|
7
|
+
import threading
|
8
|
+
import time
|
9
|
+
from contextlib import suppress
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Generic, TypeVar
|
12
|
+
|
13
|
+
from anyio import Path as AsyncPath
|
14
|
+
|
15
|
+
from kreuzberg._types import ExtractionResult
|
16
|
+
from kreuzberg._utils._serialization import deserialize, serialize
|
17
|
+
from kreuzberg._utils._sync import run_sync
|
18
|
+
|
19
|
+
T = TypeVar("T")
|
20
|
+
|
21
|
+
|
22
|
+
class KreuzbergCache(Generic[T]):
|
23
|
+
"""File-based cache for Kreuzberg operations.
|
24
|
+
|
25
|
+
Provides both sync and async interfaces for caching extraction results,
|
26
|
+
OCR results, table data, and other expensive operations to disk.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
cache_type: str,
|
32
|
+
cache_dir: Path | str | None = None,
|
33
|
+
max_cache_size_mb: float = 500.0,
|
34
|
+
max_age_days: int = 30,
|
35
|
+
) -> None:
|
36
|
+
"""Initialize cache.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
cache_type: Type of cache (e.g., 'ocr', 'tables', 'documents', 'mime')
|
40
|
+
cache_dir: Cache directory (defaults to .kreuzberg/{cache_type} in cwd)
|
41
|
+
max_cache_size_mb: Maximum cache size in MB (default: 500MB)
|
42
|
+
max_age_days: Maximum age of cached results in days (default: 30 days)
|
43
|
+
"""
|
44
|
+
if cache_dir is None:
|
45
|
+
cache_dir = Path.cwd() / ".kreuzberg" / cache_type
|
46
|
+
|
47
|
+
self.cache_dir = Path(cache_dir)
|
48
|
+
self.cache_type = cache_type
|
49
|
+
self.max_cache_size_mb = max_cache_size_mb
|
50
|
+
self.max_age_days = max_age_days
|
51
|
+
|
52
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
53
|
+
|
54
|
+
# In-memory tracking of processing state (session-scoped) # ~keep
|
55
|
+
self._processing: dict[str, threading.Event] = {}
|
56
|
+
self._lock = threading.Lock()
|
57
|
+
|
58
|
+
def _get_cache_key(self, **kwargs: Any) -> str:
|
59
|
+
"""Generate cache key from kwargs.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
**kwargs: Key-value pairs to generate cache key from
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Unique cache key string
|
66
|
+
"""
|
67
|
+
# Sort for consistent hashing # ~keep
|
68
|
+
cache_str = str(sorted(kwargs.items()))
|
69
|
+
return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
|
70
|
+
|
71
|
+
def _get_cache_path(self, cache_key: str) -> Path:
|
72
|
+
"""Get cache file path for key."""
|
73
|
+
return self.cache_dir / f"{cache_key}.msgpack"
|
74
|
+
|
75
|
+
def _is_cache_valid(self, cache_path: Path) -> bool:
|
76
|
+
"""Check if cached result is still valid."""
|
77
|
+
try:
|
78
|
+
if not cache_path.exists():
|
79
|
+
return False
|
80
|
+
|
81
|
+
mtime = cache_path.stat().st_mtime
|
82
|
+
age_days = (time.time() - mtime) / (24 * 3600)
|
83
|
+
|
84
|
+
return age_days <= self.max_age_days
|
85
|
+
except OSError:
|
86
|
+
return False
|
87
|
+
|
88
|
+
def _serialize_result(self, result: T) -> dict[str, Any]:
|
89
|
+
"""Serialize result for caching with metadata."""
|
90
|
+
return {"type": type(result).__name__, "data": result, "cached_at": time.time()}
|
91
|
+
|
92
|
+
def _deserialize_result(self, cached_data: dict[str, Any]) -> T:
|
93
|
+
"""Deserialize cached result."""
|
94
|
+
data = cached_data["data"]
|
95
|
+
|
96
|
+
if cached_data.get("type") == "ExtractionResult" and isinstance(data, dict):
|
97
|
+
from kreuzberg._types import ExtractionResult
|
98
|
+
|
99
|
+
return ExtractionResult(**data) # type: ignore[return-value]
|
100
|
+
|
101
|
+
return data # type: ignore[no-any-return]
|
102
|
+
|
103
|
+
def _cleanup_cache(self) -> None:
|
104
|
+
"""Clean up old and oversized cache entries."""
|
105
|
+
try:
|
106
|
+
cache_files = list(self.cache_dir.glob("*.msgpack"))
|
107
|
+
|
108
|
+
cutoff_time = time.time() - (self.max_age_days * 24 * 3600)
|
109
|
+
for cache_file in cache_files[:]:
|
110
|
+
try:
|
111
|
+
if cache_file.stat().st_mtime < cutoff_time:
|
112
|
+
cache_file.unlink(missing_ok=True)
|
113
|
+
cache_files.remove(cache_file)
|
114
|
+
except OSError: # noqa: PERF203
|
115
|
+
continue
|
116
|
+
|
117
|
+
total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists()) / (
|
118
|
+
1024 * 1024
|
119
|
+
)
|
120
|
+
|
121
|
+
if total_size > self.max_cache_size_mb:
|
122
|
+
cache_files.sort(key=lambda f: f.stat().st_mtime if f.exists() else 0)
|
123
|
+
|
124
|
+
for cache_file in cache_files:
|
125
|
+
try:
|
126
|
+
size_mb = cache_file.stat().st_size / (1024 * 1024)
|
127
|
+
cache_file.unlink(missing_ok=True)
|
128
|
+
total_size -= size_mb
|
129
|
+
|
130
|
+
if total_size <= self.max_cache_size_mb * 0.8:
|
131
|
+
break
|
132
|
+
except OSError:
|
133
|
+
continue
|
134
|
+
except (OSError, ValueError, TypeError):
|
135
|
+
pass
|
136
|
+
|
137
|
+
def get(self, **kwargs: Any) -> T | None:
|
138
|
+
"""Get cached result (sync).
|
139
|
+
|
140
|
+
Args:
|
141
|
+
**kwargs: Key-value pairs to generate cache key from
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
Cached result if available, None otherwise
|
145
|
+
"""
|
146
|
+
cache_key = self._get_cache_key(**kwargs)
|
147
|
+
cache_path = self._get_cache_path(cache_key)
|
148
|
+
|
149
|
+
if not self._is_cache_valid(cache_path):
|
150
|
+
return None
|
151
|
+
|
152
|
+
try:
|
153
|
+
content = cache_path.read_bytes()
|
154
|
+
cached_data = deserialize(content, dict)
|
155
|
+
return self._deserialize_result(cached_data)
|
156
|
+
except (OSError, ValueError, KeyError):
|
157
|
+
with suppress(OSError):
|
158
|
+
cache_path.unlink(missing_ok=True)
|
159
|
+
return None
|
160
|
+
|
161
|
+
def set(self, result: T, **kwargs: Any) -> None:
|
162
|
+
"""Cache result (sync).
|
163
|
+
|
164
|
+
Args:
|
165
|
+
result: Result to cache
|
166
|
+
**kwargs: Key-value pairs to generate cache key from
|
167
|
+
"""
|
168
|
+
cache_key = self._get_cache_key(**kwargs)
|
169
|
+
cache_path = self._get_cache_path(cache_key)
|
170
|
+
|
171
|
+
try:
|
172
|
+
serialized = self._serialize_result(result)
|
173
|
+
content = serialize(serialized)
|
174
|
+
cache_path.write_bytes(content)
|
175
|
+
|
176
|
+
if hash(cache_key) % 100 == 0:
|
177
|
+
self._cleanup_cache()
|
178
|
+
except (OSError, TypeError, ValueError):
|
179
|
+
pass
|
180
|
+
|
181
|
+
async def aget(self, **kwargs: Any) -> T | None:
|
182
|
+
"""Get cached result (async).
|
183
|
+
|
184
|
+
Args:
|
185
|
+
**kwargs: Key-value pairs to generate cache key from
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Cached result if available, None otherwise
|
189
|
+
"""
|
190
|
+
cache_key = self._get_cache_key(**kwargs)
|
191
|
+
cache_path = AsyncPath(self._get_cache_path(cache_key))
|
192
|
+
|
193
|
+
if not await run_sync(self._is_cache_valid, Path(cache_path)):
|
194
|
+
return None
|
195
|
+
|
196
|
+
try:
|
197
|
+
content = await cache_path.read_bytes()
|
198
|
+
cached_data = deserialize(content, dict)
|
199
|
+
return self._deserialize_result(cached_data)
|
200
|
+
except (OSError, ValueError, KeyError):
|
201
|
+
with suppress(Exception):
|
202
|
+
await cache_path.unlink(missing_ok=True)
|
203
|
+
return None
|
204
|
+
|
205
|
+
async def aset(self, result: T, **kwargs: Any) -> None:
|
206
|
+
"""Cache result (async).
|
207
|
+
|
208
|
+
Args:
|
209
|
+
result: Result to cache
|
210
|
+
**kwargs: Key-value pairs to generate cache key from
|
211
|
+
"""
|
212
|
+
cache_key = self._get_cache_key(**kwargs)
|
213
|
+
cache_path = AsyncPath(self._get_cache_path(cache_key))
|
214
|
+
|
215
|
+
try:
|
216
|
+
serialized = self._serialize_result(result)
|
217
|
+
content = serialize(serialized)
|
218
|
+
await cache_path.write_bytes(content)
|
219
|
+
|
220
|
+
if hash(cache_key) % 100 == 0:
|
221
|
+
await run_sync(self._cleanup_cache)
|
222
|
+
except (OSError, TypeError, ValueError):
|
223
|
+
pass
|
224
|
+
|
225
|
+
def is_processing(self, **kwargs: Any) -> bool:
|
226
|
+
"""Check if operation is currently being processed."""
|
227
|
+
cache_key = self._get_cache_key(**kwargs)
|
228
|
+
with self._lock:
|
229
|
+
return cache_key in self._processing
|
230
|
+
|
231
|
+
def mark_processing(self, **kwargs: Any) -> threading.Event:
|
232
|
+
"""Mark operation as being processed and return event to wait on."""
|
233
|
+
cache_key = self._get_cache_key(**kwargs)
|
234
|
+
|
235
|
+
with self._lock:
|
236
|
+
if cache_key not in self._processing:
|
237
|
+
self._processing[cache_key] = threading.Event()
|
238
|
+
return self._processing[cache_key]
|
239
|
+
|
240
|
+
def mark_complete(self, **kwargs: Any) -> None:
|
241
|
+
"""Mark operation processing as complete."""
|
242
|
+
cache_key = self._get_cache_key(**kwargs)
|
243
|
+
|
244
|
+
with self._lock:
|
245
|
+
if cache_key in self._processing:
|
246
|
+
event = self._processing.pop(cache_key)
|
247
|
+
event.set()
|
248
|
+
|
249
|
+
def clear(self) -> None:
|
250
|
+
"""Clear all cached results."""
|
251
|
+
try:
|
252
|
+
for cache_file in self.cache_dir.glob("*.msgpack"):
|
253
|
+
cache_file.unlink(missing_ok=True)
|
254
|
+
except OSError:
|
255
|
+
pass
|
256
|
+
|
257
|
+
with self._lock:
|
258
|
+
pass
|
259
|
+
|
260
|
+
def get_stats(self) -> dict[str, Any]:
|
261
|
+
"""Get cache statistics."""
|
262
|
+
try:
|
263
|
+
cache_files = list(self.cache_dir.glob("*.msgpack"))
|
264
|
+
total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists())
|
265
|
+
|
266
|
+
return {
|
267
|
+
"cache_type": self.cache_type,
|
268
|
+
"cached_results": len(cache_files),
|
269
|
+
"processing_results": len(self._processing),
|
270
|
+
"total_cache_size_mb": total_size / 1024 / 1024,
|
271
|
+
"avg_result_size_kb": (total_size / len(cache_files) / 1024) if cache_files else 0,
|
272
|
+
"cache_dir": str(self.cache_dir),
|
273
|
+
"max_cache_size_mb": self.max_cache_size_mb,
|
274
|
+
"max_age_days": self.max_age_days,
|
275
|
+
}
|
276
|
+
except OSError:
|
277
|
+
return {
|
278
|
+
"cache_type": self.cache_type,
|
279
|
+
"cached_results": 0,
|
280
|
+
"processing_results": len(self._processing),
|
281
|
+
"total_cache_size_mb": 0.0,
|
282
|
+
"avg_result_size_kb": 0.0,
|
283
|
+
"cache_dir": str(self.cache_dir),
|
284
|
+
"max_cache_size_mb": self.max_cache_size_mb,
|
285
|
+
"max_age_days": self.max_age_days,
|
286
|
+
}
|
287
|
+
|
288
|
+
|
289
|
+
_ocr_cache: KreuzbergCache[ExtractionResult] | None = None
|
290
|
+
_document_cache: KreuzbergCache[ExtractionResult] | None = None
|
291
|
+
_table_cache: KreuzbergCache[Any] | None = None
|
292
|
+
_mime_cache: KreuzbergCache[str] | None = None
|
293
|
+
|
294
|
+
|
295
|
+
def get_ocr_cache() -> KreuzbergCache[ExtractionResult]:
|
296
|
+
"""Get the global OCR cache instance."""
|
297
|
+
global _ocr_cache
|
298
|
+
if _ocr_cache is None:
|
299
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
300
|
+
cache_dir: Path | None = None
|
301
|
+
if cache_dir_str:
|
302
|
+
cache_dir = Path(cache_dir_str) / "ocr"
|
303
|
+
|
304
|
+
_ocr_cache = KreuzbergCache[ExtractionResult](
|
305
|
+
cache_type="ocr",
|
306
|
+
cache_dir=cache_dir,
|
307
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_OCR_CACHE_SIZE_MB", "500")),
|
308
|
+
max_age_days=int(os.environ.get("KREUZBERG_OCR_CACHE_AGE_DAYS", "30")),
|
309
|
+
)
|
310
|
+
return _ocr_cache
|
311
|
+
|
312
|
+
|
313
|
+
def get_document_cache() -> KreuzbergCache[ExtractionResult]:
|
314
|
+
"""Get the global document cache instance."""
|
315
|
+
global _document_cache
|
316
|
+
if _document_cache is None:
|
317
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
318
|
+
cache_dir: Path | None = None
|
319
|
+
if cache_dir_str:
|
320
|
+
cache_dir = Path(cache_dir_str) / "documents"
|
321
|
+
|
322
|
+
_document_cache = KreuzbergCache[ExtractionResult](
|
323
|
+
cache_type="documents",
|
324
|
+
cache_dir=cache_dir,
|
325
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_DOCUMENT_CACHE_SIZE_MB", "1000")),
|
326
|
+
max_age_days=int(os.environ.get("KREUZBERG_DOCUMENT_CACHE_AGE_DAYS", "7")),
|
327
|
+
)
|
328
|
+
return _document_cache
|
329
|
+
|
330
|
+
|
331
|
+
def get_table_cache() -> KreuzbergCache[Any]:
|
332
|
+
"""Get the global table cache instance."""
|
333
|
+
global _table_cache
|
334
|
+
if _table_cache is None:
|
335
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
336
|
+
cache_dir: Path | None = None
|
337
|
+
if cache_dir_str:
|
338
|
+
cache_dir = Path(cache_dir_str) / "tables"
|
339
|
+
|
340
|
+
_table_cache = KreuzbergCache[Any](
|
341
|
+
cache_type="tables",
|
342
|
+
cache_dir=cache_dir,
|
343
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_TABLE_CACHE_SIZE_MB", "200")),
|
344
|
+
max_age_days=int(os.environ.get("KREUZBERG_TABLE_CACHE_AGE_DAYS", "30")),
|
345
|
+
)
|
346
|
+
return _table_cache
|
347
|
+
|
348
|
+
|
349
|
+
def get_mime_cache() -> KreuzbergCache[str]:
|
350
|
+
"""Get the global MIME type cache instance."""
|
351
|
+
global _mime_cache
|
352
|
+
if _mime_cache is None:
|
353
|
+
cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
|
354
|
+
cache_dir: Path | None = None
|
355
|
+
if cache_dir_str:
|
356
|
+
cache_dir = Path(cache_dir_str) / "mime"
|
357
|
+
|
358
|
+
_mime_cache = KreuzbergCache[str](
|
359
|
+
cache_type="mime",
|
360
|
+
cache_dir=cache_dir,
|
361
|
+
max_cache_size_mb=float(os.environ.get("KREUZBERG_MIME_CACHE_SIZE_MB", "50")),
|
362
|
+
max_age_days=int(os.environ.get("KREUZBERG_MIME_CACHE_AGE_DAYS", "60")),
|
363
|
+
)
|
364
|
+
return _mime_cache
|
365
|
+
|
366
|
+
|
367
|
+
def clear_all_caches() -> None:
|
368
|
+
"""Clear all caches."""
|
369
|
+
get_ocr_cache().clear()
|
370
|
+
get_document_cache().clear()
|
371
|
+
get_table_cache().clear()
|
372
|
+
get_mime_cache().clear()
|
kreuzberg/_utils/_device.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
"""Device detection and management utilities for GPU acceleration."""
|
2
|
-
# ruff: noqa: BLE001
|
2
|
+
# ruff: noqa: BLE001 # ~keep
|
3
3
|
|
4
4
|
from __future__ import annotations
|
5
5
|
|
@@ -36,7 +36,6 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
36
36
|
"""
|
37
37
|
devices: list[DeviceInfo] = []
|
38
38
|
|
39
|
-
# Always include CPU as fallback
|
40
39
|
devices.append(
|
41
40
|
DeviceInfo(
|
42
41
|
device_type="cpu",
|
@@ -44,18 +43,15 @@ def detect_available_devices() -> list[DeviceInfo]:
|
|
44
43
|
)
|
45
44
|
)
|
46
45
|
|
47
|
-
# Check for CUDA (NVIDIA GPUs)
|
48
46
|
if _is_cuda_available():
|
49
47
|
cuda_devices = _get_cuda_devices()
|
50
48
|
devices.extend(cuda_devices)
|
51
49
|
|
52
|
-
# Check for MPS (Apple Silicon)
|
53
50
|
if _is_mps_available():
|
54
51
|
mps_device = _get_mps_device()
|
55
52
|
if mps_device:
|
56
53
|
devices.append(mps_device)
|
57
54
|
|
58
|
-
# Reorder to put GPU devices first
|
59
55
|
gpu_devices = [d for d in devices if d.device_type != "cpu"]
|
60
56
|
cpu_devices = [d for d in devices if d.device_type == "cpu"]
|
61
57
|
|
@@ -95,14 +91,12 @@ def validate_device_request(
|
|
95
91
|
"""
|
96
92
|
available_devices = detect_available_devices()
|
97
93
|
|
98
|
-
# Handle auto device selection
|
99
94
|
if requested == "auto":
|
100
95
|
device = get_optimal_device()
|
101
96
|
if memory_limit is not None:
|
102
97
|
_validate_memory_limit(device, memory_limit)
|
103
98
|
return device
|
104
99
|
|
105
|
-
# Find requested device
|
106
100
|
matching_devices = [d for d in available_devices if d.device_type == requested]
|
107
101
|
|
108
102
|
if not matching_devices:
|
@@ -125,10 +119,8 @@ def validate_device_request(
|
|
125
119
|
},
|
126
120
|
)
|
127
121
|
|
128
|
-
# Use the first matching device (typically the best one)
|
129
122
|
device = matching_devices[0]
|
130
123
|
|
131
|
-
# Validate memory limit if specified
|
132
124
|
if memory_limit is not None:
|
133
125
|
_validate_memory_limit(device, memory_limit)
|
134
126
|
|
@@ -159,7 +151,7 @@ def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | No
|
|
159
151
|
def _is_cuda_available() -> bool:
|
160
152
|
"""Check if CUDA is available."""
|
161
153
|
try:
|
162
|
-
import torch
|
154
|
+
import torch # type: ignore[import-not-found,unused-ignore]
|
163
155
|
|
164
156
|
return torch.cuda.is_available()
|
165
157
|
except ImportError:
|
@@ -169,7 +161,7 @@ def _is_cuda_available() -> bool:
|
|
169
161
|
def _is_mps_available() -> bool:
|
170
162
|
"""Check if MPS (Apple Silicon) is available."""
|
171
163
|
try:
|
172
|
-
import torch
|
164
|
+
import torch # type: ignore[import-not-found,unused-ignore]
|
173
165
|
|
174
166
|
return torch.backends.mps.is_available()
|
175
167
|
except ImportError:
|
@@ -188,17 +180,14 @@ def _get_cuda_devices() -> list[DeviceInfo]:
|
|
188
180
|
|
189
181
|
for i in range(torch.cuda.device_count()):
|
190
182
|
props = torch.cuda.get_device_properties(i)
|
191
|
-
total_memory = props.total_memory / (1024**3)
|
183
|
+
total_memory = props.total_memory / (1024**3)
|
192
184
|
|
193
|
-
# Get available memory
|
194
185
|
torch.cuda.set_device(i)
|
195
186
|
available_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
196
187
|
try:
|
197
|
-
# Try to get current memory usage
|
198
188
|
allocated = torch.cuda.memory_allocated(i) / (1024**3)
|
199
189
|
available_memory = total_memory - allocated
|
200
190
|
except Exception:
|
201
|
-
# Fallback to total memory if we can't get allocation info
|
202
191
|
available_memory = total_memory
|
203
192
|
|
204
193
|
devices.append(
|
@@ -225,7 +214,6 @@ def _get_mps_device() -> DeviceInfo | None:
|
|
225
214
|
if not torch.backends.mps.is_available():
|
226
215
|
return None
|
227
216
|
|
228
|
-
# MPS doesn't provide detailed memory info
|
229
217
|
return DeviceInfo(
|
230
218
|
device_type="mps",
|
231
219
|
name="Apple Silicon GPU (MPS)",
|
@@ -260,8 +248,6 @@ def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
|
|
260
248
|
|
261
249
|
def _get_mps_memory_info() -> tuple[float | None, float | None]:
|
262
250
|
"""Get MPS memory information."""
|
263
|
-
# MPS doesn't provide detailed memory info through PyTorch
|
264
|
-
# We could potentially use system calls but that's platform-specific
|
265
251
|
return None, None
|
266
252
|
|
267
253
|
|
@@ -276,7 +262,7 @@ def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
|
|
276
262
|
ValidationError: If the device doesn't have enough memory.
|
277
263
|
"""
|
278
264
|
if device.device_type == "cpu":
|
279
|
-
# CPU memory validation is complex and OS-dependent, skip for now
|
265
|
+
# CPU memory validation is complex and OS-dependent, skip for now # ~keep
|
280
266
|
return
|
281
267
|
|
282
268
|
total_memory, available_memory = get_device_memory_info(device)
|
@@ -311,7 +297,7 @@ def is_backend_gpu_compatible(backend: str) -> bool:
|
|
311
297
|
Returns:
|
312
298
|
True if the backend supports GPU acceleration.
|
313
299
|
"""
|
314
|
-
# EasyOCR and PaddleOCR support GPU, Tesseract does not
|
300
|
+
# EasyOCR and PaddleOCR support GPU, Tesseract does not # ~keep
|
315
301
|
return backend.lower() in ("easyocr", "paddleocr")
|
316
302
|
|
317
303
|
|
@@ -326,25 +312,22 @@ def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0)
|
|
326
312
|
Recommended batch size.
|
327
313
|
"""
|
328
314
|
if device.device_type == "cpu":
|
329
|
-
# Conservative batch size for CPU
|
315
|
+
# Conservative batch size for CPU # ~keep
|
330
316
|
return 1
|
331
317
|
|
332
|
-
# For GPU devices, estimate based on available memory
|
333
318
|
_, available_memory = get_device_memory_info(device)
|
334
319
|
|
335
320
|
if available_memory is None:
|
336
|
-
# Conservative default for unknown memory
|
337
321
|
return 4
|
338
322
|
|
339
|
-
#
|
340
|
-
# Use approximately 50% of available memory for batching
|
323
|
+
# Use approximately 50% of available memory for batching # ~keep
|
341
324
|
usable_memory_gb = available_memory * 0.5
|
342
325
|
usable_memory_mb = usable_memory_gb * 1024
|
343
326
|
|
344
|
-
# Estimate batch size (conservative)
|
327
|
+
# Estimate batch size (conservative) # ~keep
|
345
328
|
estimated_batch_size = max(1, int(usable_memory_mb / (input_size_mb * 4)))
|
346
329
|
|
347
|
-
# Cap at reasonable limits
|
330
|
+
# Cap at reasonable limits # ~keep
|
348
331
|
return min(estimated_batch_size, 32)
|
349
332
|
|
350
333
|
|