vlm4ocr 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr/ocr_engines.py CHANGED
@@ -1,15 +1,18 @@
1
1
  import os
2
- from typing import List, Dict, Union, Generator, Iterable
2
+ from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
3
3
  import importlib
4
4
  import asyncio
5
- from vlm4ocr.utils import get_images_from_pdf, get_images_from_tiff, get_image_from_file, clean_markdown
5
+ from colorama import Fore, Style
6
+ import json
7
+ from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
8
+ from vlm4ocr.data_types import OCRResult
6
9
  from vlm4ocr.vlm_engines import VLMEngine
7
10
 
8
11
  SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
9
12
 
10
13
 
11
14
  class OCREngine:
12
- def __init__(self, vlm_engine:VLMEngine, output_mode:str="markdown", system_prompt:str=None, user_prompt:str=None, page_delimiter:str="auto"):
15
+ def __init__(self, vlm_engine:VLMEngine, output_mode:str="markdown", system_prompt:str=None, user_prompt:str=None):
13
16
  """
14
17
  This class inputs a image or PDF file path and processes them using a VLM inference engine. Outputs plain text or markdown.
15
18
 
@@ -18,17 +21,11 @@ class OCREngine:
18
21
  inference_engine : InferenceEngine
19
22
  The inference engine to use for OCR.
20
23
  output_mode : str, Optional
21
- The output format. Must be 'markdown', 'HTML', or 'text'.
24
+ The output format. Must be 'markdown', 'HTML', 'text', or 'JSON'.
22
25
  system_prompt : str, Optional
23
26
  Custom system prompt. We recommend use a default system prompt by leaving this blank.
24
27
  user_prompt : str, Optional
25
28
  Custom user prompt. It is good to include some information regarding the document. If not specified, a default will be used.
26
- page_delimiter : str, Optional
27
- The delimiter to use between PDF pages.
28
- if 'auto', it will be set to the default page delimiter for the output mode:
29
- 'markdown' -> '\n\n---\n\n'
30
- 'HTML' -> '<br><br>'
31
- 'text' -> '\n\n---\n\n'
32
29
  """
33
30
  # Check inference engine
34
31
  if not isinstance(vlm_engine, VLMEngine):
@@ -36,42 +33,34 @@ class OCREngine:
36
33
  self.vlm_engine = vlm_engine
37
34
 
38
35
  # Check output mode
39
- if output_mode not in ["markdown", "HTML", "text"]:
40
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
36
+ if output_mode not in ["markdown", "HTML", "text", "JSON"]:
37
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'.")
41
38
  self.output_mode = output_mode
42
39
 
43
40
  # System prompt
44
41
  if isinstance(system_prompt, str) and system_prompt:
45
42
  self.system_prompt = system_prompt
46
43
  else:
47
- file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_system_prompt.txt')
48
- with open(file_path, 'r', encoding='utf-8') as f:
44
+ prompt_template_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_system_prompt.txt')
45
+ with prompt_template_path.open('r', encoding='utf-8') as f:
49
46
  self.system_prompt = f.read()
50
47
 
51
48
  # User prompt
52
49
  if isinstance(user_prompt, str) and user_prompt:
53
50
  self.user_prompt = user_prompt
54
51
  else:
55
- file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_user_prompt.txt')
56
- with open(file_path, 'r', encoding='utf-8') as f:
52
+ if self.output_mode == "JSON":
53
+ raise ValueError("user_prompt must be provided when output_mode is 'JSON' to define the JSON structure.")
54
+
55
+ prompt_template_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_user_prompt.txt')
56
+ with prompt_template_path.open('r', encoding='utf-8') as f:
57
57
  self.user_prompt = f.read()
58
58
 
59
- # Page delimiter
60
- if isinstance(page_delimiter, str):
61
- if page_delimiter == "auto":
62
- if self.output_mode == "markdown":
63
- self.page_delimiter = "\n\n---\n\n"
64
- elif self.output_mode == "HTML":
65
- self.page_delimiter = "<br><br>"
66
- else:
67
- self.page_delimiter = "\n\n---\n\n"
68
- else:
69
- self.page_delimiter = page_delimiter
70
- else:
71
- raise ValueError("page_delimiter must be a string")
72
-
59
+ # Image processor
60
+ self.image_processor = ImageProcessor()
73
61
 
