kreuzberg 3.11.4__py3-none-any.whl → 3.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. kreuzberg/__init__.py +14 -13
  2. kreuzberg/__main__.py +0 -2
  3. kreuzberg/_api/main.py +119 -9
  4. kreuzberg/_chunker.py +0 -15
  5. kreuzberg/_config.py +212 -292
  6. kreuzberg/_document_classification.py +20 -47
  7. kreuzberg/_entity_extraction.py +1 -122
  8. kreuzberg/_extractors/_base.py +4 -71
  9. kreuzberg/_extractors/_email.py +1 -15
  10. kreuzberg/_extractors/_html.py +9 -12
  11. kreuzberg/_extractors/_image.py +1 -25
  12. kreuzberg/_extractors/_pandoc.py +10 -147
  13. kreuzberg/_extractors/_pdf.py +38 -94
  14. kreuzberg/_extractors/_presentation.py +0 -99
  15. kreuzberg/_extractors/_spread_sheet.py +13 -55
  16. kreuzberg/_extractors/_structured.py +1 -4
  17. kreuzberg/_gmft.py +14 -199
  18. kreuzberg/_language_detection.py +1 -36
  19. kreuzberg/_mcp/__init__.py +0 -2
  20. kreuzberg/_mcp/server.py +3 -10
  21. kreuzberg/_mime_types.py +1 -19
  22. kreuzberg/_ocr/_base.py +4 -76
  23. kreuzberg/_ocr/_easyocr.py +124 -186
  24. kreuzberg/_ocr/_paddleocr.py +154 -224
  25. kreuzberg/_ocr/_table_extractor.py +184 -0
  26. kreuzberg/_ocr/_tesseract.py +797 -361
  27. kreuzberg/_playa.py +5 -31
  28. kreuzberg/_registry.py +0 -36
  29. kreuzberg/_types.py +588 -93
  30. kreuzberg/_utils/_cache.py +84 -138
  31. kreuzberg/_utils/_device.py +0 -74
  32. kreuzberg/_utils/_document_cache.py +0 -75
  33. kreuzberg/_utils/_errors.py +0 -50
  34. kreuzberg/_utils/_ocr_cache.py +136 -0
  35. kreuzberg/_utils/_pdf_lock.py +0 -16
  36. kreuzberg/_utils/_process_pool.py +17 -64
  37. kreuzberg/_utils/_quality.py +0 -60
  38. kreuzberg/_utils/_ref.py +32 -0
  39. kreuzberg/_utils/_serialization.py +0 -30
  40. kreuzberg/_utils/_string.py +9 -59
  41. kreuzberg/_utils/_sync.py +0 -77
  42. kreuzberg/_utils/_table.py +49 -101
  43. kreuzberg/_utils/_tmp.py +0 -9
  44. kreuzberg/cli.py +54 -74
  45. kreuzberg/extraction.py +39 -32
  46. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/METADATA +19 -15
  47. kreuzberg-3.13.1.dist-info/RECORD +57 -0
  48. kreuzberg-3.11.4.dist-info/RECORD +0 -54
  49. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/WHEEL +0 -0
  50. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/entry_points.txt +0 -0
  51. {kreuzberg-3.11.4.dist-info → kreuzberg-3.13.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,3 @@
1
- """General-purpose file-based caching layer for Kreuzberg."""
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  import hashlib
@@ -14,6 +12,7 @@ from typing import Any, Generic, TypeVar
14
12
  from anyio import Path as AsyncPath
15
13
 
16
14
  from kreuzberg._types import ExtractionResult
15
+ from kreuzberg._utils._ref import Ref
17
16
  from kreuzberg._utils._serialization import deserialize, serialize
18
17
  from kreuzberg._utils._sync import run_sync
19
18
 
@@ -21,12 +20,6 @@ T = TypeVar("T")
21
20
 
22
21
 
23
22
  class KreuzbergCache(Generic[T]):
24
- """File-based cache for Kreuzberg operations.
25
-
26
- Provides both sync and async interfaces for caching extraction results,
27
- OCR results, table data, and other expensive operations to disk.
28
- """
29
-
30
23
  def __init__(
31
24
  self,
32
25
  cache_type: str,
@@ -34,14 +27,6 @@ class KreuzbergCache(Generic[T]):
34
27
  max_cache_size_mb: float = 500.0,
35
28
  max_age_days: int = 30,
36
29
  ) -> None:
37
- """Initialize cache.
38
-
39
- Args:
40
- cache_type: Type of cache (e.g., 'ocr', 'tables', 'documents', 'mime')
41
- cache_dir: Cache directory (defaults to .kreuzberg/{cache_type} in cwd)
42
- max_cache_size_mb: Maximum cache size in MB (default: 500MB)
43
- max_age_days: Maximum age of cached results in days (default: 30 days)
44
- """
45
30
  if cache_dir is None:
46
31
  cache_dir = Path.cwd() / ".kreuzberg" / cache_type
47
32
 
@@ -57,22 +42,12 @@ class KreuzbergCache(Generic[T]):
57
42
  self._lock = threading.Lock()
58
43
 
59
44
  def _get_cache_key(self, **kwargs: Any) -> str:
60
- """Generate cache key from kwargs.
61
-
62
- Args:
63
- **kwargs: Key-value pairs to generate cache key from
64
-
65
- Returns:
66
- Unique cache key string
67
- """
68
45
  if not kwargs:
69
46
  return "empty"
70
47
 
71
- # Build cache key using list + join (faster than StringIO)
72
48
  parts = []
73
49
  for key in sorted(kwargs):
74
50
  value = kwargs[key]
75
- # Convert common types efficiently
76
51
  if isinstance(value, (str, int, float, bool)):
77
52
  parts.append(f"{key}={value}")
78
53
  elif isinstance(value, bytes):
@@ -81,15 +56,12 @@ class KreuzbergCache(Generic[T]):
81
56
  parts.append(f"{key}={type(value).__name__}:{value!s}")
82
57
 
83
58
  cache_str = "&".join(parts)
84
- # SHA256 is secure and fast enough for cache keys
85
59
  return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
86
60
 
87
61
  def _get_cache_path(self, cache_key: str) -> Path:
88
- """Get cache file path for key."""
89
62
  return self.cache_dir / f"{cache_key}.msgpack"
90
63
 
91
64
  def _is_cache_valid(self, cache_path: Path) -> bool:
92
- """Check if cached result is still valid."""
93
65
  try:
94
66
  if not cache_path.exists():
95
67
  return False
@@ -102,18 +74,14 @@ class KreuzbergCache(Generic[T]):
102
74
  return False
103
75
 
104
76
  def _serialize_result(self, result: T) -> dict[str, Any]:
105
- """Serialize result for caching with metadata."""
106
- # Handle TableData objects that contain DataFrames
107
77
  if isinstance(result, list) and result and isinstance(result[0], dict) and "df" in result[0]:
108
78
  serialized_data = []
109
79
  for item in result:
110
80
  if isinstance(item, dict) and "df" in item:
111
- # Build new dict without unnecessary copy
112
81
  serialized_item = {k: v for k, v in item.items() if k != "df"}
113
82
  if hasattr(item["df"], "to_csv"):
114
83
  serialized_item["df_csv"] = item["df"].to_csv(index=False)
115
84
  else:
116
- # Fallback for non-DataFrame objects
117
85
  serialized_item["df_csv"] = str(item["df"])
118
86
  serialized_data.append(serialized_item)
119
87
  else:
@@ -123,7 +91,6 @@ class KreuzbergCache(Generic[T]):
123
91
  return {"type": type(result).__name__, "data": result, "cached_at": time.time()}
124
92
 
125
93
  def _deserialize_result(self, cached_data: dict[str, Any]) -> T:
126
- """Deserialize cached result."""
127
94
  data = cached_data["data"]
128
95
 
129
96
  if cached_data.get("type") == "TableDataList" and isinstance(data, list):
@@ -132,7 +99,6 @@ class KreuzbergCache(Generic[T]):
132
99
  deserialized_data = []
133
100
  for item in data:
134
101
  if isinstance(item, dict) and "df_csv" in item:
135
- # Build new dict without unnecessary copy
136
102
  deserialized_item = {k: v for k, v in item.items() if k != "df_csv"}
137
103
  deserialized_item["df"] = pd.read_csv(StringIO(item["df_csv"]))
138
104
  deserialized_data.append(deserialized_item)
@@ -146,7 +112,6 @@ class KreuzbergCache(Generic[T]):
146
112
  return data # type: ignore[no-any-return]
147
113
 
148
114
  def _cleanup_cache(self) -> None:
149
- """Clean up old and oversized cache entries."""
150
115
  try:
151
116
  cache_files = list(self.cache_dir.glob("*.msgpack"))
152
117
 
@@ -180,14 +145,6 @@ class KreuzbergCache(Generic[T]):
180
145
  pass
181
146
 
182
147
  def get(self, **kwargs: Any) -> T | None:
183
- """Get cached result (sync).
184
-
185
- Args:
186
- **kwargs: Key-value pairs to generate cache key from
187
-
188
- Returns:
189
- Cached result if available, None otherwise
190
- """
191
148
  cache_key = self._get_cache_key(**kwargs)
192
149
  cache_path = self._get_cache_path(cache_key)
193
150
 
@@ -204,12 +161,6 @@ class KreuzbergCache(Generic[T]):
204
161
  return None
205
162
 
206
163
  def set(self, result: T, **kwargs: Any) -> None:
207
- """Cache result (sync).
208
-
209
- Args:
210
- result: Result to cache
211
- **kwargs: Key-value pairs to generate cache key from
212
- """
213
164
  cache_key = self._get_cache_key(**kwargs)
214
165
  cache_path = self._get_cache_path(cache_key)
215
166
 
@@ -224,14 +175,6 @@ class KreuzbergCache(Generic[T]):
224
175
  pass
225
176
 
226
177
  async def aget(self, **kwargs: Any) -> T | None:
227
- """Get cached result (async).
228
-
229
- Args:
230
- **kwargs: Key-value pairs to generate cache key from
231
-
232
- Returns:
233
- Cached result if available, None otherwise
234
- """
235
178
  cache_key = self._get_cache_key(**kwargs)
236
179
  cache_path = AsyncPath(self._get_cache_path(cache_key))
237
180
 
@@ -248,12 +191,6 @@ class KreuzbergCache(Generic[T]):
248
191
  return None
249
192
 
250
193
  async def aset(self, result: T, **kwargs: Any) -> None:
251
- """Cache result (async).
252
-
253
- Args:
254
- result: Result to cache
255
- **kwargs: Key-value pairs to generate cache key from
256
- """
257
194
  cache_key = self._get_cache_key(**kwargs)
258
195
  cache_path = AsyncPath(self._get_cache_path(cache_key))
259
196
 
@@ -268,13 +205,11 @@ class KreuzbergCache(Generic[T]):
268
205
  pass
269
206
 
270
207
  def is_processing(self, **kwargs: Any) -> bool:
271
- """Check if operation is currently being processed."""
272
208
  cache_key = self._get_cache_key(**kwargs)
273
209
  with self._lock:
274
210
  return cache_key in self._processing
275
211
 
276
212
  def mark_processing(self, **kwargs: Any) -> threading.Event:
277
- """Mark operation as being processed and return event to wait on."""
278
213
  cache_key = self._get_cache_key(**kwargs)
279
214
 
280
215
  with self._lock:
@@ -283,7 +218,6 @@ class KreuzbergCache(Generic[T]):
283
218
  return self._processing[cache_key]
284
219
 
285
220
  def mark_complete(self, **kwargs: Any) -> None:
286
- """Mark operation processing as complete."""
287
221
  cache_key = self._get_cache_key(**kwargs)
288
222
 
289
223
  with self._lock:
@@ -292,7 +226,6 @@ class KreuzbergCache(Generic[T]):
292
226
  event.set()
293
227
 
294
228
  def clear(self) -> None:
295
- """Clear all cached results."""
296
229
  try:
297
230
  for cache_file in self.cache_dir.glob("*.msgpack"):
298
231
  cache_file.unlink(missing_ok=True)
@@ -303,7 +236,6 @@ class KreuzbergCache(Generic[T]):
303
236
  pass
304
237
 
305
238
  def get_stats(self) -> dict[str, Any]:
306
- """Get cache statistics."""
307
239
  try:
308
240
  cache_files = list(self.cache_dir.glob("*.msgpack"))
309
241
  total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists())
