kreuzberg 1.7.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kreuzberg/_tesseract.py CHANGED
@@ -2,152 +2,33 @@ from __future__ import annotations
2
2
 
3
3
  import re
4
4
  import subprocess
5
- from asyncio import gather
5
+ import sys
6
6
  from enum import Enum
7
+ from functools import partial
7
8
  from os import PathLike
8
- from tempfile import NamedTemporaryFile
9
- from typing import Any, Literal, TypeVar, Union
9
+ from typing import Final, TypeVar, Union, cast
10
10
 
11
+ from anyio import CapacityLimiter, create_task_group, to_process
11
12
  from anyio import Path as AsyncPath
12
13
  from PIL.Image import Image
13
14
 
15
+ from kreuzberg import ExtractionResult, ParsingError
16
+ from kreuzberg._constants import DEFAULT_MAX_PROCESSES
17
+ from kreuzberg._mime_types import PLAIN_TEXT_MIME_TYPE
18
+ from kreuzberg._string import normalize_spaces
14
19
  from kreuzberg._sync import run_sync
20
+ from kreuzberg._tmp import create_temp_file
15
21
  from kreuzberg.exceptions import MissingDependencyError, OCRError
16
22
 
23
+ if sys.version_info < (3, 11): # pragma: no cover
24
+ from exceptiongroup import ExceptionGroup # type: ignore[import-not-found]
25
+
26
+ MINIMAL_SUPPORTED_TESSERACT_VERSION: Final[int] = 5
27
+
17
28
  version_ref = {"checked": False}
18
29
 
19
30
  T = TypeVar("T", bound=Union[Image, PathLike[str], str])
20
31
 
21
- SupportedLanguages = Literal[
22
- "afr",
23
- "amh",
24
- "ara",
25
- "asm",
26
- "aze",
27
- "aze_cyrl",
28
- "bel",
29
- "ben",
30
- "bod",
31
- "bos",
32
- "bre",
33
- "bul",
34
- "cat",
35
- "ceb",
36
- "ces",
37
- "chi_sim",
38
- "chi_tra",
39
- "chr",
40
- "cos",
41
- "cym",
42
- "dan",
43
- "dan_frak",
44
- "deu",
45
- "deu_frak",
46
- "deu_latf",
47
- "dzo",
48
- "ell",
49
- "eng",
50
- "enm",
51
- "epo",
52
- "equ",
53
- "est",
54
- "eus",
55
- "fao",
56
- "fas",
57
- "fil",
58
- "fin",
59
- "fra",
60
- "frk",
61
- "frm",
62
- "fry",
63
- "gla",
64
- "gle",
65
- "glg",
66
- "grc",
67
- "guj",
68
- "hat",
69
- "heb",
70
- "hin",
71
- "hrv",
72
- "hun",
73
- "hye",
74
- "iku",
75
- "ind",
76
- "isl",
77
- "ita",
78
- "ita_old",
79
- "jav",
80
- "jpn",
81
- "kan",
82
- "kat",
83
- "kat_old",
84
- "kaz",
85
- "khm",
86
- "kir",
87
- "kmr",
88
- "kor",
89
- "kor_vert",
90
- "kur",
91
- "lao",
92
- "lat",
93
- "lav",
94
- "lit",
95
- "ltz",
96
- "mal",
97
- "mar",
98
- "mkd",
99
- "mlt",
100
- "mon",
101
- "mri",
102
- "msa",
103
- "mya",
104
- "nep",
105
- "nld",
106
- "nor",
107
- "oci",
108
- "ori",
109
- "osd",
110
- "pan",
111
- "pol",
112
- "por",
113
- "pus",
114
- "que",
115
- "ron",
116
- "rus",
117
- "san",
118
- "sin",
119
- "slk",
120
- "slk_frak",
121
- "slv",
122
- "snd",
123
- "spa",
124
- "spa_old",
125
- "sqi",
126
- "srp",
127
- "srp_latn",
128
- "sun",
129
- "swa",
130
- "swe",
131
- "syr",
132
- "tam",
133
- "tat",
134
- "tel",
135
- "tgk",
136
- "tgl",
137
- "tha",
138
- "tir",
139
- "ton",
140
- "tur",
141
- "uig",
142
- "ukr",
143
- "urd",
144
- "uzb",
145
- "uzb_cyrl",
146
- "vie",
147
- "yid",
148
- "yor",
149
- ]
150
-
151
32
 
