vlm4ocr 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr/ocr_engines.py CHANGED
@@ -1,15 +1,18 @@
1
1
  import os
2
- from typing import List, Dict, Union, Generator, Iterable
2
+ from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
3
3
  import importlib
4
4
  import asyncio
5
- from vlm4ocr.utils import get_images_from_pdf, get_images_from_tiff, get_image_from_file, clean_markdown
5
+ from colorama import Fore, Style
6
+ from PIL import Image
7
+ from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, get_default_page_delimiter
8
+ from vlm4ocr.data_types import OCRResult
6
9
  from vlm4ocr.vlm_engines import VLMEngine
7
10
 
8
11
  SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
9
12
 
10
13
 
11
14
  class OCREngine:
12
- def __init__(self, vlm_engine:VLMEngine, output_mode:str="markdown", system_prompt:str=None, user_prompt:str=None, page_delimiter:str="auto"):
15
+ def __init__(self, vlm_engine:VLMEngine, output_mode:str="markdown", system_prompt:str=None, user_prompt:str=None):
13
16
  """
14
17
  This class inputs a image or PDF file path and processes them using a VLM inference engine. Outputs plain text or markdown.
15
18
 
@@ -23,12 +26,6 @@ class OCREngine:
23
26
  Custom system prompt. We recommend use a default system prompt by leaving this blank.
24
27
  user_prompt : str, Optional
25
28
  Custom user prompt. It is good to include some information regarding the document. If not specified, a default will be used.
26
- page_delimiter : str, Optional
27
- The delimiter to use between PDF pages.
28
- if 'auto', it will be set to the default page delimiter for the output mode:
29
- 'markdown' -> '\n\n---\n\n'
30
- 'HTML' -> '<br><br>'
31
- 'text' -> '\n\n---\n\n'
32
29
  """
33
30
  # Check inference engine
34
31
  if not isinstance(vlm_engine, VLMEngine):
@@ -44,34 +41,23 @@ class OCREngine:
44
41
  if isinstance(system_prompt, str) and system_prompt:
45
42
  self.system_prompt = system_prompt
46
43
  else:
47
- file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_system_prompt.txt')
48
- with open(file_path, 'r', encoding='utf-8') as f:
44
+ prompt_template_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_system_prompt.txt')
45
+ with prompt_template_path.open('r', encoding='utf-8') as f:
49
46
  self.system_prompt = f.read()
50
47
 
51
48
  # User prompt
52
49
  if isinstance(user_prompt, str) and user_prompt:
53
50
  self.user_prompt = user_prompt
54
51
  else:
55
- file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_user_prompt.txt')
56
- with open(file_path, 'r', encoding='utf-8') as f:
52
+ prompt_template_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_user_prompt.txt')
53
+ with prompt_template_path.open('r', encoding='utf-8') as f:
57
54
  self.user_prompt = f.read()
58
55
 
59
- # Page delimiter
60
- if isinstance(page_delimiter, str):
61
- if page_delimiter == "auto":
62
- if self.output_mode == "markdown":
63
- self.page_delimiter = "\n\n---\n\n"
64
- elif self.output_mode == "HTML":
65
- self.page_delimiter = "<br><br>"
66
- else:
67
- self.page_delimiter = "\n\n---\n\n"
68
- else:
69
- self.page_delimiter = page_delimiter
70
- else:
71
- raise ValueError("page_delimiter must be a string")
72
-
56
+ # Image processor
57
+ self.image_processor = ImageProcessor()
73
58
 