74
- def stream_ocr(self, file_path: str, max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
62
+
63
+ def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Generator[Dict[str, str], None, None]:
75
64
  """
76
65
  This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
77
66
  Yields dictionaries with 'type' ('ocr_chunk' or 'page_delimiter') and 'data'.
@@ -80,15 +69,18 @@ class OCREngine:
80
69
  -----------
81
70
  file_path : str
82
71
  The path to the image or PDF file. Must be one of '.pdf', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
83
- max_new_tokens : int, Optional
84
- The maximum number of tokens to generate.
85
- temperature : float, Optional
86
- The temperature to use for sampling.
72
+ rotate_correction : bool, Optional
73
+ If True, applies rotate correction to the images using pytesseract.
74
+ max_dimension_pixels : int, Optional
75
+ The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
87
76
 
88
77
  Returns:
89
78
  --------
90
- Generator[str, None, None]
91
- A generator that yields the output.
79
+ Generator[Dict[str, str], None, None]
80
+ A generator that yields the output:
81
+ {"type": "info", "data": msg}
82
+ {"type": "ocr_chunk", "data": chunk}
83
+ {"type": "page_delimiter", "data": page_delimiter}
92
84
  """
93
85
  # Check file path
94
86
  if not isinstance(file_path, str):
@@ -98,227 +90,403 @@ class OCREngine:
98
90
  file_ext = os.path.splitext(file_path)[1].lower()
99
91
  if file_ext not in SUPPORTED_IMAGE_EXTS:
100
92
  raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
93
+
94
+ # Check if image preprocessing can be applied
95
+ if self.image_processor.has_tesseract==False and rotate_correction:
96
+ raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
101
97
 
102
98
  # PDF or TIFF
103
99
  if file_ext in ['.pdf', '.tif', '.tiff']:
104
- images = get_images_from_pdf(file_path) if file_ext == '.pdf' else get_images_from_tiff(file_path)
100
+ data_loader = PDFDataLoader(file_path) if file_ext == '.pdf' else TIFFDataLoader(file_path)
101
+ images = data_loader.get_all_pages()
102
+ # Check if images were extracted
105
103
  if not images:
106
104
  raise ValueError(f"No images extracted from file: {file_path}")
105
+
106
+ # OCR each image
107
107
  for i, image in enumerate(images):
108
+ # Apply rotate correction if specified and tesseract is available
109
+ if rotate_correction and self.image_processor.has_tesseract:
110
+ try:
111
+ image, _ = self.image_processor.rotate_correction(image)
112
+
113
+ except Exception as e:
114
+ yield {"type": "info", "data": f"Error during rotate correction: {str(e)}"}
115
+
116
+ # Resize the image if max_dimension_pixels is specified
117
+ if max_dimension_pixels is not None:
118
+ try:
119
+ image, _ = self.image_processor.resize(image, max_dimension_pixels=max_dimension_pixels)
120
+ except Exception as e:
121
+ yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
122
+
108
123
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
109
124
  response_stream = self.vlm_engine.chat(
110
125
  messages,
111
- max_new_tokens=max_new_tokens,
112
- temperature=temperature,
113
- stream=True,
114
- **kwrs
126
+ stream=True
115
127
  )
116
128
  for chunk in response_stream:
117
129
  yield {"type": "ocr_chunk", "data": chunk}
118
130
 
119
131
  if i < len(images) - 1:
120
- yield {"type": "page_delimiter", "data": self.page_delimiter}
132
+ yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
121
133
 
122
134
  # Image
123
135
  else:
124
- image = get_image_from_file(file_path)
136
+ data_loader = ImageDataLoader(file_path)
137
+ image = data_loader.get_page(0)
138
+
139
+ # Apply rotate correction if specified and tesseract is available
140
+ if rotate_correction and self.image_processor.has_tesseract:
141
+ try:
142
+ image, _ = self.image_processor.rotate_correction(image)
143
+
144
+ except Exception as e:
145
+ yield {"type": "info", "data": f"Error during rotate correction: {str(e)}"}
146
+
147
+ # Resize the image if max_dimension_pixels is specified
148
+ if max_dimension_pixels is not None:
149
+ try:
150
+ image, _ = self.image_processor.resize(image, max_dimension_pixels=max_dimension_pixels)
151
+ except Exception as e:
152
+ yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
153
+
125
154
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
126
155
  response_stream = self.vlm_engine.chat(
127
156
  messages,
128
- max_new_tokens=max_new_tokens,
129
- temperature=temperature,
130
- stream=True,
131
- **kwrs
157
+ stream=True
132
158
  )
133
159
  for chunk in response_stream:
134
160
  yield {"type": "ocr_chunk", "data": chunk}
135
161
 
136
162
 
137
- def run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096, temperature:float=0.0,
138
- verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32, **kwrs) -> Union[str, Generator[str, None, None]]:
163
+ def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
164
+ max_dimension_pixels:int=None, verbose:bool=False) -> List[OCRResult]:
139
165
  """
140
- This method takes a list of file paths (image, PDF, TIFF) and perform OCR using the VLM inference engine.
166
+ This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
141
167
 
142
168
  Parameters:
143
169
  -----------
144
170
  file_paths : Union[str, Iterable[str]]
145
- A file path or a list of file paths to process. Must be one of '.pdf', '.tif', '.tiff, '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
146
- max_new_tokens : int, Optional
147
- The maximum number of tokens to generate.
148
- temperature : float, Optional
149
- The temperature to use for sampling.
171
+ A file path or a list of file paths to process. Must be one of '.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
172
+ rotate_correction : bool, Optional
173
+ If True, applies rotate correction to the images using pytesseract.
174
+ max_dimension_pixels : int, Optional
175
+ The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
150
176
  verbose : bool, Optional
151
- If True, the function will print the output in terminal.
152
- concurrent : bool, Optional
153
- If True, the function will process the files concurrently.
154
- concurrent_batch_size : int, Optional
155
- The number of images/pages to process concurrently.
177
+ If True, the function will print the output in terminal.
178
+
179
+ Returns:
180
+ --------
181
+ List[OCRResult]
182
+ A list of OCR result objects.
156
183
  """