152
33
  class PSMMode(Enum):
153
34
  """Enum for Tesseract Page Segmentation Modes (PSM) with human-readable values."""
@@ -189,7 +70,7 @@ async def validate_tesseract_version() -> None:
189
70
  command = ["tesseract", "--version"]
190
71
  result = await run_sync(subprocess.run, command, capture_output=True)
191
72
  version_match = re.search(r"tesseract\s+v?(\d+)", result.stdout.decode())
192
- if not version_match or int(version_match.group(1)) < 5:
73
+ if not version_match or int(version_match.group(1)) < MINIMAL_SUPPORTED_TESSERACT_VERSION:
193
74
  raise MissingDependencyError("Tesseract version 5 or above is required.")
194
75
 
195
76
  version_ref["checked"] = True
@@ -198,94 +79,96 @@ async def validate_tesseract_version() -> None:
198
79
 
199
80
 
200
81
  async def process_file(
201
- input_file: str | PathLike[str], *, language: SupportedLanguages, psm: PSMMode, **kwargs: Any
202
- ) -> str:
82
+ input_file: str | PathLike[str],
83
+ *,
84
+ language: str,
85
+ psm: PSMMode,
86
+ max_processes: int = DEFAULT_MAX_PROCESSES,
87
+ ) -> ExtractionResult:
203
88
  """Process a single image file using Tesseract OCR.
204
89
 
205
90
  Args:
206
91
  input_file: The path to the image file to process.
207
92
  language: The language code for OCR.
208
93
  psm: Page segmentation mode.
209
- **kwargs: Additional Tesseract configuration options as key-value pairs.
94
+ max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1).
210
95
 
211
96
  Raises:
212
97
  OCRError: If OCR fails to extract text from the image.
213
98
 
214
99
  Returns:
215
- str: Extracted text from the image.
100
+ ExtractionResult: The extracted text from the image.
216
101
  """
217
- with NamedTemporaryFile(suffix=".txt", delete=False) as output_file:
218
- # this is needed because tesseract adds .txt to the output file
219
- try:
220
- output_file_name = output_file.name.replace(".txt", "")
221
- command = [
222
- "tesseract",
223
- str(input_file),
224
- output_file_name,
225
- "-l",
226
- language,
227
- "--psm",
228
- str(psm.value),
229
- ]
230
-
231
- for key, value in kwargs.items():
232
- command.extend(["-c", f"{key}={value}"])
233
-
234
- result = await run_sync(
235
- subprocess.run,
236
- command,
237
- capture_output=True,
238
- )
239
-
240
- if not result.returncode == 0:
241
- raise OCRError("OCR failed with a non-0 return code.")
242
-
243
- output = await AsyncPath(output_file.name).read_text("utf-8")
244
- return output.strip()
245
- except (RuntimeError, OSError) as e:
246
- raise OCRError("Failed to OCR using tesseract") from e
247
-
248
- finally:
249
- output_file.close()
250
- await AsyncPath(output_file.name).unlink()
251
-
252
-
253
- async def process_image(image: Image, *, language: SupportedLanguages, psm: PSMMode, **kwargs: Any) -> str:
102
+ output_path, unlink = await create_temp_file(".txt")
103
+ try:
104
+ output_base = str(output_path).replace(".txt", "")
105
+ command = [
106
+ "tesseract",
107
+ str(input_file),
108
+ output_base,
109
+ "-l",
110
+ language,
111
+ "--psm",
112
+ str(psm.value),
113
+ ]
114
+
115
+ result = await to_process.run_sync(
116
+ partial(subprocess.run, capture_output=True),
117
+ command,
118
+ limiter=CapacityLimiter(max_processes),
119
+ cancellable=True,
120
+ )
121
+
122
+ if not result.returncode == 0:
123
+ raise OCRError("OCR failed with a non-0 return code.")
124
+
125
+ output = await AsyncPath(output_path).read_text("utf-8")
126
+ return ExtractionResult(content=normalize_spaces(output), mime_type=PLAIN_TEXT_MIME_TYPE, metadata={})
127
+ except (RuntimeError, OSError) as e:
128
+ raise OCRError("Failed to OCR using tesseract") from e
129
+ finally:
130
+ await unlink()
131
+
132
+
133
+ async def process_image(
134
+ image: Image,
135
+ *,
136
+ language: str,
137
+ psm: PSMMode,
138
+ max_processes: int = DEFAULT_MAX_PROCESSES,
139
+ ) -> ExtractionResult:
254
140
  """Process a single Pillow Image using Tesseract OCR.
255
141
 
256
142
  Args:
257
143
  image: The Pillow Image to process.
258
144
  language: The language code for OCR.
259
145
  psm: Page segmentation mode.
260
- **kwargs: Additional Tesseract configuration options as key-value pairs.
146
+ max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1).
261
147
 
262
148
  Returns:
263
- str: Extracted text from the image.
149
+ ExtractionResult: The extracted text from the image.
264
150
  """