74
- def stream_ocr(self, file_path: str, max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
59
+
60
+ def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Generator[Dict[str, str], None, None]:
75
61
  """
76
62
  This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
77
63
  Yields dictionaries with 'type' ('ocr_chunk' or 'page_delimiter') and 'data'.
@@ -80,15 +66,18 @@ class OCREngine:
80
66
  -----------
81
67
  file_path : str
82
68
  The path to the image or PDF file. Must be one of '.pdf', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
83
- max_new_tokens : int, Optional
84
- The maximum number of tokens to generate.
85
- temperature : float, Optional
86
- The temperature to use for sampling.
69
+ rotate_correction : bool, Optional
70
+ If True, applies rotate correction to the images using pytesseract.
71
+ max_dimension_pixels : int, Optional
72
+ The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
87
73
 
88
74
  Returns:
89
75
  --------
90
- Generator[str, None, None]
91
- A generator that yields the output.
76
+ Generator[Dict[str, str], None, None]
77
+ A generator that yields the output:
78
+ {"type": "info", "data": msg}
79
+ {"type": "ocr_chunk", "data": chunk}
80
+ {"type": "page_delimiter", "data": page_delimiter}
92
81
  """
93
82
  # Check file path
94
83
  if not isinstance(file_path, str):
@@ -98,227 +87,389 @@ class OCREngine:
98
87
  file_ext = os.path.splitext(file_path)[1].lower()
99
88
  if file_ext not in SUPPORTED_IMAGE_EXTS:
100
89
  raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
90
+
91
+ # Check if image preprocessing can be applied
92
+ if self.image_processor.has_tesseract==False and rotate_correction:
93
+ raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
101
94
 
102
95
  # PDF or TIFF
103
96
  if file_ext in ['.pdf', '.tif', '.tiff']:
104
- images = get_images_from_pdf(file_path) if file_ext == '.pdf' else get_images_from_tiff(file_path)
97
+ data_loader = PDFDataLoader(file_path) if file_ext == '.pdf' else TIFFDataLoader(file_path)
98
+ images = data_loader.get_all_pages()
99
+ # Check if images were extracted
105
100
  if not images:
106
101
  raise ValueError(f"No images extracted from file: {file_path}")
102
+
103
+ # OCR each image
107
104
  for i, image in enumerate(images):
105
+ # Apply rotate correction if specified and tesseract is available
106
+ if rotate_correction and self.image_processor.has_tesseract:
107
+ try:
108
+ image, _ = self.image_processor.rotate_correction(image)
109
+
110
+ except Exception as e:
111
+ yield {"type": "info", "data": f"Error during rotate correction: {str(e)}"}
112
+
113
+ # Resize the image if max_dimension_pixels is specified
114
+ if max_dimension_pixels is not None:
115
+ try:
116
+ image, _ = self.image_processor.resize(image, max_dimension_pixels=max_dimension_pixels)
117
+ except Exception as e:
118
+ yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
119
+
108
120
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
109
121
  response_stream = self.vlm_engine.chat(
110
122
  messages,
111
- max_new_tokens=max_new_tokens,
112
- temperature=temperature,
113
- stream=True,
114
- **kwrs
123
+ stream=True
115
124
  )
116
125
  for chunk in response_stream:
117
126
  yield {"type": "ocr_chunk", "data": chunk}
118
127
 
119
128
  if i < len(images) - 1:
120
- yield {"type": "page_delimiter", "data": self.page_delimiter}
129
+ yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
121
130
 
122
131
  # Image
123
132
  else:
124
- image = get_image_from_file(file_path)
133
+ data_loader = ImageDataLoader(file_path)
134
+ image = data_loader.get_page(0)
135
+
136
+ # Apply rotate correction if specified and tesseract is available
137
+ if rotate_correction and self.image_processor.has_tesseract:
138
+ try:
139
+ image, _ = self.image_processor.rotate_correction(image)
140
+
141
+ except Exception as e:
142
+ yield {"type": "info", "data": f"Error during rotate correction: {str(e)}"}
143
+
144
+ # Resize the image if max_dimension_pixels is specified
145
+ if max_dimension_pixels is not None:
146
+ try:
147
+ image, _ = self.image_processor.resize(image, max_dimension_pixels=max_dimension_pixels)
148
+ except Exception as e:
149
+ yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
150
+
125
151
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
126
152
  response_stream = self.vlm_engine.chat(
127
153
  messages,
128
- max_new_tokens=max_new_tokens,
129
- temperature=temperature,
130
- stream=True,
131
- **kwrs
154
+ stream=True
132
155
  )
133
156
  for chunk in response_stream:
134
157
  yield {"type": "ocr_chunk", "data": chunk}
135
158
 
136
159
 
137
- def run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096, temperature:float=0.0,
138
- verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32, **kwrs) -> Union[str, Generator[str, None, None]]:
160
+ def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
161
+ max_dimension_pixels:int=None, verbose:bool=False) -> List[OCRResult]:
139
162
  """
140
- This method takes a list of file paths (image, PDF, TIFF) and perform OCR using the VLM inference engine.
163
+ This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
141
164
 
142
165
  Parameters:
143
166
  -----------
144
167
  file_paths : Union[str, Iterable[str]]
145
- A file path or a list of file paths to process. Must be one of '.pdf', '.tif', '.tiff, '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
146
- max_new_tokens : int, Optional
147
- The maximum number of tokens to generate.
148
- temperature : float, Optional
149
- The temperature to use for sampling.
168
+ A file path or a list of file paths to process. Must be one of '.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
169
+ rotate_correction : bool, Optional
170
+ If True, applies rotate correction to the images using pytesseract.
171
+ max_dimension_pixels : int, Optional
172
+ The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
150
173
  verbose : bool, Optional
151
- If True, the function will print the output in terminal.
152
- concurrent : bool, Optional
153
- If True, the function will process the files concurrently.
154
- concurrent_batch_size : int, Optional
155
- The number of images/pages to process concurrently.
174
+ If True, the function will print the output in terminal.
175
+
176
+ Returns:
177
+ --------
178
+ List[OCRResult]
179
+ A list of OCR result objects.
156
180
  """