157
- # if file_paths is a string, convert it to a list
158
184
  if isinstance(file_paths, str):
159
185
  file_paths = [file_paths]
160
-
161
- if not isinstance(file_paths, Iterable):
162
- raise TypeError("file_paths must be a string or an iterable of strings")
163
-
164
- # check if all file paths are valid
186
+
187
+ ocr_results = []
165
188
  for file_path in file_paths:
166
- if not isinstance(file_path, str):
167
- raise TypeError("file_paths must be a string or an iterable of strings")
189
+ # Define OCRResult object
190
+ ocr_result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
191
+ # get file extension
168
192
  file_ext = os.path.splitext(file_path)[1].lower()
193
+ # Check file extension
169
194
  if file_ext not in SUPPORTED_IMAGE_EXTS:
170
- raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
195
+ if verbose:
196
+ print(f"{Fore.RED}Unsupported file type:{Style.RESET_ALL} {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
197
+ ocr_result.status = "error"
198
+ ocr_result.add_page(text=f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}",
199
+ image_processing_status={})
200
+ ocr_results.append(ocr_result)
201
+ continue
202
+
203
+ filename = os.path.basename(file_path)
204
+
205
+ try:
206
+ # Load images from file
207
+ if file_ext == '.pdf':
208
+ data_loader = PDFDataLoader(file_path)
209
+ elif file_ext in ['.tif', '.tiff']:
210
+ data_loader = TIFFDataLoader(file_path)
211
+ else:
212
+ data_loader = ImageDataLoader(file_path)
213
+
214
+ images = data_loader.get_all_pages()
215
+ except Exception as e:
216
+ if verbose:
217
+ print(f"{Fore.RED}Error processing file {filename}:{Style.RESET_ALL} {str(e)}")
218
+ ocr_result.status = "error"
219
+ ocr_result.add_page(text=f"Error processing file {filename}: {str(e)}", image_processing_status={})
220
+ ocr_results.append(ocr_result)
221
+ continue
222
+
223
+ # Check if images were extracted
224
+ if not images:
225
+ if verbose:
226
+ print(f"{Fore.RED}No images extracted from file:{Style.RESET_ALL} {filename}. It might be empty or corrupted.")
227
+ ocr_result.status = "error"
228
+ ocr_result.add_page(text=f"No images extracted from file: {filename}. It might be empty or corrupted.",
229
+ image_processing_status={})
230
+ ocr_results.append(ocr_result)
231
+ continue
232
+
233
+ # OCR images
234
+ for i, image in enumerate(images):
235
+ image_processing_status = {}
236
+ # Apply rotate correction if specified and tesseract is available
237
+ if rotate_correction and self.image_processor.has_tesseract:
238
+ try:
239
+ image, rotation_angle = self.image_processor.rotate_correction(image)
240
+ image_processing_status["rotate_correction"] = {
241
+ "status": "success",
242
+ "rotation_angle": rotation_angle
243
+ }
244
+ if verbose:
245
+ print(f"{Fore.GREEN}Rotate correction applied for {filename} page {i} with angle {rotation_angle} degrees.{Style.RESET_ALL}")
246
+ except Exception as e:
247
+ image_processing_status["rotate_correction"] = {
248
+ "status": "error",
249
+ "error": str(e)
250
+ }
251
+ if verbose:
252
+ print(f"{Fore.RED}Error during rotate correction for {filename}:{Style.RESET_ALL} {rotation_angle['error']}. OCR continues without rotate correction.")
253
+
254
+ # Resize the image if max_dimension_pixels is specified
255
+ if max_dimension_pixels is not None:
256
+ try:
257
+ image, resized = self.image_processor.resize(image, max_dimension_pixels=max_dimension_pixels)
258
+ image_processing_status["resize"] = {
259
+ "status": "success",
260
+ "resized": resized
261
+ }
262
+ if verbose and resized:
263
+ print(f"{Fore.GREEN}Image resized for {filename} page {i} to fit within {max_dimension_pixels} pixels.{Style.RESET_ALL}")
264
+ except Exception as e:
265
+ image_processing_status["resize"] = {
266
+ "status": "error",
267
+ "error": str(e)
268
+ }
269
+ if verbose:
270
+ print(f"{Fore.RED}Error resizing image for {filename}:{Style.RESET_ALL} {resized['error']}. OCR continues without resizing.")
271
+
272
+ try:
273
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
274
+ response = self.vlm_engine.chat(
275
+ messages,
276
+ verbose=verbose,
277
+ stream=False
278
+ )
279
+ # Clean the response if output mode is markdown
280
+ if self.output_mode == "markdown":
281
+ response = clean_markdown(response)
171
282
 