@@ -331,87 +263,101 @@ class KreuzbergCache(Generic[T]):
331
263
  }
332
264
 
333
265
 
334
- _ocr_cache: KreuzbergCache[ExtractionResult] | None = None
335
- _document_cache: KreuzbergCache[ExtractionResult] | None = None
336
- _table_cache: KreuzbergCache[Any] | None = None
337
- _mime_cache: KreuzbergCache[str] | None = None
266
+ def _create_ocr_cache() -> KreuzbergCache[ExtractionResult]:
267
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
268
+ cache_dir: Path | None = None
269
+ if cache_dir_str:
270
+ cache_dir = Path(cache_dir_str) / "ocr"
271
+
272
+ return KreuzbergCache[ExtractionResult](
273
+ cache_type="ocr",
274
+ cache_dir=cache_dir,
275
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_OCR_CACHE_SIZE_MB", "500")),
276
+ max_age_days=int(os.environ.get("KREUZBERG_OCR_CACHE_AGE_DAYS", "30")),
277
+ )
278
+
279
+
280
+ _ocr_cache_ref = Ref("ocr_cache", _create_ocr_cache)
338
281
 
339
282
 
340
283
  def get_ocr_cache() -> KreuzbergCache[ExtractionResult]:
341
- """Get the global OCR cache instance."""
342
- global _ocr_cache
343
- if _ocr_cache is None:
344
- cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
345
- cache_dir: Path | None = None
346
- if cache_dir_str:
347
- cache_dir = Path(cache_dir_str) / "ocr"
348
-
349
- _ocr_cache = KreuzbergCache[ExtractionResult](
350
- cache_type="ocr",
351
- cache_dir=cache_dir,
352
- max_cache_size_mb=float(os.environ.get("KREUZBERG_OCR_CACHE_SIZE_MB", "500")),
353
- max_age_days=int(os.environ.get("KREUZBERG_OCR_CACHE_AGE_DAYS", "30")),
354
- )
355
- return _ocr_cache
284
+ return _ocr_cache_ref.get()
285
+
286
+
287
+ def _create_document_cache() -> KreuzbergCache[ExtractionResult]:
288
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
289
+ cache_dir: Path | None = None
290
+ if cache_dir_str:
291
+ cache_dir = Path(cache_dir_str) / "documents"
292
+
293
+ return KreuzbergCache[ExtractionResult](
294
+ cache_type="documents",
295
+ cache_dir=cache_dir,
296
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_DOCUMENT_CACHE_SIZE_MB", "1000")),
297
+ max_age_days=int(os.environ.get("KREUZBERG_DOCUMENT_CACHE_AGE_DAYS", "7")),
298
+ )
299
+
300
+
301
+ _document_cache_ref = Ref("document_cache", _create_document_cache)
356
302
 
