kreuzberg 3.1.7__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kreuzberg/__init__.py +3 -0
  2. kreuzberg/__main__.py +8 -0
  3. kreuzberg/_cli_config.py +175 -0
  4. kreuzberg/_extractors/_image.py +39 -4
  5. kreuzberg/_extractors/_pandoc.py +158 -18
  6. kreuzberg/_extractors/_pdf.py +199 -19
  7. kreuzberg/_extractors/_presentation.py +1 -1
  8. kreuzberg/_extractors/_spread_sheet.py +65 -7
  9. kreuzberg/_gmft.py +222 -16
  10. kreuzberg/_mime_types.py +62 -16
  11. kreuzberg/_multiprocessing/__init__.py +6 -0
  12. kreuzberg/_multiprocessing/gmft_isolated.py +332 -0
  13. kreuzberg/_multiprocessing/process_manager.py +188 -0
  14. kreuzberg/_multiprocessing/sync_tesseract.py +261 -0
  15. kreuzberg/_multiprocessing/tesseract_pool.py +359 -0
  16. kreuzberg/_ocr/_easyocr.py +66 -10
  17. kreuzberg/_ocr/_paddleocr.py +86 -7
  18. kreuzberg/_ocr/_tesseract.py +136 -46
  19. kreuzberg/_playa.py +43 -0
  20. kreuzberg/_utils/_cache.py +372 -0
  21. kreuzberg/_utils/_device.py +356 -0
  22. kreuzberg/_utils/_document_cache.py +220 -0
  23. kreuzberg/_utils/_errors.py +232 -0
  24. kreuzberg/_utils/_pdf_lock.py +72 -0
  25. kreuzberg/_utils/_process_pool.py +100 -0
  26. kreuzberg/_utils/_serialization.py +82 -0
  27. kreuzberg/_utils/_string.py +1 -1
  28. kreuzberg/_utils/_sync.py +21 -0
  29. kreuzberg/cli.py +338 -0
  30. kreuzberg/extraction.py +247 -36
  31. {kreuzberg-3.1.7.dist-info → kreuzberg-3.3.0.dist-info}/METADATA +95 -34
  32. kreuzberg-3.3.0.dist-info/RECORD +48 -0
  33. {kreuzberg-3.1.7.dist-info → kreuzberg-3.3.0.dist-info}/WHEEL +1 -2
  34. kreuzberg-3.3.0.dist-info/entry_points.txt +2 -0
  35. kreuzberg-3.1.7.dist-info/RECORD +0 -33
  36. kreuzberg-3.1.7.dist-info/top_level.txt +0 -1
  37. {kreuzberg-3.1.7.dist-info → kreuzberg-3.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,356 @@
1
+ """Device detection and management utilities for GPU acceleration."""
2
+ # ruff: noqa: BLE001 # ~keep
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Literal
9
+
10
+ from kreuzberg.exceptions import ValidationError
11
+
12
+ DeviceType = Literal["cpu", "cuda", "mps", "auto"]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class DeviceInfo:
17
+ """Information about a compute device."""
18
+
19
+ device_type: Literal["cpu", "cuda", "mps"]
20
+ """The type of device."""
21
+ device_id: int | None = None
22
+ """Device ID for multi-GPU systems. None for CPU or single GPU."""
23
+ memory_total: float | None = None
24
+ """Total memory in GB. None if unknown."""
25
+ memory_available: float | None = None
26
+ """Available memory in GB. None if unknown."""
27
+ name: str | None = None
28
+ """Human-readable device name."""
29
+
30
+
31
+ def detect_available_devices() -> list[DeviceInfo]:
32
+ """Detect all available compute devices.
33
+
34
+ Returns:
35
+ List of available devices, with the most preferred device first.
36
+ """
37
+ devices: list[DeviceInfo] = []
38
+
39
+ devices.append(
40
+ DeviceInfo(
41
+ device_type="cpu",
42
+ name="CPU",
43
+ )
44
+ )
45
+
46
+ if _is_cuda_available():
47
+ cuda_devices = _get_cuda_devices()
48
+ devices.extend(cuda_devices)
49
+
50
+ if _is_mps_available():
51
+ mps_device = _get_mps_device()
52
+ if mps_device:
53
+ devices.append(mps_device)
54
+
55
+ gpu_devices = [d for d in devices if d.device_type != "cpu"]
56
+ cpu_devices = [d for d in devices if d.device_type == "cpu"]
57
+
58
+ return gpu_devices + cpu_devices
59
+
60
+
61
+ def get_optimal_device() -> DeviceInfo:
62
+ """Get the optimal device for OCR processing.
63
+
64
+ Returns:
65
+ The best available device, preferring GPU over CPU.
66
+ """
67
+ devices = detect_available_devices()
68
+ return devices[0] if devices else DeviceInfo(device_type="cpu", name="CPU")
69
+
70
+
71
+ def validate_device_request(
72
+ requested: DeviceType,
73
+ backend: str,
74
+ *,
75
+ memory_limit: float | None = None,
76
+ fallback_to_cpu: bool = True,
77
+ ) -> DeviceInfo:
78
+ """Validate and resolve a device request.
79
+
80
+ Args:
81
+ requested: The requested device type.
82
+ backend: Name of the OCR backend requesting the device.
83
+ memory_limit: Optional memory limit in GB.
84
+ fallback_to_cpu: Whether to fallback to CPU if requested device unavailable.
85
+
86
+ Returns:
87
+ A validated DeviceInfo object.
88
+
89
+ Raises:
90
+ ValidationError: If the requested device is not available and fallback is disabled.
91
+ """
92
+ available_devices = detect_available_devices()
93
+
94
+ if requested == "auto":
95
+ device = get_optimal_device()
96
+ if memory_limit is not None:
97
+ _validate_memory_limit(device, memory_limit)
98
+ return device
99
+
100
+ matching_devices = [d for d in available_devices if d.device_type == requested]
101
+
102
+ if not matching_devices:
103
+ if fallback_to_cpu and requested != "cpu":
104
+ warnings.warn(
105
+ f"Requested device '{requested}' not available for {backend}. Falling back to CPU.",
106
+ UserWarning,
107
+ stacklevel=2,
108
+ )
109
+ cpu_device = next((d for d in available_devices if d.device_type == "cpu"), None)
110
+ if cpu_device:
111
+ return cpu_device
112
+
113
+ raise ValidationError(
114
+ f"Requested device '{requested}' is not available for {backend}",
115
+ context={
116
+ "requested_device": requested,
117
+ "backend": backend,
118
+ "available_devices": [d.device_type for d in available_devices],
119
+ },
120
+ )
121
+
122
+ device = matching_devices[0]
123
+
124
+ if memory_limit is not None:
125
+ _validate_memory_limit(device, memory_limit)
126
+
127
+ return device
128
+
129
+
130
+ def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | None]:
131
+ """Get memory information for a device.
132
+
133
+ Args:
134
+ device: The device to query.
135
+
136
+ Returns:
137
+ Tuple of (total_memory_gb, available_memory_gb). None values if unknown.
138
+ """
139
+ if device.device_type == "cpu":
140
+ return None, None
141
+
142
+ if device.device_type == "cuda":
143
+ return _get_cuda_memory_info(device.device_id or 0)
144
+
145
+ if device.device_type == "mps":
146
+ return _get_mps_memory_info()
147
+
148
+ return None, None
149
+
150
+
151
+ def _is_cuda_available() -> bool:
152
+ """Check if CUDA is available."""
153
+ try:
154
+ import torch # type: ignore[import-not-found,unused-ignore]
155
+
156
+ return torch.cuda.is_available()
157
+ except ImportError:
158
+ return False
159
+
160
+
161
+ def _is_mps_available() -> bool:
162
+ """Check if MPS (Apple Silicon) is available."""
163
+ try:
164
+ import torch # type: ignore[import-not-found,unused-ignore]
165
+
166
+ return torch.backends.mps.is_available()
167
+ except ImportError:
168
+ return False
169
+
170
+
171
+ def _get_cuda_devices() -> list[DeviceInfo]:
172
+ """Get information about available CUDA devices."""
173
+ devices: list[DeviceInfo] = []
174
+
175
+ try:
176
+ import torch
177
+
178
+ if not torch.cuda.is_available():
179
+ return devices
180
+
181
+ for i in range(torch.cuda.device_count()):
182
+ props = torch.cuda.get_device_properties(i)
183
+ total_memory = props.total_memory / (1024**3)
184
+
185
+ torch.cuda.set_device(i)
186
+ available_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3)
187
+ try:
188
+ allocated = torch.cuda.memory_allocated(i) / (1024**3)
189
+ available_memory = total_memory - allocated
190
+ except Exception:
191
+ available_memory = total_memory
192
+
193
+ devices.append(
194
+ DeviceInfo(
195
+ device_type="cuda",
196
+ device_id=i,
197
+ memory_total=total_memory,
198
+ memory_available=available_memory,
199
+ name=props.name,
200
+ )
201
+ )
202
+
203
+ except ImportError:
204
+ pass
205
+
206
+ return devices
207
+
208
+
209
+ def _get_mps_device() -> DeviceInfo | None:
210
+ """Get information about the MPS device."""
211
+ try:
212
+ import torch
213
+
214
+ if not torch.backends.mps.is_available():
215
+ return None
216
+
217
+ return DeviceInfo(
218
+ device_type="mps",
219
+ name="Apple Silicon GPU (MPS)",
220
+ )
221
+
222
+ except ImportError:
223
+ return None
224
+
225
+
226
+ def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
227
+ """Get CUDA memory information for a specific device."""
228
+ try:
229
+ import torch
230
+
231
+ if not torch.cuda.is_available():
232
+ return None, None
233
+
234
+ props = torch.cuda.get_device_properties(device_id)
235
+ total_memory = props.total_memory / (1024**3)
236
+
237
+ try:
238
+ allocated = torch.cuda.memory_allocated(device_id) / (1024**3)
239
+ available_memory = total_memory - allocated
240
+ except Exception:
241
+ available_memory = total_memory
242
+
243
+ return total_memory, available_memory
244
+
245
+ except ImportError:
246
+ return None, None
247
+
248
+
249
+ def _get_mps_memory_info() -> tuple[float | None, float | None]:
250
+ """Get MPS memory information."""
251
+ return None, None
252
+
253
+
254
+ def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
255
+ """Validate that a device has enough memory for the requested limit.
256
+
257
+ Args:
258
+ device: The device to validate.
259
+ memory_limit: Required memory in GB.
260
+
261
+ Raises:
262
+ ValidationError: If the device doesn't have enough memory.
263
+ """
264
+ if device.device_type == "cpu":
265
+ # CPU memory validation is complex and OS-dependent, skip for now # ~keep
266
+ return
267
+
268
+ total_memory, available_memory = get_device_memory_info(device)
269
+
270
+ if total_memory is not None and memory_limit > total_memory:
271
+ raise ValidationError(
272
+ f"Requested memory limit ({memory_limit:.1f}GB) exceeds device capacity ({total_memory:.1f}GB)",
273
+ context={
274
+ "device": device.device_type,
275
+ "device_name": device.name,
276
+ "requested_memory": memory_limit,
277
+ "total_memory": total_memory,
278
+ "available_memory": available_memory,
279
+ },
280
+ )
281
+
282
+ if available_memory is not None and memory_limit > available_memory:
283
+ warnings.warn(
284
+ f"Requested memory limit ({memory_limit:.1f}GB) exceeds available memory "
285
+ f"({available_memory:.1f}GB) on {device.name or device.device_type}",
286
+ UserWarning,
287
+ stacklevel=3,
288
+ )
289
+
290
+
291
+ def is_backend_gpu_compatible(backend: str) -> bool:
292
+ """Check if an OCR backend supports GPU acceleration.
293
+
294
+ Args:
295
+ backend: Name of the OCR backend.
296
+
297
+ Returns:
298
+ True if the backend supports GPU acceleration.
299
+ """
300
+ # EasyOCR and PaddleOCR support GPU, Tesseract does not # ~keep
301
+ return backend.lower() in ("easyocr", "paddleocr")
302
+
303
+
304
+ def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0) -> int:
305
+ """Get recommended batch size for OCR processing.
306
+
307
+ Args:
308
+ device: The device to optimize for.
309
+ input_size_mb: Estimated input size per item in MB.
310
+
311
+ Returns:
312
+ Recommended batch size.
313
+ """
314
+ if device.device_type == "cpu":
315
+ # Conservative batch size for CPU # ~keep
316
+ return 1
317
+
318
+ _, available_memory = get_device_memory_info(device)
319
+
320
+ if available_memory is None:
321
+ return 4
322
+
323
+ # Use approximately 50% of available memory for batching # ~keep
324
+ usable_memory_gb = available_memory * 0.5
325
+ usable_memory_mb = usable_memory_gb * 1024
326
+
327
+ # Estimate batch size (conservative) # ~keep
328
+ estimated_batch_size = max(1, int(usable_memory_mb / (input_size_mb * 4)))
329
+
330
+ # Cap at reasonable limits # ~keep
331
+ return min(estimated_batch_size, 32)
332
+
333
+
334
+ def cleanup_device_memory(device: DeviceInfo) -> None:
335
+ """Clean up device memory.
336
+
337
+ Args:
338
+ device: The device to clean up.
339
+ """
340
+ if device.device_type == "cuda":
341
+ try:
342
+ import torch
343
+
344
+ if torch.cuda.is_available():
345
+ torch.cuda.empty_cache()
346
+ except ImportError:
347
+ pass
348
+
349
+ elif device.device_type == "mps":
350
+ try:
351
+ import torch
352
+
353
+ if torch.backends.mps.is_available():
354
+ torch.mps.empty_cache()
355
+ except (ImportError, AttributeError):
356
+ pass
@@ -0,0 +1,220 @@
1
+ """Document-level caching to prevent pypdfium2 issues with duplicate processing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import threading
7
+ import time
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ if TYPE_CHECKING:
12
+ from kreuzberg._types import ExtractionConfig, ExtractionResult
13
+
14
+
15
+ class DocumentCache:
16
+ """Session-scoped cache for document extraction results.
17
+
18
+ Ensures each unique document is processed only once per session,
19
+ preventing pypdfium2 state corruption issues with repeated processing.
20
+ """
21
+
22
+ def __init__(self) -> None:
23
+ """Initialize document cache."""
24
+ self._cache: dict[str, ExtractionResult] = {}
25
+ self._processing: dict[str, threading.Event] = {}
26
+ self._lock = threading.Lock()
27
+
28
+ self._file_metadata: dict[str, dict[str, Any]] = {}
29
+
30
+ def _get_cache_key(self, file_path: Path | str, config: ExtractionConfig | None = None) -> str:
31
+ """Generate cache key for a file and config combination.
32
+
33
+ Args:
34
+ file_path: Path to the file
35
+ config: Extraction configuration
36
+
37
+ Returns:
38
+ Unique cache key string
39
+ """
40
+ path = Path(file_path).resolve()
41
+
42
+ try:
43
+ stat = path.stat()
44
+ file_info = {
45
+ "path": str(path),
46
+ "size": stat.st_size,
47
+ "mtime": stat.st_mtime,
48
+ }
49
+ except OSError:
50
+ file_info = {"path": str(path), "size": 0, "mtime": 0}
51
+
52
+ config_info = {}
53
+ if config:
54
+ config_info = {
55
+ "force_ocr": config.force_ocr,
56
+ "ocr_backend": config.ocr_backend,
57
+ "extract_tables": config.extract_tables,
58
+ "chunk_content": config.chunk_content,
59
+ "max_chars": config.max_chars,
60
+ "max_overlap": config.max_overlap,
61
+ }
62
+
63
+ cache_data = {**file_info, **config_info}
64
+ cache_str = str(sorted(cache_data.items()))
65
+
66
+ return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
67
+
68
+ def _is_cache_valid(self, cache_key: str, file_path: Path | str) -> bool:
69
+ """Check if cached result is still valid.
70
+
71
+ Args:
72
+ cache_key: The cache key to validate
73
+ file_path: Path to the file
74
+
75
+ Returns:
76
+ True if cache is valid, False if invalidated
77
+ """
78
+ if cache_key not in self._file_metadata:
79
+ return False
80
+
81
+ path = Path(file_path)
82
+ try:
83
+ current_stat = path.stat()
84
+ cached_metadata = self._file_metadata[cache_key]
85
+
86
+ return bool(
87
+ cached_metadata["size"] == current_stat.st_size and cached_metadata["mtime"] == current_stat.st_mtime
88
+ )
89
+ except OSError:
90
+ return False
91
+
92
+ def get(self, file_path: Path | str, config: ExtractionConfig | None = None) -> ExtractionResult | None:
93
+ """Get cached extraction result if available and valid.
94
+
95
+ Args:
96
+ file_path: Path to the file
97
+ config: Extraction configuration
98
+
99
+ Returns:
100
+ Cached result if available, None otherwise
101
+ """
102
+ cache_key = self._get_cache_key(file_path, config)
103
+
104
+ with self._lock:
105
+ if cache_key in self._cache:
106
+ if self._is_cache_valid(cache_key, file_path):
107
+ return self._cache[cache_key]
108
+
109
+ self._cache.pop(cache_key, None)
110
+ self._file_metadata.pop(cache_key, None)
111
+
112
+ return None
113
+
114
+ def set(self, file_path: Path | str, config: ExtractionConfig | None, result: ExtractionResult) -> None:
115
+ """Cache extraction result.
116
+
117
+ Args:
118
+ file_path: Path to the file
119
+ config: Extraction configuration
120
+ result: Extraction result to cache
121
+ """
122
+ cache_key = self._get_cache_key(file_path, config)
123
+ path = Path(file_path)
124
+
125
+ try:
126
+ stat = path.stat()
127
+ file_metadata = {
128
+ "size": stat.st_size,
129
+ "mtime": stat.st_mtime,
130
+ "cached_at": time.time(),
131
+ }
132
+ except OSError:
133
+ file_metadata = {
134
+ "size": 0,
135
+ "mtime": 0,
136
+ "cached_at": time.time(),
137
+ }
138
+
139
+ with self._lock:
140
+ self._cache[cache_key] = result
141
+ self._file_metadata[cache_key] = file_metadata
142
+
143
+ def is_processing(self, file_path: Path | str, config: ExtractionConfig | None = None) -> bool:
144
+ """Check if file is currently being processed.
145
+
146
+ Args:
147
+ file_path: Path to the file
148
+ config: Extraction configuration
149
+
150
+ Returns:
151
+ True if file is currently being processed
152
+ """
153
+ cache_key = self._get_cache_key(file_path, config)
154
+ with self._lock:
155
+ return cache_key in self._processing
156
+
157
+ def mark_processing(self, file_path: Path | str, config: ExtractionConfig | None = None) -> threading.Event:
158
+ """Mark file as being processed and return event to wait on.
159
+
160
+ Args:
161
+ file_path: Path to the file
162
+ config: Extraction configuration
163
+
164
+ Returns:
165
+ Event that will be set when processing completes
166
+ """
167
+ cache_key = self._get_cache_key(file_path, config)
168
+
169
+ with self._lock:
170
+ if cache_key not in self._processing:
171
+ self._processing[cache_key] = threading.Event()
172
+ return self._processing[cache_key]
173
+
174
+ def mark_complete(self, file_path: Path | str, config: ExtractionConfig | None = None) -> None:
175
+ """Mark file processing as complete.
176
+
177
+ Args:
178
+ file_path: Path to the file
179
+ config: Extraction configuration
180
+ """
181
+ cache_key = self._get_cache_key(file_path, config)
182
+
183
+ with self._lock:
184
+ if cache_key in self._processing:
185
+ event = self._processing.pop(cache_key)
186
+ event.set()
187
+
188
+ def clear(self) -> None:
189
+ """Clear all cached results."""
190
+ with self._lock:
191
+ self._cache.clear()
192
+ self._file_metadata.clear()
193
+
194
+ def get_stats(self) -> dict[str, Any]:
195
+ """Get cache statistics.
196
+
197
+ Returns:
198
+ Dictionary with cache statistics
199
+ """
200
+ with self._lock:
201
+ return {
202
+ "cached_documents": len(self._cache),
203
+ "processing_documents": len(self._processing),
204
+ "total_cache_size_mb": sum(len(result.content.encode("utf-8")) for result in self._cache.values())
205
+ / 1024
206
+ / 1024,
207
+ }
208
+
209
+
210
+ _document_cache = DocumentCache()
211
+
212
+
213
+ def get_document_cache() -> DocumentCache:
214
+ """Get the global document cache instance."""
215
+ return _document_cache
216
+
217
+
218
+ def clear_document_cache() -> None:
219
+ """Clear the global document cache."""
220
+ _document_cache.clear()