172
- # Concurrent processing
173
- if concurrent:
174
- # Check concurrent_batch_size
175
- if concurrent_batch_size <= 0:
176
- raise ValueError("concurrent_batch_size must be greater than 0")
283
+ # Parse the response if output mode is JSON
284
+ if self.output_mode == "JSON":
285
+ json_list = extract_json(response)
286
+ # Serialize the JSON list to a string
287
+ response = json.dumps(json_list, indent=4)
288
+
289
+ # Add the page to the OCR result
290
+ ocr_result.add_page(text=response,
291
+ image_processing_status=image_processing_status)
292
+
293
+ except Exception as page_e:
294
+ ocr_result.status = "error"
295
+ ocr_result.add_page(text=f"Error during OCR for a page in {filename}: {str(page_e)}",
296
+ image_processing_status={})
297
+ if verbose:
298
+ print(f"{Fore.RED}Error during OCR for a page in {filename}:{Style.RESET_ALL} {page_e}")
299
+
300
+ # Add the OCR result to the list
301
+ ocr_result.status = "success"
302
+ ocr_results.append(ocr_result)
177
303
 
178
304
  if verbose:
179
- Warning("verbose is not supported for concurrent processing.", UserWarning)
305
+ print(f"{Fore.BLUE}Successfully processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
306
+ for page in ocr_result:
307
+ print(page)
308
+ print("-" * 80)
309
+
310
+ return ocr_results
180
311
 
181
- return asyncio.run(self._run_ocr_async(file_paths,
182
- max_new_tokens=max_new_tokens,
183
- temperature=temperature,
184
- concurrent_batch_size=concurrent_batch_size,
185
- **kwrs))
186
-
187
- # Sync processing
188
- return self._run_ocr(file_paths, max_new_tokens=max_new_tokens, temperature=temperature, verbose=verbose, **kwrs)
189
-
190
312
 
191
- def _run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
192
- temperature:float=0.0, verbose:bool=False, **kwrs) -> Iterable[str]:
313
+ def concurrent_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
314
+ max_dimension_pixels:int=None, concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
193
315
  """
194
- This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
316
+ First complete first out. Input and output order not guaranteed.
317
+ This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
318
+ Results are processed concurrently using asyncio.
195
319
 
196
320
  Parameters:
197
321
  -----------
198
322
  file_paths : Union[str, Iterable[str]]
199
323
  A file path or a list of file paths to process. Must be one of '.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
200
- max_new_tokens : int, Optional
201
- The maximum number of tokens to generate.
202
- temperature : float, Optional
203
- The temperature to use for sampling.
204
- verbose : bool, Optional
205
- If True, the function will print the output in terminal.
324
+ rotate_correction : bool, Optional
325
+ If True, applies rotate correction to the images using pytesseract.
326
+ max_dimension_pixels : int, Optional
327
+ The maximum dimension of the image in pixels. Origianl dimensions will be resized to fit in. If None, no resizing is applied.
328
+ concurrent_batch_size : int, Optional
329
+ The number of concurrent VLM calls to make.
330
+ max_file_load : int, Optional
331
+ The maximum number of files to load concurrently. If None, defaults to 2 times of concurrent_batch_size.
206
332
 
207
333
  Returns:
208
334
  --------
209
- Iterable[str]
210
- A list of strings containing the OCR results.
335
+ AsyncGenerator[OCRResult, None]
336
+ A generator that yields OCR result objects as they complete.
211
337
  """