357
303
 
358
304
  def get_document_cache() -> KreuzbergCache[ExtractionResult]:
359
- """Get the global document cache instance."""
360
- global _document_cache
361
- if _document_cache is None:
362
- cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
363
- cache_dir: Path | None = None
364
- if cache_dir_str:
365
- cache_dir = Path(cache_dir_str) / "documents"
366
-
367
- _document_cache = KreuzbergCache[ExtractionResult](
368
- cache_type="documents",
369
- cache_dir=cache_dir,
370
- max_cache_size_mb=float(os.environ.get("KREUZBERG_DOCUMENT_CACHE_SIZE_MB", "1000")),
371
- max_age_days=int(os.environ.get("KREUZBERG_DOCUMENT_CACHE_AGE_DAYS", "7")),
372
- )
373
- return _document_cache
305
+ return _document_cache_ref.get()
306
+
307
+
308
+ def _create_table_cache() -> KreuzbergCache[Any]:
309
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
310
+ cache_dir: Path | None = None
311
+ if cache_dir_str:
312
+ cache_dir = Path(cache_dir_str) / "tables"
313
+
314
+ return KreuzbergCache[Any](
315
+ cache_type="tables",
316
+ cache_dir=cache_dir,
317
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_TABLE_CACHE_SIZE_MB", "200")),
318
+ max_age_days=int(os.environ.get("KREUZBERG_TABLE_CACHE_AGE_DAYS", "30")),
319
+ )
320
+
321
+
322
+ _table_cache_ref = Ref("table_cache", _create_table_cache)
374
323
 