265
- with NamedTemporaryFile(suffix=".png", delete=False) as image_file:
266
- try:
267
- await run_sync(image.save, image_file.name, format="PNG")
268
- return await process_file(image_file.name, language=language, psm=psm, **kwargs)
269
-
270
- finally:
271
- image_file.close()
272
- await AsyncPath(image_file.name).unlink()
151
+ image_path, unlink = await create_temp_file(".png")
152
+ await run_sync(image.save, str(image_path), format="PNG")
153
+ result = await process_file(image_path, language=language, psm=psm, max_processes=max_processes)
154
+ await unlink()
155
+ return result
273
156
 
274
157
 
275
158
  async def process_image_with_tesseract(
276
159
  image: Image | PathLike[str] | str,
277
160
  *,
278
- language: SupportedLanguages = "eng",
161
+ language: str = "eng",
279
162
  psm: PSMMode = PSMMode.AUTO,
280
- **kwargs: Any,
281
- ) -> str:
163
+ max_processes: int = DEFAULT_MAX_PROCESSES,
164
+ ) -> ExtractionResult:
282
165
  """Run Tesseract OCR asynchronously on a single Pillow Image or a list of Pillow Images.
283
166
 
284
167
  Args:
285
168
  image: A single Pillow Image, a pathlike or a string or a list of Pillow Images to process.
286
169
  language: The language code for OCR (default: "eng").
287
170
  psm: Page segmentation mode (default: PSMMode.AUTO).
288
- **kwargs: Additional Tesseract configuration options as key-value pairs.
171
+ max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1).
289
172
 
290
173
  Raises:
291
174
  ValueError: If the input is not a Pillow Image or a list of Pillow Images.
@@ -296,10 +179,10 @@ async def process_image_with_tesseract(
296
179
  await validate_tesseract_version()
297
180
 
298
181
  if isinstance(image, Image):
299
- return await process_image(image, language=language, psm=psm, **kwargs)
182
+ return await process_image(image, language=language, psm=psm, max_processes=max_processes)
300
183
 
301
184
  if isinstance(image, (PathLike, str)):
302
- return await process_file(image, language=language, psm=psm, **kwargs)
185
+ return await process_file(image, language=language, psm=psm, max_processes=max_processes)
303
186
 
304
187
  raise ValueError("Input must be one of: str, Pathlike or Pillow Image.")
305
188
 
@@ -307,22 +190,36 @@ async def process_image_with_tesseract(
307
190
  async def batch_process_images(
308
191
  images: list[T],
309
192
  *,
310
- language: SupportedLanguages = "eng",
193
+ language: str = "eng",
311
194
  psm: PSMMode = PSMMode.AUTO,
312
- **kwargs: Any,
313
- ) -> list[str]:
314
- """Run Tesseract OCR asynchronously on a single Pillow Image or a list of Pillow Images.
195
+ max_processes: int = DEFAULT_MAX_PROCESSES,
196
+ ) -> list[ExtractionResult]:
197
+ """Run Tesseract OCR asynchronously on multiple images with controlled concurrency.
315
198
 
316
199
  Args:
317
200
  images: A list of Pillow Images, paths or strings to process.
318
201
  language: The language code for OCR (default: "eng").
319
202
  psm: Page segmentation mode (default: PSMMode.AUTO).
320
- **kwargs: Additional Tesseract configuration options as key-value pairs.
203
+ max_processes: Maximum number of concurrent processes. Defaults to CPU count / 2 (minimum 1).
204
+
205
+ Raises:
206
+ ParsingError: If OCR fails to extract text from any of the images.
321
207
 
322
208
  Returns:
323
- Extracted text as a string (for single image) or a list of strings (for multiple images).
209
+ List of ExtractionResult objects, one per input image.
324
210
  """