212
- ocr_results = []
213
- for file_path in file_paths:
214
- file_ext = os.path.splitext(file_path)[1].lower()
215
- # PDF or TIFF
216
- if file_ext in ['.pdf', '.tif', '.tiff']:
217
- images = get_images_from_pdf(file_path) if file_ext == '.pdf' else get_images_from_tiff(file_path)
218
- if not images:
219
- raise ValueError(f"No images extracted from file: {file_path}")
220
- results = []
221
- for image in images:
222
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
223
- response = self.vlm_engine.chat(
224
- messages,
225
- max_new_tokens=max_new_tokens,
226
- temperature=temperature,
227
- verbose=verbose,
228
- stream=False,
229
- **kwrs
230
- )
231
- results.append(response)
232
-
233
- ocr_text = self.page_delimiter.join(results)
234
- # Image
235
- else:
236
- image = get_image_from_file(file_path)
237
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
238
- ocr_text = self.vlm_engine.chat(
239
- messages,
240
- max_new_tokens=max_new_tokens,
241
- temperature=temperature,
242
- verbose=verbose,
243
- stream=False,
244
- **kwrs
245
- )
246
-
247
- # Clean markdown
248
- if self.output_mode == "markdown":
249
- ocr_text = clean_markdown(ocr_text)
250
- ocr_results.append(ocr_text)
338
+ if isinstance(file_paths, str):
339
+ file_paths = [file_paths]
340
+
341
+ if max_file_load is None:
342
+ max_file_load = concurrent_batch_size * 2
251
343
 
252
- return ocr_results
344
+ if not isinstance(max_file_load, int) or max_file_load <= 0:
345
+ raise ValueError("max_file_load must be a positive integer")
346
+
347
+ if self.image_processor.has_tesseract==False and rotate_correction:
348
+ raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
253
349
 
350
+ return self._ocr_async(file_paths=file_paths,
351
+ rotate_correction=rotate_correction,
352
+ max_dimension_pixels=max_dimension_pixels,
353
+ concurrent_batch_size=concurrent_batch_size,
354
+ max_file_load=max_file_load)
355
+
254
356
 