375
324
 
376
325
  def get_table_cache() -> KreuzbergCache[Any]:
377
- """Get the global table cache instance."""
378
- global _table_cache
379
- if _table_cache is None:
380
- cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
381
- cache_dir: Path | None = None
382
- if cache_dir_str:
383
- cache_dir = Path(cache_dir_str) / "tables"
384
-
385
- _table_cache = KreuzbergCache[Any](
386
- cache_type="tables",
387
- cache_dir=cache_dir,
388
- max_cache_size_mb=float(os.environ.get("KREUZBERG_TABLE_CACHE_SIZE_MB", "200")),
389
- max_age_days=int(os.environ.get("KREUZBERG_TABLE_CACHE_AGE_DAYS", "30")),
390
- )
391
- return _table_cache
326
+ return _table_cache_ref.get()
327
+
328
+
329
+ def _create_mime_cache() -> KreuzbergCache[str]:
330
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
331
+ cache_dir: Path | None = None
332
+ if cache_dir_str:
333
+ cache_dir = Path(cache_dir_str) / "mime"
334
+
335
+ return KreuzbergCache[str](
336
+ cache_type="mime",
337
+ cache_dir=cache_dir,
338
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_MIME_CACHE_SIZE_MB", "50")),
339
+ max_age_days=int(os.environ.get("KREUZBERG_MIME_CACHE_AGE_DAYS", "60")),
340
+ )
341
+
342
+
343
+ _mime_cache_ref = Ref("mime_cache", _create_mime_cache)
392
344
 