157
- # if file_paths is a string, convert it to a list
158
181
  if isinstance(file_paths, str):
159
182
  file_paths = [file_paths]
160
-
161
- if not isinstance(file_paths, Iterable):
162
- raise TypeError("file_paths must be a string or an iterable of strings")
163
-
164
- # check if all file paths are valid
183
+
184
+ ocr_results = []
165
185
  for file_path in file_paths:
166
- if not isinstance(file_path, str):
167
- raise TypeError("file_paths must be a string or an iterable of strings")
186
+ # Define OCRResult object
187
+ ocr_result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
188
+ # get file extension
168
189
  file_ext = os.path.splitext(file_path)[1].lower()
190
+ # Check file extension
169
191
  if file_ext not in SUPPORTED_IMAGE_EXTS:
170
- raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
192
+ if verbose:
193
+ print(f"{Fore.RED}Unsupported file type:{Style.RESET_ALL} {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
194
+ ocr_result.status = "error"
195
+ ocr_result.add_page(text=f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}",
196
+ image_processing_status={})
197
+ ocr_results.append(ocr_result)
198
+ continue
199
+
200
+ filename = os.path.basename(file_path)
201
+
202
+ try:
203
+ # Load images from file
204
+ if file_ext == '.pdf':
205
+ data_loader = PDFDataLoader(file_path)
206
+ elif file_ext in ['.tif', '.tiff']:
207
+ data_loader = TIFFDataLoader(file_path)
208
+ else:
209
+ data_loader = ImageDataLoader(file_path)
210
+
211
+ images = data_loader.get_all_pages()
212
+ except Exception as e:
213
+ if verbose:
214
+ print(f"{Fore.RED}Error processing file {filename}:{Style.RESET_ALL} {str(e)}")
215
+ ocr_result.status = "error"
216
+ ocr_result.add_page(text=f"Error processing file {filename}: {str(e)}", image_processing_status={})
217
+ ocr_results.append(ocr_result)
218
+ continue
219
+
220
+ # Check if images were extracted
221
+ if not images:
222
+ if verbose:
223
+ print(f"{Fore.RED}No images extracted from file:{Style.RESET_ALL} {filename}. It might be empty or corrupted.")
224
+ ocr_result.status = "error"
225
+ ocr_result.add_page(text=f"No images extracted from file: {filename}. It might be empty or corrupted.",
226
+ image_processing_status={})
227
+ ocr_results.append(ocr_result)
228
+ continue
229
+
230
+ # OCR images
231
+ for i, image in enumerate(images):
232
+ image_processing_status = {}
233
+ # Apply rotate correction if specified and tesseract is available
234
+ if rotate_correction and self.image_processor.has_tesseract:
235
+ try:
236
+ image, rotation_angle = self.image_processor.rotate_correction(image)
237
+ image_processing_status["rotate_correction"] = {
238
+ "status": "success",
239
+ "rotation_angle": rotation_angle
240
+ }
241
+ if verbose:
242
+ print(f"{Fore.GREEN}Rotate correction applied for {filename} page {i} with angle {rotation_angle} degrees.{Style.RESET_ALL}")
243
+ except Exception as e:
244
+ image_processing_status["rotate_correction"] = {
245
+ "status": "error",
246
+ "error": str(e)
247
+ }
248
+ if verbose:
249
+ print(f"{Fore.RED}Error during rotate correction for {filename}:{Style.RESET_ALL} {rotation_angle['error']}. OCR continues without rotate correction.")
250
+
251
+ # Resize the image if max_dimension_pixels is specified
252
+ if max_dimension_pixels is not None:
253
+ try:
254
+ image, resized = self.image_processor.resize(image, max_dimension_pixels=max_dimension_pixels)
255
+ image_processing_status["resize"] = {
256
+ "status": "success",
257
+ "resized": resized
258
+ }
259
+ if verbose and resized:
260
+ print(f"{Fore.GREEN}Image resized for {filename} page {i} to fit within {max_dimension_pixels} pixels.{Style.RESET_ALL}")
261
+ except Exception as e:
262
+ image_processing_status["resize"] = {
263
+ "status": "error",
264
+ "error": str(e)
265
+ }
266
+ if verbose:
267
+ print(f"{Fore.RED}Error resizing image for {filename}:{Style.RESET_ALL} {resized['error']}. OCR continues without resizing.")
268
+
269
+ try:
270
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
271
+ response = self.vlm_engine.chat(
272
+ messages,
273
+ verbose=verbose,
274
+ stream=False
275
+ )
276
+ # Clean the response if output mode is markdown
277
+ if self.output_mode == "markdown":
278
+ response = clean_markdown(response)
279
+
280
+ # Add the page to the OCR result
281
+ ocr_result.add_page(text=response,
282
+ image_processing_status=image_processing_status)
283
+
284
+ except Exception as page_e:
285
+ ocr_result.status = "error"
286
+ ocr_result.add_page(text=f"Error during OCR for a page in {filename}: {str(page_e)}",
287
+ image_processing_status={})
288
+ if verbose:
289
+ print(f"{Fore.RED}Error during OCR for a page in {filename}:{Style.RESET_ALL} {page_e}")
171
290
 