255
- async def _run_ocr_async(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
256
- temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[str]:
357
+ async def _ocr_async(self, file_paths: Iterable[str], rotate_correction:bool=False, max_dimension_pixels:int=None,
358
+ concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
257
359
  """
258
- This is the async version of the _run_ocr method.
360
+ Internal method to asynchronously process an iterable of file paths.
361
+ Yields OCRResult objects as they complete. Order not guaranteed.
362
+ concurrent_batch_size controls how many VLM calls are made concurrently.
259
363
  """
260
- # flatten pages/images in file_paths
261
- flat_page_list = []
364
+ vlm_call_semaphore = asyncio.Semaphore(concurrent_batch_size)
365
+ file_load_semaphore = asyncio.Semaphore(max_file_load)
366
+
367
+ tasks = []
262
368
  for file_path in file_paths:
369
+ task = self._ocr_file_with_semaphore(file_load_semaphore=file_load_semaphore,
370
+ vlm_call_semaphore=vlm_call_semaphore,
371
+ file_path=file_path,
372
+ rotate_correction=rotate_correction,
373
+ max_dimension_pixels=max_dimension_pixels)
374
+ tasks.append(task)
375
+
376
+
377
+ for future in asyncio.as_completed(tasks):
378
+ result: OCRResult = await future
379
+ yield result
380
+
381
+ async def _ocr_file_with_semaphore(self, file_load_semaphore:asyncio.Semaphore, vlm_call_semaphore:asyncio.Semaphore,
382
+ file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> OCRResult:
383
+ """
384
+ This internal method takes a semaphore and OCR a single file using the VLM inference engine.
385
+ """
386
+ async with file_load_semaphore:
387
+ filename = os.path.basename(file_path)
263
388
  file_ext = os.path.splitext(file_path)[1].lower()
264
- # PDF or TIFF
265
- if file_ext in ['.pdf', '.tif', '.tiff']:
266
- images = get_images_from_pdf(file_path) if file_ext == '.pdf' else get_images_from_tiff(file_path)
267
- if not images:
268
- flat_page_list.append({'file_path': file_path, 'file_type': "PDF/TIFF", "image": image, "page_num": 0, "total_page_count": 0})
269
- for page_num, image in enumerate(images):
270
- flat_page_list.append({'file_path': file_path, 'file_type': "PDF/TIFF", "image": image, "page_num": page_num, "total_page_count": len(images)})
271
- # Image
272
- else:
273
- image = get_image_from_file(file_path)
274
- flat_page_list.append({'file_path': file_path, 'file_type': "image", "image": image})
275
-
276
- # Process images with asyncio.Semaphore
277
- semaphore = asyncio.Semaphore(concurrent_batch_size)
278
- async def semaphore_helper(page:List[Dict[str,str]], max_new_tokens:int, temperature:float, **kwrs):
389
+ result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
390
+ # check file extension
391
+ if file_ext not in SUPPORTED_IMAGE_EXTS:
392
+ result.status = "error"
393
+ result.add_page(text=f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}",
394
+ image_processing_status={})
395
+ return result
396
+
279
397
  try:
280
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, page["image"])
281
- async with semaphore:
282
- async_task = self.vlm_engine.chat_async(
283
- messages,
284
- max_new_tokens=max_new_tokens,
285
- temperature=temperature,
286
- **kwrs
398
+ # Load images from file
399
+ if file_ext == '.pdf':
400
+ data_loader = PDFDataLoader(file_path)
401
+ elif file_ext in ['.tif', '.tiff']:
402
+ data_loader = TIFFDataLoader(file_path)
403
+ else:
404
+ data_loader = ImageDataLoader(file_path)
405
+
406
+ except Exception as e:
407
+ result.status = "error"
408
+ result.add_page(text=f"Error processing file {filename}: {str(e)}", image_processing_status={})
409
+ return result
410
+
411
+ try:
412
+ page_processing_tasks = []
413
+ for page_index in range(data_loader.get_page_count()):
414
+ task = self._ocr_page_with_semaphore(
415
+ vlm_call_semaphore=vlm_call_semaphore,
416
+ data_loader=data_loader,
417
+ page_index=page_index,
418
+ rotate_correction=rotate_correction,
419
+ max_dimension_pixels=max_dimension_pixels
287
420
  )
288
- return await async_task
421
+ page_processing_tasks.append(task)
422
+
423
+ if page_processing_tasks:
424
+ processed_page_results = await asyncio.gather(*page_processing_tasks)
425
+ for text, image_processing_status in processed_page_results:
426
+ result.add_page(text=text, image_processing_status=image_processing_status)
427
+
289
428
  except Exception as e:
290
- print(f"Error processing image: {e}")
291
- return f"[Error: {e}]"
429
+ result.status = "error"
430
+ result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
431
+ return result
292
432
 
293
- tasks = []
294
- for page in flat_page_list:
295
- async_task = semaphore_helper(
296
- page,
297
- max_new_tokens=max_new_tokens,
298
- temperature=temperature,
299
- **kwrs
300
- )
301
- tasks.append(asyncio.create_task(async_task))
433
+ # Set status to success if no errors occurred
434
+ result.status = "success"
435
+ return result
302
436
 
303
- responses = await asyncio.gather(*tasks)
437
+ async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
438
+ page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Tuple[str, Dict[str, str]]:
439
+ """
440
+ This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
304
441
 
305
- # Restructure the results
306
- ocr_results = []
307
- page_text_buffer = ""
308
- for page, ocr_text in zip(flat_page_list, responses):
309
- # PDF or TIFF
310
- if page['file_type'] == "PDF/TIFF":
311
- page_text_buffer += ocr_text + self.page_delimiter
312
- if page['page_num'] == page['total_page_count'] - 1:
313
- if self.output_mode == "markdown":
314
- page_text_buffer = clean_markdown(page_text_buffer)
315
- ocr_results.append(page_text_buffer)
316
- page_text_buffer = ""
317
- # Image
318
- if page['file_type'] == "image":
319
- if self.output_mode == "markdown":
320
- ocr_text = clean_markdown(ocr_text)
321
- ocr_results.append(ocr_text)
322
-
323
- return ocr_results
442
+ Returns:
443
+ -------
444
+ Tuple[str, Dict[str, str]]
445
+ A tuple containing the OCR text and a dictionary with image processing status.
446
+ """
447
+ async with vlm_call_semaphore:
448
+ image = await data_loader.get_page_async(page_index)
449
+ image_processing_status = {}
450
+ # Apply rotate correction if specified and tesseract is available
451
+ if rotate_correction and self.image_processor.has_tesseract:
452
+ try:
453
+ image, rotation_angle = await self.image_processor.rotate_correction_async(image)
454
+ image_processing_status["rotate_correction"] = {
455
+ "status": "success",
456
+ "rotation_angle": rotation_angle
457
+ }
458
+ except Exception as e:
459
+ image_processing_status["rotate_correction"] = {
460
+ "status": "error",
461
+ "error": str(e)
462
+ }
463
+
464
+ # Resize the image if max_dimension_pixels is specified
465
+ if max_dimension_pixels is not None:
466
+ try:
467
+ image, resized = await self.image_processor.resize_async(image, max_dimension_pixels=max_dimension_pixels)
468
+ image_processing_status["resize"] = {
469
+ "status": "success",
470
+ "resized": resized
471
+ }
472
+ except Exception as e:
473
+ image_processing_status["resize"] = {
474
+ "status": "error",
475
+ "error": str(e)
476
+ }
477
+
478
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
479
+ ocr_text = await self.vlm_engine.chat_async(
480
+ messages,
481
+ )
482
+ # Clean the OCR text if output mode is markdown
483
+ if self.output_mode == "markdown":
484
+ ocr_text = clean_markdown(ocr_text)
485
+
486
+ # Parse the response if output mode is JSON
487
+ if self.output_mode == "JSON":
488
+ json_list = extract_json(ocr_text)
489
+ # Serialize the JSON list to a string
490
+ ocr_text = json.dumps(json_list, indent=4)
324
491
 
492
+ return ocr_text, image_processing_status