kreuzberg 3.14.1__py3-none-any.whl → 3.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kreuzberg/__init__.py +6 -0
  2. kreuzberg/_api/_config_cache.py +247 -0
  3. kreuzberg/_api/main.py +127 -45
  4. kreuzberg/_chunker.py +7 -6
  5. kreuzberg/_constants.py +2 -0
  6. kreuzberg/_document_classification.py +4 -6
  7. kreuzberg/_entity_extraction.py +9 -4
  8. kreuzberg/_extractors/_base.py +269 -3
  9. kreuzberg/_extractors/_email.py +95 -27
  10. kreuzberg/_extractors/_html.py +85 -7
  11. kreuzberg/_extractors/_image.py +23 -22
  12. kreuzberg/_extractors/_pandoc.py +106 -75
  13. kreuzberg/_extractors/_pdf.py +209 -99
  14. kreuzberg/_extractors/_presentation.py +72 -8
  15. kreuzberg/_extractors/_spread_sheet.py +25 -30
  16. kreuzberg/_mcp/server.py +345 -25
  17. kreuzberg/_mime_types.py +42 -0
  18. kreuzberg/_ocr/_easyocr.py +2 -2
  19. kreuzberg/_ocr/_paddleocr.py +1 -1
  20. kreuzberg/_ocr/_tesseract.py +74 -34
  21. kreuzberg/_types.py +180 -21
  22. kreuzberg/_utils/_cache.py +10 -4
  23. kreuzberg/_utils/_device.py +2 -4
  24. kreuzberg/_utils/_image_preprocessing.py +12 -39
  25. kreuzberg/_utils/_process_pool.py +29 -8
  26. kreuzberg/_utils/_quality.py +7 -2
  27. kreuzberg/_utils/_resource_managers.py +65 -0
  28. kreuzberg/_utils/_sync.py +36 -6
  29. kreuzberg/_utils/_tmp.py +37 -1
  30. kreuzberg/cli.py +34 -20
  31. kreuzberg/extraction.py +43 -27
  32. {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/METADATA +2 -1
  33. kreuzberg-3.15.0.dist-info/RECORD +60 -0
  34. kreuzberg-3.14.1.dist-info/RECORD +0 -58
  35. {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/WHEEL +0 -0
  36. {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/entry_points.txt +0 -0
  37. {kreuzberg-3.14.1.dist-info → kreuzberg-3.15.0.dist-info}/licenses/LICENSE +0 -0
kreuzberg/__init__.py CHANGED
@@ -4,9 +4,12 @@ from ._registry import ExtractorRegistry
4
4
  from ._types import (
5
5
  EasyOCRConfig,
6
6
  Entity,
7
+ ExtractedImage,
7
8
  ExtractionConfig,
8
9
  ExtractionResult,
9
10
  GMFTConfig,
11
+ ImageOCRConfig,
12
+ ImageOCRResult,
10
13
  LanguageDetectionConfig,
11
14
  Metadata,
12
15
  PaddleOCRConfig,
@@ -32,10 +35,13 @@ __version__ = version("kreuzberg")
32
35
  __all__ = [
33
36
  "EasyOCRConfig",
34
37
  "Entity",
38
+ "ExtractedImage",
35
39
  "ExtractionConfig",
36
40
  "ExtractionResult",
37
41
  "ExtractorRegistry",
38
42
  "GMFTConfig",
43
+ "ImageOCRConfig",
44
+ "ImageOCRResult",
39
45
  "KreuzbergError",
40
46
  "LanguageDetectionConfig",
41
47
  "Metadata",
@@ -0,0 +1,247 @@
1
+ """API Configuration Caching Module.
2
+
3
+ This module provides LRU cached functions for API config operations to improve performance
4
+ by avoiding repeated file system operations and object creation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ from functools import lru_cache
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ from kreuzberg._config import discover_config
16
+ from kreuzberg._types import (
17
+ EasyOCRConfig,
18
+ ExtractionConfig,
19
+ GMFTConfig,
20
+ HTMLToMarkdownConfig,
21
+ LanguageDetectionConfig,
22
+ PaddleOCRConfig,
23
+ SpacyEntityExtractionConfig,
24
+ TesseractConfig,
25
+ )
26
+
27
+
28
+ @lru_cache(maxsize=16)
29
+ def _cached_discover_config(
30
+ search_path: str,
31
+ config_file_mtime: float, # noqa: ARG001
32
+ config_file_size: int, # noqa: ARG001
33
+ ) -> ExtractionConfig | None:
34
+ """Cache config discovery with file modification time validation."""
35
+ return discover_config(Path(search_path))
36
+
37
+
38
+ def discover_config_cached(search_path: Path | str | None = None) -> ExtractionConfig | None:
39
+ """Cached version of discover_config with automatic invalidation.
40
+
41
+ This function caches the result of discover_config() and automatically invalidates
42
+ the cache when config files are modified.
43
+
44
+ Args:
45
+ search_path: Path to start searching for config files from
46
+
47
+ Returns:
48
+ ExtractionConfig if found, None otherwise
49
+ """
50
+ search_path = Path.cwd() if search_path is None else Path(search_path)
51
+
52
+ config_files = ["kreuzberg.toml", "pyproject.toml"]
53
+ for config_file_name in config_files:
54
+ config_path = search_path / config_file_name
55
+ if config_path.exists():
56
+ try:
57
+ stat = config_path.stat()
58
+ return _cached_discover_config(
59
+ str(search_path),
60
+ stat.st_mtime,
61
+ stat.st_size,
62
+ )
63
+ except OSError:
64
+ return discover_config(search_path)
65
+
66
+ return _cached_discover_config(str(search_path), 0.0, 0)
67
+
68
+
69
+ @lru_cache(maxsize=128)
70
+ def _cached_create_ocr_config(
71
+ config_type: str,
72
+ config_json: str,
73
+ ) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig:
74
+ """Cache OCR config object creation."""
75
+ config_dict = json.loads(config_json)
76
+
77
+ if config_type == "tesseract":
78
+ return TesseractConfig(**config_dict)
79
+ if config_type == "easyocr":
80
+ return EasyOCRConfig(**config_dict)
81
+ if config_type == "paddleocr":
82
+ return PaddleOCRConfig(**config_dict)
83
+ msg = f"Unknown OCR config type: {config_type}"
84
+ raise ValueError(msg)
85
+
86
+
87
+ @lru_cache(maxsize=64)
88
+ def _cached_create_gmft_config(config_json: str) -> GMFTConfig:
89
+ """Cache GMFT config creation."""
90
+ return GMFTConfig(**json.loads(config_json))
91
+
92
+
93
+ @lru_cache(maxsize=64)
94
+ def _cached_create_language_detection_config(config_json: str) -> LanguageDetectionConfig:
95
+ """Cache language detection config creation."""
96
+ return LanguageDetectionConfig(**json.loads(config_json))
97
+
98
+
99
+ @lru_cache(maxsize=64)
100
+ def _cached_create_spacy_config(config_json: str) -> SpacyEntityExtractionConfig:
101
+ """Cache spaCy entity extraction config creation."""
102
+ return SpacyEntityExtractionConfig(**json.loads(config_json))
103
+
104
+
105
+ @lru_cache(maxsize=64)
106
+ def _cached_create_html_markdown_config(config_json: str) -> HTMLToMarkdownConfig:
107
+ """Cache HTML to Markdown config creation."""
108
+ return HTMLToMarkdownConfig(**json.loads(config_json))
109
+
110
+
111
+ @lru_cache(maxsize=256)
112
+ def _cached_parse_header_config(header_value: str) -> dict[str, Any]:
113
+ """Cache parsed header configurations."""
114
+ parsed_config: dict[str, Any] = json.loads(header_value)
115
+ return parsed_config
116
+
117
+
118
+ def create_ocr_config_cached(
119
+ ocr_backend: str | None, config_dict: dict[str, Any]
120
+ ) -> TesseractConfig | EasyOCRConfig | PaddleOCRConfig:
121
+ """Cached version of OCR config creation.
122
+
123
+ Args:
124
+ ocr_backend: The OCR backend type
125
+ config_dict: Configuration dictionary
126
+
127
+ Returns:
128
+ Configured OCR config object
129
+ """
130
+ if not ocr_backend:
131
+ return TesseractConfig()
132
+
133
+ config_json = json.dumps(config_dict, sort_keys=True)
134
+ return _cached_create_ocr_config(ocr_backend, config_json)
135
+
136
+
137
+ def create_gmft_config_cached(config_dict: dict[str, Any]) -> GMFTConfig:
138
+ """Cached version of GMFT config creation."""
139
+ config_json = json.dumps(config_dict, sort_keys=True)
140
+ return _cached_create_gmft_config(config_json)
141
+
142
+
143
+ def create_language_detection_config_cached(config_dict: dict[str, Any]) -> LanguageDetectionConfig:
144
+ """Cached version of language detection config creation."""
145
+ config_json = json.dumps(config_dict, sort_keys=True)
146
+ return _cached_create_language_detection_config(config_json)
147
+
148
+
149
+ def create_spacy_config_cached(config_dict: dict[str, Any]) -> SpacyEntityExtractionConfig:
150
+ """Cached version of spaCy config creation."""
151
+ config_json = json.dumps(config_dict, sort_keys=True)
152
+ return _cached_create_spacy_config(config_json)
153
+
154
+
155
+ def create_html_markdown_config_cached(config_dict: dict[str, Any]) -> HTMLToMarkdownConfig:
156
+ """Cached version of HTML to Markdown config creation."""
157
+ config_json = json.dumps(config_dict, sort_keys=True)
158
+ return _cached_create_html_markdown_config(config_json)
159
+
160
+
161
+ def parse_header_config_cached(header_value: str) -> dict[str, Any]:
162
+ """Cached version of header config parsing.
163
+
164
+ Args:
165
+ header_value: JSON string from X-Extraction-Config header
166
+
167
+ Returns:
168
+ Parsed configuration dictionary
169
+ """
170
+ return _cached_parse_header_config(header_value)
171
+
172
+
173
+ @lru_cache(maxsize=512)
174
+ def _cached_merge_configs(
175
+ static_config_hash: str,
176
+ query_params_hash: str,
177
+ header_config_hash: str,
178
+ ) -> ExtractionConfig:
179
+ """Cache the complete config merging process.
180
+
181
+ This is the ultimate optimization - cache the entire result of merge_configs()
182
+ based on content hashes of all inputs.
183
+ """
184
+ msg = "Not implemented yet - use individual component caching"
185
+ raise NotImplementedError(msg)
186
+
187
+
188
+ def _hash_dict(data: dict[str, Any] | None) -> str:
189
+ """Create a hash string from a dictionary for cache keys."""
190
+ if data is None:
191
+ return "none"
192
+
193
+ json_str = json.dumps(data, sort_keys=True, default=str)
194
+ return hashlib.sha256(json_str.encode()).hexdigest()[:16]
195
+
196
+
197
+ def get_cache_stats() -> dict[str, Any]:
198
+ """Get cache statistics for monitoring performance."""
199
+ return {
200
+ "discover_config": {
201
+ "hits": _cached_discover_config.cache_info().hits,
202
+ "misses": _cached_discover_config.cache_info().misses,
203
+ "size": _cached_discover_config.cache_info().currsize,
204
+ "max_size": _cached_discover_config.cache_info().maxsize,
205
+ },
206
+ "ocr_config": {
207
+ "hits": _cached_create_ocr_config.cache_info().hits,
208
+ "misses": _cached_create_ocr_config.cache_info().misses,
209
+ "size": _cached_create_ocr_config.cache_info().currsize,
210
+ "max_size": _cached_create_ocr_config.cache_info().maxsize,
211
+ },
212
+ "header_parsing": {
213
+ "hits": _cached_parse_header_config.cache_info().hits,
214
+ "misses": _cached_parse_header_config.cache_info().misses,
215
+ "size": _cached_parse_header_config.cache_info().currsize,
216
+ "max_size": _cached_parse_header_config.cache_info().maxsize,
217
+ },
218
+ "gmft_config": {
219
+ "hits": _cached_create_gmft_config.cache_info().hits,
220
+ "misses": _cached_create_gmft_config.cache_info().misses,
221
+ "size": _cached_create_gmft_config.cache_info().currsize,
222
+ "max_size": _cached_create_gmft_config.cache_info().maxsize,
223
+ },
224
+ "language_detection_config": {
225
+ "hits": _cached_create_language_detection_config.cache_info().hits,
226
+ "misses": _cached_create_language_detection_config.cache_info().misses,
227
+ "size": _cached_create_language_detection_config.cache_info().currsize,
228
+ "max_size": _cached_create_language_detection_config.cache_info().maxsize,
229
+ },
230
+ "spacy_config": {
231
+ "hits": _cached_create_spacy_config.cache_info().hits,
232
+ "misses": _cached_create_spacy_config.cache_info().misses,
233
+ "size": _cached_create_spacy_config.cache_info().currsize,
234
+ "max_size": _cached_create_spacy_config.cache_info().maxsize,
235
+ },
236
+ }
237
+
238
+
239
+ def clear_all_caches() -> None:
240
+ """Clear all API configuration caches."""
241
+ _cached_discover_config.cache_clear()
242
+ _cached_create_ocr_config.cache_clear()
243
+ _cached_create_gmft_config.cache_clear()
244
+ _cached_create_language_detection_config.cache_clear()
245
+ _cached_create_spacy_config.cache_clear()
246
+ _cached_create_html_markdown_config.cache_clear()
247
+ _cached_parse_header_config.cache_clear()
kreuzberg/_api/main.py CHANGED
@@ -3,8 +3,7 @@ from __future__ import annotations
3
3
  import base64
4
4
  import io
5
5
  import traceback
6
- from functools import lru_cache
7
- from json import dumps, loads
6
+ from json import dumps
8
7
  from typing import TYPE_CHECKING, Annotated, Any, Literal
9
8
 
10
9
  import msgspec
@@ -14,26 +13,57 @@ from typing_extensions import TypedDict
14
13
 
15
14
  from kreuzberg import (
16
15
  EasyOCRConfig,
16
+ ExtractedImage,
17
17
  ExtractionConfig,
18
18
  ExtractionResult,
19
- GMFTConfig,
19
+ ImageOCRResult,
20
20
  KreuzbergError,
21
- LanguageDetectionConfig,
22
21
  MissingDependencyError,
23
22
  PaddleOCRConfig,
24
23
  ParsingError,
25
- SpacyEntityExtractionConfig,
26
24
  TesseractConfig,
27
25
  ValidationError,
28
26
  batch_extract_bytes,
29
27
  )
28
+ from kreuzberg._api._config_cache import (
29
+ create_gmft_config_cached,
30
+ create_html_markdown_config_cached,
31
+ create_language_detection_config_cached,
32
+ create_ocr_config_cached,
33
+ create_spacy_config_cached,
34
+ discover_config_cached,
35
+ parse_header_config_cached,
36
+ )
30
37
  from kreuzberg._config import discover_config
31
- from kreuzberg._types import HTMLToMarkdownConfig
32
38
 
33
39
  if TYPE_CHECKING:
34
40
  from litestar.datastructures import UploadFile
35
41
 
36
42
 
43
+ class ExtractedImageDict(TypedDict):
44
+ """TypedDict for extracted image JSON representation."""
45
+
46
+ data: str
47
+ format: str
48
+ filename: str | None
49
+ page_number: int | None
50
+ dimensions: tuple[int, int] | None
51
+ colorspace: str | None
52
+ bits_per_component: int | None
53
+ is_mask: bool
54
+ description: str | None
55
+
56
+
57
+ class ImageOCRResultDict(TypedDict):
58
+ """TypedDict for image OCR result JSON representation."""
59
+
60
+ image: ExtractedImageDict
61
+ ocr_result: Any
62
+ confidence_score: float | None
63
+ processing_time: float | None
64
+ skipped_reason: str | None
65
+
66
+
37
67
  class HealthResponse(TypedDict):
38
68
  """Response model for health check endpoint."""
39
69
 
@@ -146,68 +176,65 @@ def _create_ocr_config(
146
176
  return config_dict
147
177
 
148
178
 
149
- @lru_cache(maxsize=128)
150
- def _merge_configs_cached(
179
+ def _create_dimension_tuple(width: int | None, height: int | None) -> tuple[int, int] | None:
180
+ """Create a dimension tuple from width and height values.
181
+
182
+ Args:
183
+ width: Width value or None
184
+ height: Height value or None
185
+
186
+ Returns:
187
+ Tuple of (width, height) if both values are not None, otherwise None
188
+ """
189
+ if width is not None and height is not None:
190
+ return (width, height)
191
+ return None
192
+
193
+
194
+ def merge_configs(
151
195
  static_config: ExtractionConfig | None,
152
- query_params: tuple[tuple[str, Any], ...],
153
- header_config: tuple[tuple[str, Any], ...] | None,
196
+ query_params: dict[str, Any],
197
+ header_config: dict[str, Any] | None,
154
198
  ) -> ExtractionConfig:
155
199
  base_config = static_config or ExtractionConfig()
156
200
  config_dict = base_config.to_dict()
157
201
 
158
- query_dict = dict(query_params) if query_params else {}
159
- for key, value in query_dict.items():
202
+ for key, value in query_params.items():
160
203
  if value is not None and key in config_dict:
161
204
  config_dict[key] = _convert_value_type(config_dict[key], value)
162
205
 
163
206
  if header_config:
164
- header_dict = dict(header_config)
165
- for key, value in header_dict.items():
207
+ for key, value in header_config.items():
166
208
  if key in config_dict:
167
209
  config_dict[key] = value
168
210
 
169
211
  if "ocr_config" in config_dict and isinstance(config_dict["ocr_config"], dict):
170
212
  ocr_backend = config_dict.get("ocr_backend")
171
- config_dict["ocr_config"] = _create_ocr_config(ocr_backend, config_dict["ocr_config"])
213
+ config_dict["ocr_config"] = create_ocr_config_cached(ocr_backend, config_dict["ocr_config"])
172
214
 
173
215
  if "gmft_config" in config_dict and isinstance(config_dict["gmft_config"], dict):
174
- config_dict["gmft_config"] = GMFTConfig(**config_dict["gmft_config"])
216
+ config_dict["gmft_config"] = create_gmft_config_cached(config_dict["gmft_config"])
175
217
 
176
218
  if "language_detection_config" in config_dict and isinstance(config_dict["language_detection_config"], dict):
177
- config_dict["language_detection_config"] = LanguageDetectionConfig(**config_dict["language_detection_config"])
219
+ config_dict["language_detection_config"] = create_language_detection_config_cached(
220
+ config_dict["language_detection_config"]
221
+ )
178
222
 
179
223
  if "spacy_entity_extraction_config" in config_dict and isinstance(
180
224
  config_dict["spacy_entity_extraction_config"], dict
181
225
  ):
182
- config_dict["spacy_entity_extraction_config"] = SpacyEntityExtractionConfig(
183
- **config_dict["spacy_entity_extraction_config"]
226
+ config_dict["spacy_entity_extraction_config"] = create_spacy_config_cached(
227
+ config_dict["spacy_entity_extraction_config"]
184
228
  )
185
229
 
186
230
  if "html_to_markdown_config" in config_dict and isinstance(config_dict["html_to_markdown_config"], dict):
187
- config_dict["html_to_markdown_config"] = HTMLToMarkdownConfig(**config_dict["html_to_markdown_config"])
231
+ config_dict["html_to_markdown_config"] = create_html_markdown_config_cached(
232
+ config_dict["html_to_markdown_config"]
233
+ )
188
234
 
189
235
  return ExtractionConfig(**config_dict)
190
236
 
191
237
 
192
- def _make_hashable(obj: Any) -> Any:
193
- if isinstance(obj, dict):
194
- return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
195
- if isinstance(obj, list):
196
- return tuple(_make_hashable(item) for item in obj)
197
- return obj
198
-
199
-
200
- def merge_configs(
201
- static_config: ExtractionConfig | None,
202
- query_params: dict[str, Any],
203
- header_config: dict[str, Any] | None,
204
- ) -> ExtractionConfig:
205
- query_tuple = tuple(sorted(query_params.items())) if query_params else ()
206
- header_tuple = _make_hashable(header_config) if header_config else None
207
-
208
- return _merge_configs_cached(static_config, query_tuple, header_tuple)
209
-
210
-
211
238
  @post("/extract", operation_id="ExtractFiles")
212
239
  async def handle_files_upload( # noqa: PLR0913
213
240
  request: Request[Any, Any, Any],
@@ -223,6 +250,13 @@ async def handle_files_upload( # noqa: PLR0913
223
250
  ocr_backend: Literal["tesseract", "easyocr", "paddleocr"] | None = None,
224
251
  auto_detect_language: str | bool | None = None,
225
252
  pdf_password: str | None = None,
253
+ extract_images: str | bool | None = None,
254
+ ocr_extracted_images: str | bool | None = None,
255
+ image_ocr_backend: Literal["tesseract", "easyocr", "paddleocr"] | None = None,
256
+ image_ocr_min_width: int | None = None,
257
+ image_ocr_min_height: int | None = None,
258
+ image_ocr_max_width: int | None = None,
259
+ image_ocr_max_height: int | None = None,
226
260
  ) -> list[ExtractionResult]:
227
261
  """Extract text, metadata, and structured data from uploaded documents.
228
262
 
@@ -250,11 +284,30 @@ async def handle_files_upload( # noqa: PLR0913
250
284
  ocr_backend: OCR engine to use (tesseract, easyocr, paddleocr)
251
285
  auto_detect_language: Enable automatic language detection
252
286
  pdf_password: Password for encrypted PDF files
287
+ extract_images: Enable image extraction for supported formats
288
+ ocr_extracted_images: Run OCR over extracted images
289
+ image_ocr_backend: Optional backend override for image OCR
290
+ image_ocr_min_width: Minimum image width for OCR eligibility
291
+ image_ocr_min_height: Minimum image height for OCR eligibility
292
+ image_ocr_max_width: Maximum image width for OCR eligibility
293
+ image_ocr_max_height: Maximum image height for OCR eligibility
253
294
 
254
295
  Returns:
255
296
  List of extraction results, one per uploaded file
297
+
298
+ Additional query parameters:
299
+ extract_images: Enable image extraction for supported formats
300
+ ocr_extracted_images: Run OCR over extracted images
301
+ image_ocr_backend: Optional backend override for image OCR
302
+ image_ocr_min_width: Minimum image width for OCR eligibility
303
+ image_ocr_min_height: Minimum image height for OCR eligibility
304
+ image_ocr_max_width: Maximum image width for OCR eligibility
305
+ image_ocr_max_height: Maximum image height for OCR eligibility
256
306
  """
257
- static_config = discover_config()
307
+ static_config = discover_config_cached()
308
+
309
+ min_dims = _create_dimension_tuple(image_ocr_min_width, image_ocr_min_height)
310
+ max_dims = _create_dimension_tuple(image_ocr_max_width, image_ocr_max_height)
258
311
 
259
312
  query_params = {
260
313
  "chunk_content": chunk_content,
@@ -268,12 +321,17 @@ async def handle_files_upload( # noqa: PLR0913
268
321
  "ocr_backend": ocr_backend,
269
322
  "auto_detect_language": auto_detect_language,
270
323
  "pdf_password": pdf_password,
324
+ "extract_images": extract_images,
325
+ "ocr_extracted_images": ocr_extracted_images,
326
+ "image_ocr_backend": image_ocr_backend,
327
+ "image_ocr_min_dimensions": min_dims,
328
+ "image_ocr_max_dimensions": max_dims,
271
329
  }
272
330
 
273
331
  header_config = None
274
332
  if config_header := request.headers.get("X-Extraction-Config"):
275
333
  try:
276
- header_config = loads(config_header)
334
+ header_config = parse_header_config_cached(config_header)
277
335
  except Exception as e:
278
336
  raise ValidationError(f"Invalid JSON in X-Extraction-Config header: {e}", context={"error": str(e)}) from e
279
337
 
@@ -316,18 +374,41 @@ async def get_configuration() -> ConfigurationResponse:
316
374
 
317
375
 
318
376
  def _polars_dataframe_encoder(obj: Any) -> Any:
319
- """Convert polars DataFrame to dict for JSON serialization."""
320
377
  return obj.to_dicts()
321
378
 
322
379
 
323
380
  def _pil_image_encoder(obj: Any) -> str:
324
- """Convert PIL Image to base64 string for JSON serialization."""
325
381
  buffer = io.BytesIO()
326
382
  obj.save(buffer, format="PNG")
327
383
  img_str = base64.b64encode(buffer.getvalue()).decode()
328
384
  return f"data:image/png;base64,{img_str}"
329
385
 
330
386
 
387
+ def _extracted_image_encoder(obj: ExtractedImage) -> ExtractedImageDict:
388
+ encoded_data = base64.b64encode(obj.data).decode()
389
+ return ExtractedImageDict(
390
+ data=f"data:image/{obj.format};base64,{encoded_data}",
391
+ format=obj.format,
392
+ filename=obj.filename,
393
+ page_number=obj.page_number,
394
+ dimensions=obj.dimensions,
395
+ colorspace=obj.colorspace,
396
+ bits_per_component=obj.bits_per_component,
397
+ is_mask=obj.is_mask,
398
+ description=obj.description,
399
+ )
400
+
401
+
402
+ def _image_ocr_result_encoder(obj: ImageOCRResult) -> ImageOCRResultDict:
403
+ return ImageOCRResultDict(
404
+ image=_extracted_image_encoder(obj.image),
405
+ ocr_result=obj.ocr_result,
406
+ confidence_score=obj.confidence_score,
407
+ processing_time=obj.processing_time,
408
+ skipped_reason=obj.skipped_reason,
409
+ )
410
+
411
+
331
412
  openapi_config = OpenAPIConfig(
332
413
  title="Kreuzberg API",
333
414
  version="3.14.0",
@@ -344,10 +425,11 @@ openapi_config = OpenAPIConfig(
344
425
  create_examples=True,
345
426
  )
346
427
 
347
- # Type encoders for custom serialization
348
428
  type_encoders = {
349
429
  pl.DataFrame: _polars_dataframe_encoder,
350
430
  Image.Image: _pil_image_encoder,
431
+ ExtractedImage: _extracted_image_encoder,
432
+ ImageOCRResult: _image_ocr_result_encoder,
351
433
  }
352
434
 
353
435
  app = Litestar(
@@ -360,5 +442,5 @@ app = Litestar(
360
442
  Exception: general_exception_handler,
361
443
  },
362
444
  type_encoders=type_encoders,
363
- request_max_body_size=1024 * 1024 * 1024, # 1GB limit for large file uploads
445
+ request_max_body_size=1024 * 1024 * 1024,
364
446
  )
kreuzberg/_chunker.py CHANGED
@@ -20,14 +20,15 @@ def get_chunker(
20
20
  key = (max_characters, overlap_characters, mime_type)
21
21
  if key not in _chunkers:
22
22
  try:
23
- if mime_type == MARKDOWN_MIME_TYPE:
24
- from semantic_text_splitter import MarkdownSplitter # noqa: PLC0415
23
+ match mime_type:
24
+ case x if x == MARKDOWN_MIME_TYPE:
25
+ from semantic_text_splitter import MarkdownSplitter # noqa: PLC0415
25
26
 
26
- _chunkers[key] = MarkdownSplitter(max_characters, overlap_characters)
27
- else:
28
- from semantic_text_splitter import TextSplitter # noqa: PLC0415
27
+ _chunkers[key] = MarkdownSplitter(max_characters, overlap_characters)
28
+ case _:
29
+ from semantic_text_splitter import TextSplitter # noqa: PLC0415
29
30
 
30
- _chunkers[key] = TextSplitter(max_characters, overlap_characters)
31
+ _chunkers[key] = TextSplitter(max_characters, overlap_characters)
31
32
  except ImportError as e: # pragma: no cover
32
33
  raise MissingDependencyError.create_for_package(
33
34
  dependency_group="chunking", functionality="chunking", package_name="semantic-text-splitter"
kreuzberg/_constants.py CHANGED
@@ -5,3 +5,5 @@ from typing import Final
5
5
  MINIMAL_SUPPORTED_PANDOC_VERSION: Final[int] = 2
6
6
  DEFAULT_MAX_CHARACTERS: Final[int] = 2000
7
7
  DEFAULT_MAX_OVERLAP: Final[int] = 100
8
+
9
+ PDF_POINTS_PER_INCH: Final[float] = 72.0 # Standard PDF unit conversion
@@ -65,12 +65,10 @@ def classify_document(result: ExtractionResult, config: ExtractionConfig) -> tup
65
65
  return None, None
66
66
 
67
67
  translated_text = _get_translated_text(result)
68
- scores = dict.fromkeys(DOCUMENT_CLASSIFIERS, 0)
69
-
70
- for doc_type, patterns in DOCUMENT_CLASSIFIERS.items():
71
- for pattern in patterns:
72
- if re.search(pattern, translated_text):
73
- scores[doc_type] += 1
68
+ scores = {
69
+ doc_type: sum(1 for pattern in patterns if re.search(pattern, translated_text))
70
+ for doc_type, patterns in DOCUMENT_CLASSIFIERS.items()
71
+ }
74
72
 
75
73
  total_score = sum(scores.values())
76
74
  if total_score == 0:
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import os
4
4
  import re
5
5
  from functools import lru_cache
6
+ from itertools import chain
6
7
  from typing import TYPE_CHECKING, Any
7
8
 
8
9
  from kreuzberg._types import Entity, SpacyEntityExtractionConfig
@@ -21,11 +22,15 @@ def extract_entities(
21
22
  ) -> list[Entity]:
22
23
  entities: list[Entity] = []
23
24
  if custom_patterns:
24
- for ent_type, pattern in custom_patterns:
25
- entities.extend(
26
- Entity(type=ent_type, text=match.group(), start=match.start(), end=match.end())
27
- for match in re.finditer(pattern, text)
25
+ entities.extend(
26
+ chain.from_iterable(
27
+ (
28
+ Entity(type=ent_type, text=match.group(), start=match.start(), end=match.end())
29
+ for match in re.finditer(pattern, text)
30
+ )
31
+ for ent_type, pattern in custom_patterns
28
32
  )
33
+ )
29
34
 
30
35
  if spacy_config is None:
31
36
  spacy_config = SpacyEntityExtractionConfig()