172
- # Concurrent processing
173
- if concurrent:
174
- # Check concurrent_batch_size
175
- if concurrent_batch_size <= 0:
176
- raise ValueError("concurrent_batch_size must be greater than 0")
291
+ # Add the OCR result to the list
292
+ ocr_result.status = "success"
293
+ ocr_results.append(ocr_result)
177
294
 
178
295
  if verbose:
179
- Warning("verbose is not supported for concurrent processing.", UserWarning)
296
+ print(f"{Fore.BLUE}Successfully processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
297
+ for page in ocr_result:
298
+ print(page)
299
+ print("-" * 80)
300
+
301
+ return ocr_results
180
302
 
181
- return asyncio.run(self._run_ocr_async(file_paths,
182
- max_new_tokens=max_new_tokens,
183
- temperature=temperature,
184
- concurrent_batch_size=concurrent_batch_size,
185
- **kwrs))
186
-
187
- # Sync processing
188
- return self._run_ocr(file_paths, max_new_tokens=max_new_tokens, temperature=temperature, verbose=verbose, **kwrs)
189
-
190
303
 
191
- def _run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
192
- temperature:float=0.0, verbose:bool=False, **kwrs) -> Iterable[str]:
304
+ def concurrent_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
305
+ max_dimension_pixels:int=None, concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
193
306
  """
194
- This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
307
+ First complete first out. Input and output order not guaranteed.
308
+ This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
309
+ Results are processed concurrently using asyncio.
195
310
 
196
311
  Parameters:
197
312
  -----------
198
313
  file_paths : Union[str, Iterable[str]]
199
314
  A file path or a list of file paths to process. Must be one of '.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
200
- max_new_tokens : int, Optional
201
- The maximum number of tokens to generate.
202
- temperature : float, Optional
203
- The temperature to use for sampling.
204
- verbose : bool, Optional
205
- If True, the function will print the output in terminal.
315
+ rotate_correction : bool, Optional
316
+ If True, applies rotate correction to the images using pytesseract.
317
+ max_dimension_pixels : int, Optional
318
+ The maximum dimension of the image in pixels. Origianl dimensions will be resized to fit in. If None, no resizing is applied.
319
+ concurrent_batch_size : int, Optional
320
+ The number of concurrent VLM calls to make.
321
+ max_file_load : int, Optional
322
+ The maximum number of files to load concurrently. If None, defaults to 2 times of concurrent_batch_size.
206
323
 
207
324
  Returns:
208
325
  --------
209
- Iterable[str]
210
- A list of strings containing the OCR results.
326
+ AsyncGenerator[OCRResult, None]
327
+ A generator that yields OCR result objects as they complete.
211
328
  """