393
345
 
394
346
  def get_mime_cache() -> KreuzbergCache[str]:
395
- """Get the global MIME type cache instance."""
396
- global _mime_cache
397
- if _mime_cache is None:
398
- cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
399
- cache_dir: Path | None = None
400
- if cache_dir_str:
401
- cache_dir = Path(cache_dir_str) / "mime"
402
-
403
- _mime_cache = KreuzbergCache[str](
404
- cache_type="mime",
405
- cache_dir=cache_dir,
406
- max_cache_size_mb=float(os.environ.get("KREUZBERG_MIME_CACHE_SIZE_MB", "50")),
407
- max_age_days=int(os.environ.get("KREUZBERG_MIME_CACHE_AGE_DAYS", "60")),
408
- )
409
- return _mime_cache
347
+ return _mime_cache_ref.get()
410
348
 
411
349
 
412
350
  def clear_all_caches() -> None:
413
- """Clear all caches."""
414
- get_ocr_cache().clear()
415
- get_document_cache().clear()
416
- get_table_cache().clear()
417
- get_mime_cache().clear()
351
+ if _ocr_cache_ref.is_initialized():
352
+ get_ocr_cache().clear()
353
+ if _document_cache_ref.is_initialized():
354
+ get_document_cache().clear()
355
+ if _table_cache_ref.is_initialized():
356
+ get_table_cache().clear()
357
+ if _mime_cache_ref.is_initialized():
358
+ get_mime_cache().clear()
359
+
360
+ _ocr_cache_ref.clear()
361
+ _document_cache_ref.clear()
362
+ _table_cache_ref.clear()
363
+ _mime_cache_ref.clear()
@@ -1,4 +1,3 @@
1
- """Device detection and management utilities for GPU acceleration."""
2
1
  # ruff: noqa: BLE001 # ~keep
3
2
 
4
3
  from __future__ import annotations
@@ -15,8 +14,6 @@ DeviceType = Literal["cpu", "cuda", "mps", "auto"]
15
14
 
16
15
  @dataclass(frozen=True, slots=True)
17
16
  class DeviceInfo:
18
- """Information about a compute device."""
19
-
20
17
  device_type: Literal["cpu", "cuda", "mps"]
21
18
  """The type of device."""
22
19
  device_id: int | None = None
@@ -30,12 +27,6 @@ class DeviceInfo:
30
27
 
31
28
 
32
29
  def detect_available_devices() -> list[DeviceInfo]:
33
- """Detect all available compute devices.
34
-
35
- Returns:
36
- List of available devices, with the most preferred device first.
37
- """
38
- # Build device lists efficiently using generators
39
30
  cpu_device = DeviceInfo(device_type="cpu", name="CPU")
40
31
 
41
32
  cuda_devices = _get_cuda_devices() if _is_cuda_available() else []
@@ -43,17 +34,11 @@ def detect_available_devices() -> list[DeviceInfo]:
43
34
  mps_device = _get_mps_device() if _is_mps_available() else None
44
35
  mps_devices = [mps_device] if mps_device else []
45
36
 
46
- # Return GPU devices first, then CPU using itertools.chain
47
37
  gpu_devices = list(chain(cuda_devices, mps_devices))
48
38
  return [*gpu_devices, cpu_device]
49
39
 
50
40
 
51
41
  def get_optimal_device() -> DeviceInfo:
52
- """Get the optimal device for OCR processing.
53
-
54
- Returns:
55
- The best available device, preferring GPU over CPU.
56
- """
57
42
  devices = detect_available_devices()
58
43
  return devices[0] if devices else DeviceInfo(device_type="cpu", name="CPU")
59
44
 
@@ -65,20 +50,6 @@ def validate_device_request(
65
50
  memory_limit: float | None = None,
66
51
  fallback_to_cpu: bool = True,
67
52
  ) -> DeviceInfo:
68
- """Validate and resolve a device request.
69
-
70
- Args:
71
- requested: The requested device type.
72
- backend: Name of the OCR backend requesting the device.
73
- memory_limit: Optional memory limit in GB.
74
- fallback_to_cpu: Whether to fallback to CPU if requested device unavailable.
75
-
76
- Returns:
77
- A validated DeviceInfo object.
78
-
79
- Raises:
80
- ValidationError: If the requested device is not available and fallback is disabled.
81
- """
82
53
  available_devices = detect_available_devices()
83
54
 
84
55
  if requested == "auto":
@@ -118,14 +89,6 @@ def validate_device_request(
118
89
 
119
90
 
120
91
  def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | None]:
121
- """Get memory information for a device.
122
-
123
- Args:
124
- device: The device to query.
125
-
126
- Returns:
127
- Tuple of (total_memory_gb, available_memory_gb). None values if unknown.
128
- """
129
92
  if device.device_type == "cpu":
130
93
  return None, None
131
94
 
@@ -139,7 +102,6 @@ def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | No
139
102
 
140
103
 
141
104
  def _is_cuda_available() -> bool:
142
- """Check if CUDA is available."""
143
105
  try:
144
106
  import torch # type: ignore[import-not-found,unused-ignore] # noqa: PLC0415
145
107
 
@@ -149,7 +111,6 @@ def _is_cuda_available() -> bool:
149
111
 
150
112
 
151
113
  def _is_mps_available() -> bool:
152
- """Check if MPS (Apple Silicon) is available."""
153
114
  try:
154
115
  import torch # type: ignore[import-not-found,unused-ignore] # noqa: PLC0415
155
116
 
@@ -159,7 +120,6 @@ def _is_mps_available() -> bool:
159
120
 
160
121
 
161
122
  def _get_cuda_devices() -> list[DeviceInfo]:
162
- """Get information about available CUDA devices."""
163
123
  devices: list[DeviceInfo] = []
164
124
 
165
125
  try:
@@ -197,7 +157,6 @@ def _get_cuda_devices() -> list[DeviceInfo]:
197
157
 
198
158
 
199
159
  def _get_mps_device() -> DeviceInfo | None:
200
- """Get information about the MPS device."""
201
160
  try:
202
161
  import torch # noqa: PLC0415
203
162
 
@@ -214,7 +173,6 @@ def _get_mps_device() -> DeviceInfo | None:
214
173
 
215
174
 
216
175
  def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
217
- """Get CUDA memory information for a specific device."""
218
176
  try:
219
177
  import torch # noqa: PLC0415
220
178
 
@@ -237,20 +195,10 @@ def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
237
195
 
238
196
 
239
197
  def _get_mps_memory_info() -> tuple[float | None, float | None]:
240
- """Get MPS memory information."""
241
198
  return None, None
242
199
 
243
200
 
244
201
  def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
245
- """Validate that a device has enough memory for the requested limit.
246
-
247
- Args:
248
- device: The device to validate.
249
- memory_limit: Required memory in GB.
250
-
251
- Raises:
252
- ValidationError: If the device doesn't have enough memory.
253
- """
254
202
  if device.device_type == "cpu":
255
203
  # CPU memory validation is complex and OS-dependent, skip for now # ~keep
256
204
  return
@@ -279,28 +227,11 @@ def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
279
227
 
280
228
 
281
229
  def is_backend_gpu_compatible(backend: str) -> bool:
282
- """Check if an OCR backend supports GPU acceleration.
283
-
284
- Args:
285
- backend: Name of the OCR backend.
286
-
287
- Returns:
288
- True if the backend supports GPU acceleration.
289
- """
290
230
  # EasyOCR and PaddleOCR support GPU, Tesseract does not # ~keep
291
231
  return backend.lower() in ("easyocr", "paddleocr")
292
232
 
293
233
 
294
234
  def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0) -> int:
295
- """Get recommended batch size for OCR processing.
296
-
297
- Args:
298
- device: The device to optimize for.
299
- input_size_mb: Estimated input size per item in MB.
300
-
301
- Returns:
302
- Recommended batch size.
303
- """
304
235
  if device.device_type == "cpu":
305
236
  # Conservative batch size for CPU # ~keep
306
237
  return 1
@@ -322,11 +253,6 @@ def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0)
322
253
 
323
254
 
324
255
  def cleanup_device_memory(device: DeviceInfo) -> None:
325
- """Clean up device memory.
326
-
327
- Args:
328
- device: The device to clean up.
329
- """
330
256
  if device.device_type == "cuda":
331
257
  try:
332
258
  import torch # noqa: PLC0415