325
211
  await validate_tesseract_version()
326
- return await gather(
327
- *[process_image_with_tesseract(image, language=language, psm=psm, **kwargs) for image in images]
328
- )
212
+ results = cast(list[ExtractionResult], list(range(len(images))))
213
+
214
+ async def _process_image(index: int, image: T) -> None:
215
+ results[index] = await process_image_with_tesseract(
216
+ image, language=language, psm=psm, max_processes=max_processes
217
+ )
218
+
219
+ try:
220
+ async with create_task_group() as tg:
221
+ for i, image in enumerate(images):
222
+ tg.start_soon(_process_image, i, image)
223
+ return results
224
+ except ExceptionGroup as eg:
225
+ raise ParsingError("Failed to process images with Tesseract") from eg
kreuzberg/_tmp.py ADDED
@@ -0,0 +1,37 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import suppress
4
+ from pathlib import Path
5
+ from tempfile import NamedTemporaryFile
6
+ from typing import TYPE_CHECKING, Callable
7
+
8
+ from anyio import Path as AsyncPath
9
+
10
+ from kreuzberg._sync import run_sync
11
+
12
+ if TYPE_CHECKING: # pragma: no cover
13
+ from collections.abc import Coroutine
14
+
15
+
16
+ async def create_temp_file(
17
+ extension: str, content: bytes | None = None
18
+ ) -> tuple[Path, Callable[[], Coroutine[None, None, None]]]:
19
+ """Create a temporary file that is closed.
20
+
21
+ Args:
22
+ extension: The file extension.
23
+ content: The content to write to the file.
24
+
25
+ Returns:
26
+ The temporary file path.
27
+ """
28
+ file = await run_sync(NamedTemporaryFile, suffix=extension, delete=False)
29
+ if content:
30
+ await AsyncPath(file.name).write_bytes(content)
31
+ await run_sync(file.close)
32
+
33
+ async def unlink() -> None:
34
+ with suppress(OSError, PermissionError):
35
+ await AsyncPath(file.name).unlink(missing_ok=True)
36
+
37
+ return Path(file.name), unlink
kreuzberg/_types.py ADDED
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from typing import NamedTuple, TypedDict
5
+
6
+ if sys.version_info < (3, 11): # pragma: no cover
7
+ from typing_extensions import NotRequired
8
+ else: # pragma: no cover
9
+ from typing import NotRequired
10
+
11
+
12
+ class Metadata(TypedDict, total=False):
13
+ """Document metadata.
14
+
15
+ All fields are optional but will only be included if they contain non-empty values.
16
+ Any field that would be empty or None is omitted from the dictionary.
17
+
18
+ Different documents and extraction methods will yield different metadata.
19
+ """
20
+
21
+ title: NotRequired[str]
22
+ """Document title."""
23
+ subtitle: NotRequired[str]
24
+ """Document subtitle."""
25
+ abstract: NotRequired[str | list[str]]
26
+ """Document abstract, summary or description."""
27
+ authors: NotRequired[list[str]]
28
+ """List of document authors."""
29
+ date: NotRequired[str]
30
+ """Document date as string to preserve original format."""
31
+ subject: NotRequired[str]
32
+ """Document subject or topic."""
33
+ description: NotRequired[str]
34
+ """Extended description."""
35
+ keywords: NotRequired[list[str]]
36
+ """Keywords or tags."""
37
+ categories: NotRequired[list[str]]
38
+ """Categories or classifications."""
39
+ version: NotRequired[str]
40
+ """Version identifier."""
41
+ language: NotRequired[str]
42
+ """Document language code."""
43
+ references: NotRequired[list[str]]
44
+ """Reference entries."""
45
+ citations: NotRequired[list[str]]
46
+ """Citation identifiers."""
47
+ copyright: NotRequired[str]
48
+ """Copyright information."""
49
+ license: NotRequired[str]
50
+ """License information."""
51
+ identifier: NotRequired[str]
52
+ """Document identifier."""
53
+ publisher: NotRequired[str]
54
+ """Publisher name."""
55
+ contributors: NotRequired[list[str]]
56
+ """Additional contributors."""
57
+ creator: NotRequired[str]
58
+ """Document creator."""
59
+ institute: NotRequired[str | list[str]]
60
+ """Institute or organization."""
61
+
62
+
63
+ class ExtractionResult(NamedTuple):
64
+ """The result of a file extraction."""
65
+
66
+ content: str
67
+ """The extracted content."""
68
+ mime_type: str
69
+ """The mime type of the content."""
70
+ metadata: Metadata
71
+ """The metadata of the content."""
kreuzberg/_xlsx.py ADDED
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ from io import StringIO
5
+ from typing import TYPE_CHECKING, cast
6
+
7
+ from anyio import Path as AsyncPath
8
+ from anyio import create_task_group
9
+ from python_calamine import CalamineWorkbook
10
+
11
+ from kreuzberg import ExtractionResult, ParsingError
12
+ from kreuzberg._mime_types import MARKDOWN_MIME_TYPE
13
+ from kreuzberg._pandoc import process_file_with_pandoc
14
+ from kreuzberg._string import normalize_spaces
15
+ from kreuzberg._sync import run_sync
16
+ from kreuzberg._tmp import create_temp_file
17
+
18
+ if TYPE_CHECKING: # pragma: no cover
19
+ from pathlib import Path
20
+
21
+
22
+ async def extract_xlsx_file(input_file: Path) -> ExtractionResult:
23
+ """Extract text from an XLSX file by converting it to CSV and then to markdown.
24
+
25
+ Args:
26
+ input_file: The path to the XLSX file.
27
+
28
+ Returns:
29
+ The extracted text content.
30
+
31
+ Raises:
32
+ ParsingError: If the XLSX file could not be parsed.
33
+ """
34
+ try:
35
+ workbook: CalamineWorkbook = await run_sync(CalamineWorkbook.from_path, str(input_file))
36
+
37
+ results = cast(list[str], [None] * len(workbook.sheet_names))
38
+
39
+ async def convert_sheet_to_text(sheet_name: str) -> None:
40
+ nonlocal results
41
+ values = await run_sync(workbook.get_sheet_by_name(sheet_name).to_python)
42
+
43
+ csv_buffer = StringIO()
44
+ writer = csv.writer(csv_buffer)
45
+
46
+ for row in values:
47
+ writer.writerow(row)
48
+
49
+ csv_data = csv_buffer.getvalue()
50
+ csv_buffer.close()
51
+
52
+ from kreuzberg._tmp import create_temp_file
53
+
54
+ csv_path, unlink = await create_temp_file(".csv")
55
+ await AsyncPath(csv_path).write_text(csv_data)
56
+ result = await process_file_with_pandoc(csv_path, mime_type="text/csv")
57
+ results[workbook.sheet_names.index(sheet_name)] = f"## {sheet_name}\n\n{normalize_spaces(result.content)}"
58
+ await unlink()
59
+
60
+ async with create_task_group() as tg:
61
+ for sheet_name in workbook.sheet_names:
62
+ tg.start_soon(convert_sheet_to_text, sheet_name)
63
+
64
+ return ExtractionResult(
65
+ content="\n\n".join(results),
66
+ mime_type=MARKDOWN_MIME_TYPE,
67
+ metadata={},
68
+ )
69
+ except Exception as e:
70
+ raise ParsingError(
71
+ "Could not extract text from XLSX",
72
+ context={
73
+ "error": str(e),
74
+ },
75
+ ) from e
76
+
77
+
78
+ async def extract_xlsx_content(content: bytes) -> ExtractionResult:
79
+ """Extract text from an XLSX file content.
80
+
81
+ Args:
82
+ content: The XLSX file content.
83
+
84
+ Returns:
85
+ The extracted text content.
86
+ """
87
+ xlsx_path, unlink = await create_temp_file(".xlsx")
88
+
89
+ await AsyncPath(xlsx_path).write_bytes(content)
90
+ result = await extract_xlsx_file(xlsx_path)
91
+ await unlink()
92
+ return result