212
- ocr_results = []
213
- for file_path in file_paths:
214
- file_ext = os.path.splitext(file_path)[1].lower()
215
- # PDF or TIFF
216
- if file_ext in ['.pdf', '.tif', '.tiff']:
217
- images = get_images_from_pdf(file_path) if file_ext == '.pdf' else get_images_from_tiff(file_path)
218
- if not images:
219
- raise ValueError(f"No images extracted from file: {file_path}")
220
- results = []
221
- for image in images:
222
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
223
- response = self.vlm_engine.chat(
224
- messages,
225
- max_new_tokens=max_new_tokens,
226
- temperature=temperature,
227
- verbose=verbose,
228
- stream=False,
229
- **kwrs
230
- )
231
- results.append(response)
232
-
233
- ocr_text = self.page_delimiter.join(results)
234
- # Image
235
- else:
236
- image = get_image_from_file(file_path)
237
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
238
- ocr_text = self.vlm_engine.chat(
239
- messages,
240
- max_new_tokens=max_new_tokens,
241
- temperature=temperature,
242
- verbose=verbose,
243
- stream=False,
244
- **kwrs
245
- )
246
-
247
- # Clean markdown
248
- if self.output_mode == "markdown":
249
- ocr_text = clean_markdown(ocr_text)
250
- ocr_results.append(ocr_text)
329
+ if isinstance(file_paths, str):
330
+ file_paths = [file_paths]
331
+
332
+ if max_file_load is None:
333
+ max_file_load = concurrent_batch_size * 2
251
334
 
252
- return ocr_results
335
+ if not isinstance(max_file_load, int) or max_file_load <= 0:
336
+ raise ValueError("max_file_load must be a positive integer")
337
+
338
+ if self.image_processor.has_tesseract==False and rotate_correction:
339
+ raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
253
340
 
341
+ return self._ocr_async(file_paths=file_paths,
342
+ rotate_correction=rotate_correction,
343
+ max_dimension_pixels=max_dimension_pixels,
344
+ concurrent_batch_size=concurrent_batch_size,
345
+ max_file_load=max_file_load)
346
+
254
347
 
255
- async def _run_ocr_async(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
256
- temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[str]:
348
+ async def _ocr_async(self, file_paths: Iterable[str], rotate_correction:bool=False, max_dimension_pixels:int=None,
349
+ concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
257
350
  """
258
- This is the async version of the _run_ocr method.
351
+ Internal method to asynchronously process an iterable of file paths.
352
+ Yields OCRResult objects as they complete. Order not guaranteed.
353
+ concurrent_batch_size controls how many VLM calls are made concurrently.
259
354
  """
260
- # flatten pages/images in file_paths
261
- flat_page_list = []
355
+ vlm_call_semaphore = asyncio.Semaphore(concurrent_batch_size)
356
+ file_load_semaphore = asyncio.Semaphore(max_file_load)
357
+
358
+ tasks = []
262
359
  for file_path in file_paths:
360
+ task = self._ocr_file_with_semaphore(file_load_semaphore=file_load_semaphore,
361
+ vlm_call_semaphore=vlm_call_semaphore,
362
+ file_path=file_path,
363
+ rotate_correction=rotate_correction,
364
+ max_dimension_pixels=max_dimension_pixels)
365
+ tasks.append(task)
366
+
367
+
368
+ for future in asyncio.as_completed(tasks):
369
+ result: OCRResult = await future
370
+ yield result
371
+
372
+ async def _ocr_file_with_semaphore(self, file_load_semaphore:asyncio.Semaphore, vlm_call_semaphore:asyncio.Semaphore,
373
+ file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> OCRResult:
374
+ """
375
+ This internal method takes a semaphore and OCR a single file using the VLM inference engine.
376
+ """
377
+ async with file_load_semaphore:
378
+ filename = os.path.basename(file_path)
263
379
  file_ext = os.path.splitext(file_path)[1].lower()
264
- # PDF or TIFF
265
- if file_ext in ['.pdf', '.tif', '.tiff']:
266
- images = get_images_from_pdf(file_path) if file_ext == '.pdf' else get_images_from_tiff(file_path)
267
- if not images:
268
- flat_page_list.append({'file_path': file_path, 'file_type': "PDF/TIFF", "image": image, "page_num": 0, "total_page_count": 0})
269
- for page_num, image in enumerate(images):
270
- flat_page_list.append({'file_path': file_path, 'file_type': "PDF/TIFF", "image": image, "page_num": page_num, "total_page_count": len(images)})
271
- # Image
272
- else:
273
- image = get_image_from_file(file_path)
274
- flat_page_list.append({'file_path': file_path, 'file_type': "image", "image": image})
275
-
276
- # Process images with asyncio.Semaphore
277
- semaphore = asyncio.Semaphore(concurrent_batch_size)
278
- async def semaphore_helper(page:List[Dict[str,str]], max_new_tokens:int, temperature:float, **kwrs):
380
+ result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
381
+ # check file extension
382
+ if file_ext not in SUPPORTED_IMAGE_EXTS:
383
+ result.status = "error"
384
+ result.add_page(text=f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}",
385
+ image_processing_status={})
386
+ return result
387
+
279
388
  try:
280
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, page["image"])
281
- async with semaphore:
282
- async_task = self.vlm_engine.chat_async(
283
- messages,
284
- max_new_tokens=max_new_tokens,
285
- temperature=temperature,
286
- **kwrs
389
+ # Load images from file
390
+ if file_ext == '.pdf':
391
+ data_loader = PDFDataLoader(file_path)
392
+ elif file_ext in ['.tif', '.tiff']:
393
+ data_loader = TIFFDataLoader(file_path)
394
+ else:
395
+ data_loader = ImageDataLoader(file_path)
396
+
397
+ except Exception as e:
398
+ result.status = "error"
399
+ result.add_page(text=f"Error processing file {filename}: {str(e)}", image_processing_status={})
400
+ return result
401
+
402
+ try:
403
+ page_processing_tasks = []
404
+ for page_index in range(data_loader.get_page_count()):
405
+ task = self._ocr_page_with_semaphore(
406
+ vlm_call_semaphore=vlm_call_semaphore,
407
+ data_loader=data_loader,
408
+ page_index=page_index,
409
+ rotate_correction=rotate_correction,
410
+ max_dimension_pixels=max_dimension_pixels
287
411
  )
288
- return await async_task
412
+ page_processing_tasks.append(task)
413
+
414
+ if page_processing_tasks:
415
+ processed_page_results = await asyncio.gather(*page_processing_tasks)
416
+ for text, image_processing_status in processed_page_results:
417
+ result.add_page(text=text, image_processing_status=image_processing_status)
418
+
289
419
  except Exception as e:
290
- print(f"Error processing image: {e}")
291
- return f"[Error: {e}]"
420
+ result.status = "error"
421
+ result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
422
+ return result
292
423
 
293
- tasks = []
294
- for page in flat_page_list:
295
- async_task = semaphore_helper(
296
- page,
297
- max_new_tokens=max_new_tokens,
298
- temperature=temperature,
299
- **kwrs
300
- )
301
- tasks.append(asyncio.create_task(async_task))
424
+ # Set status to success if no errors occurred
425
+ result.status = "success"
426
+ return result
302
427
 
303
- responses = await asyncio.gather(*tasks)
428
+ async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
429
+ page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Tuple[str, Dict[str, str]]:
430
+ """
431
+ This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
304
432
 
305
- # Restructure the results
306
- ocr_results = []
307
- page_text_buffer = ""
308
- for page, ocr_text in zip(flat_page_list, responses):
309
- # PDF or TIFF
310
- if page['file_type'] == "PDF/TIFF":
311
- page_text_buffer += ocr_text + self.page_delimiter
312
- if page['page_num'] == page['total_page_count'] - 1:
313
- if self.output_mode == "markdown":
314
- page_text_buffer = clean_markdown(page_text_buffer)
315
- ocr_results.append(page_text_buffer)
316
- page_text_buffer = ""
317
- # Image
318
- if page['file_type'] == "image":
319
- if self.output_mode == "markdown":
320
- ocr_text = clean_markdown(ocr_text)
321
- ocr_results.append(ocr_text)
322
-
323
- return ocr_results
433
+ Returns:
434
+ -------
435
+ Tuple[str, Dict[str, str]]
436
+ A tuple containing the OCR text and a dictionary with image processing status.
437
+ """
438
+ async with vlm_call_semaphore:
439
+ image = await data_loader.get_page_async(page_index)
440
+ image_processing_status = {}
441
+ # Apply rotate correction if specified and tesseract is available
442
+ if rotate_correction and self.image_processor.has_tesseract:
443
+ try:
444
+ image, rotation_angle = await self.image_processor.rotate_correction_async(image)
445
+ image_processing_status["rotate_correction"] = {
446
+ "status": "success",
447
+ "rotation_angle": rotation_angle
448
+ }
449
+ except Exception as e:
450
+ image_processing_status["rotate_correction"] = {
451
+ "status": "error",
452
+ "error": str(e)
453
+ }
454
+
455
+ # Resize the image if max_dimension_pixels is specified
456
+ if max_dimension_pixels is not None:
457
+ try:
458
+ image, resized = await self.image_processor.resize_async(image, max_dimension_pixels=max_dimension_pixels)
459
+ image_processing_status["resize"] = {
460
+ "status": "success",
461
+ "resized": resized
462
+ }
463
+ except Exception as e:
464
+ image_processing_status["resize"] = {
465
+ "status": "error",
466
+ "error": str(e)
467
+ }
324
468
 
469
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
470
+ ocr_text = await self.vlm_engine.chat_async(
471
+ messages,
472
+ )
473
+ if self.output_mode == "markdown":
474
+ ocr_text = clean_markdown(ocr_text)
475
+ return ocr_text, image_processing_status