vlm4ocr 0.3.1__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vlm4ocr
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Python package and Web App for OCR with vision language models.
5
5
  License: MIT
6
6
  Author: Enshuo (David) Hsu
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "vlm4ocr"
3
- version = "0.3.1"
3
+ version = "0.4.0"
4
4
  description = "Python package and Web App for OCR with vision language models."
5
5
  authors = ["Enshuo (David) Hsu"]
6
6
  license = "MIT"
@@ -1,7 +1,9 @@
1
+ from .data_types import FewShotExample
1
2
  from .ocr_engines import OCREngine
2
3
  from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
3
4
 
4
5
  __all__ = [
6
+ "FewShotExample",
5
7
  "BasicVLMConfig",
6
8
  "ReasoningVLMConfig",
7
9
  "OpenAIReasoningVLMConfig",
@@ -15,7 +15,12 @@ logging.basicConfig(
15
15
  format='%(asctime)s - %(levelname)s: %(message)s',
16
16
  datefmt='%Y-%m-%d %H:%M:%S'
17
17
  )
18
+ # Get our specific logger for CLI messages
18
19
  logger = logging.getLogger("vlm4ocr_cli")
20
+ # Get the logger that will receive captured warnings
21
+ # By default, warnings are logged to a logger named 'py.warnings'
22
+ warnings_logger = logging.getLogger('py.warnings')
23
+
19
24
 
20
25
  SUPPORTED_IMAGE_EXTS_CLI = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
21
26
  OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'.txt'}
@@ -56,17 +61,26 @@ def setup_file_logger(log_dir, timestamp_str, debug_mode):
56
61
  log_file_path = os.path.join(log_dir, log_file_name)
57
62
 
58
63
  file_handler = logging.FileHandler(log_file_path, mode='a')
59
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
64
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(name)s:%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
60
65
  file_handler.setFormatter(formatter)
61
66
 
62
67
  log_level = logging.DEBUG if debug_mode else logging.INFO
63
68
  file_handler.setLevel(log_level)
64
69
 
65
- logger.addHandler(file_handler)
70
+ # Add handler to the root logger to capture all logs (from our logger,
71
+ # and from the warnings logger 'py.warnings')
72
+ root_logger = logging.getLogger()
73
+ root_logger.addHandler(file_handler)
74
+
75
+ # We still configure our specific logger's level for console output
66
76
  logger.info(f"Logging to file: {log_file_path}")
67
77
 
68
78
 
69
79
  def main():
80
+ # Capture warnings from the 'warnings' module (like RuntimeWarning)
81
+ # and redirect them to the 'logging' system.
82
+ logging.captureWarnings(True)
83
+
70
84
  parser = argparse.ArgumentParser(
71
85
  description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models. Processing is concurrent by default.",
72
86
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -94,7 +108,8 @@ def main():
94
108
  vlm_engine_group.add_argument("--vlm_engine", choices=["openai", "azure_openai", "ollama", "openai_compatible"], required=True, help="VLM engine.")
95
109
  vlm_engine_group.add_argument("--model", required=True, help="Model identifier for the VLM engine.")
96
110
  vlm_engine_group.add_argument("--max_new_tokens", type=int, default=4096, help="Max new tokens for VLM.")
97
- vlm_engine_group.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
111
+ vlm_engine_group.add_argument("--temperature", type=float, default=None, help="Sampling temperature.")
112
+ vlm_engine_group.add_argument("--top_p", type=float, default=None, help="Sampling top p.")
98
113
 
99
114
  openai_group = parser.add_argument_group("OpenAI & OpenAI-Compatible Options")
100
115
  openai_group.add_argument("--api_key", default=os.environ.get("OPENAI_API_KEY"), help="API key.")
@@ -135,16 +150,23 @@ def main():
135
150
  current_timestamp_str = time.strftime("%Y%m%d_%H%M%S")
136
151
 
137
152
  # --- Configure Logger Level based on args ---
153
+ # Get root logger to control global level for libraries
154
+ root_logger = logging.getLogger()
155
+
138
156
  if args.debug:
139
- logger.setLevel(logging.DEBUG)
140
- # Set root logger to DEBUG only if our specific logger is DEBUG, to avoid overly verbose library logs unless intended.
141
- if logger.getEffectiveLevel() <= logging.DEBUG:
142
- logging.getLogger().setLevel(logging.DEBUG)
157
+ logger.setLevel(logging.DEBUG) # Our logger to DEBUG
158
+ warnings_logger.setLevel(logging.DEBUG) # Warnings logger to DEBUG
159
+ root_logger.setLevel(logging.DEBUG) # Root to DEBUG
143
160
  logger.debug("Debug mode enabled for console.")
144
161
  else:
145
- logger.setLevel(logging.INFO) # Default for our CLI's own messages
146
- logging.getLogger().setLevel(logging.WARNING) # Keep external libraries quieter by default
147
-
162
+ logger.setLevel(logging.INFO) # Our logger to INFO
163
+ warnings_logger.setLevel(logging.INFO) # Warnings logger to INFO
164
+ root_logger.setLevel(logging.WARNING) # Root to WARNING (quieter libraries)
165
+ # Our console handler (from basicConfig) is on the root logger,
166
+ # so setting root to WARNING makes console quiet
167
+ # But our logger (vlm4ocr_cli) is INFO, so if a file handler
168
+ # is added, it will get INFO messages from 'logger'
169
+
148
170
  if args.concurrent_batch_size < 1:
149
171
  parser.error("--concurrent_batch_size must be 1 or greater.")
150
172
 
@@ -183,6 +205,15 @@ def main():
183
205
  # --- Setup File Logger (if --log is specified) ---
184
206
  if args.log:
185
207
  setup_file_logger(effective_output_dir, current_timestamp_str, args.debug)
208
+ # If logging to file, we want our console to be less verbose
209
+ # if not in debug mode, so we set the console handler's level higher.
210
+ if not args.debug:
211
+ # Find the console handler (from basicConfig) and set its level
212
+ for handler in root_logger.handlers:
213
+ if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
214
+ handler.setLevel(logging.WARNING)
215
+ logger.debug("Set console handler level to WARNING.")
216
+ break
186
217
 
187
218
  logger.debug(f"Parsed arguments: {args}")
188
219
 
@@ -190,9 +221,11 @@ def main():
190
221
  vlm_engine_instance = None
191
222
  try:
192
223
  logger.info(f"Initializing VLM engine: {args.vlm_engine} with model: {args.model}")
224
+ logger.info(f"max_new_tokens: {args.max_new_tokens}, temperature: {args.temperature}, top_p: {args.top_p}")
193
225
  config = BasicVLMConfig(
194
226
  max_new_tokens=args.max_new_tokens,
195
- temperature=args.temperature
227
+ temperature=args.temperature,
228
+ top_p=args.top_p
196
229
  )
197
230
  if args.vlm_engine == "openai":
198
231
  if not args.api_key: parser.error("--api_key (or OPENAI_API_KEY) is required for OpenAI.")
@@ -286,16 +319,34 @@ def main():
286
319
  # console verbosity controlled by logger level.
287
320
  show_progress_bar = (num_actual_files > 0)
288
321
 
322
+ # Only show progress bar if not in debug mode (debug logs would interfere)
323
+ # and if there are files to process.
324
+ # If logging to file, console can be quiet (INFO level).
325
+ # If NOT logging to file, console must be INFO level to show bar.
326
+
327
+ # Determine if progress bar should be active (not disabled)
328
+ # Disable bar if in debug mode (logs interfere) or no files
329
+ disable_bar = args.debug or not show_progress_bar
330
+
331
+ # If not logging to file AND not debug, we need console at INFO
332
+ if not args.log and not args.debug:
333
+ for handler in logging.getLogger().handlers:
334
+ if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
335
+ handler.setLevel(logging.INFO)
336
+ logger.debug("Set console handler level to INFO for progress bar.")
337
+ break
338
+
289
339
  iterator_wrapper = tqdm.asyncio.tqdm(
290
340
  ocr_task_generator,
291
341
  total=num_actual_files,
292
342
  desc="Processing files",
293
343
  unit="file",
294
- disable=not show_progress_bar # disable if no files, or can remove this disable if tqdm handles total=0
344
+ disable=disable_bar
295
345
  )
296
346
 
297
347
  async for result_object in iterator_wrapper:
298
348
  if not isinstance(result_object, OCRResult):
349
+ # This warning *will* now be captured by the file log
299
350
  logger.warning(f"Received unexpected data type: {type(result_object)}")
300
351
  continue
301
352
 
@@ -314,9 +365,12 @@ def main():
314
365
  content_to_write = result_object.to_string()
315
366
  with open(current_ocr_output_file_path, "w", encoding="utf-8") as f:
316
367
  f.write(content_to_write)
317
- # Log less verbosely to console if progress bar is active
318
- if not show_progress_bar or logger.getEffectiveLevel() <= logging.DEBUG:
319
- logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
368
+
369
+ # MODIFIED: Always log success info.
370
+ # This will go to the file log if active.
371
+ # It will NOT go to console if console level is WARNING.
372
+ logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
373
+
320
374
  except Exception as e:
321
375
  logger.error(f"Error writing output for '{input_file_path_from_result}' to '{current_ocr_output_file_path}': {e}")
322
376
 
@@ -1,7 +1,8 @@
1
1
  import os
2
2
  from typing import List, Dict, Literal
3
+ from PIL import Image
3
4
  from dataclasses import dataclass, field
4
- from vlm4ocr.utils import get_default_page_delimiter
5
+ from vlm4ocr.utils import get_default_page_delimiter, ImageProcessor
5
6
 
6
7
  OutputMode = Literal["markdown", "HTML", "text", "JSON"]
7
8
 
@@ -118,4 +119,41 @@ class OCRResult:
118
119
  else:
119
120
  self.page_delimiter = page_delimiter
120
121
 
121
- return self.page_delimiter.join([page.get("text", "") for page in self.pages])
122
+ return self.page_delimiter.join([page.get("text", "") for page in self.pages])
123
+
124
+ @dataclass
125
+ class FewShotExample:
126
+ """
127
+ This class represents a few-shot example for OCR tasks.
128
+
129
+ Parameters:
130
+ ----------
131
+ image : PIL.Image.Image
132
+ The image associated with the example.
133
+ text : str
134
+ The expected OCR result text for the image.
135
+ rotate_correction : bool, Optional
136
+ If True, applies rotate correction to the images using pytesseract.
137
+ max_dimension_pixels : int, Optional
138
+ The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
139
+ """
140
+ image: Image.Image
141
+ text: str
142
+ rotate_correction: bool = False
143
+ max_dimension_pixels: int = None
144
+ def __post_init__(self):
145
+ if not isinstance(self.image, Image.Image):
146
+ raise ValueError("image must be a PIL.Image.Image object")
147
+ if not isinstance(self.text, str):
148
+ raise ValueError("text must be a string")
149
+
150
+ if self.rotate_correction or self.max_dimension_pixels is not None:
151
+ self.image_processor = ImageProcessor()
152
+
153
+ # Rotate correction if specified
154
+ if self.rotate_correction:
155
+ self.image, _ = self.image_processor.rotate_correction(self.image)
156
+
157
+ # Resize image if max_dimension_pixels is specified
158
+ if self.max_dimension_pixels is not None:
159
+ self.image, _ = self.image_processor.resize(image=self.image, max_dimension_pixels=self.max_dimension_pixels)
@@ -1,11 +1,11 @@
1
1
  import os
2
- from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
2
+ from typing import Any, Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
3
3
  import importlib
4
4
  import asyncio
5
5
  from colorama import Fore, Style
6
6
  import json
7
7
  from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
8
- from vlm4ocr.data_types import OCRResult
8
+ from vlm4ocr.data_types import OCRResult, FewShotExample
9
9
  from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
10
10
 
11
11
  SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
@@ -60,7 +60,8 @@ class OCREngine:
60
60
  self.image_processor = ImageProcessor()
61
61
 
62
62
 
63
- def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Generator[Dict[str, str], None, None]:
63
+ def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None,
64
+ few_shot_examples:List[FewShotExample]=None) -> Generator[Dict[str, str], None, None]:
64
65
  """
65
66
  This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
66
67
  Yields dictionaries with 'type' ('ocr_chunk' or 'page_delimiter') and 'data'.
@@ -73,6 +74,8 @@ class OCREngine:
73
74
  If True, applies rotate correction to the images using pytesseract.
74
75
  max_dimension_pixels : int, Optional
75
76
  The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
77
+ few_shot_examples : List[FewShotExample], Optional
78
+ list of few-shot examples.
76
79
 
77
80
  Returns:
78
81
  --------
@@ -90,10 +93,6 @@ class OCREngine:
90
93
  file_ext = os.path.splitext(file_path)[1].lower()
91
94
  if file_ext not in SUPPORTED_IMAGE_EXTS:
92
95
  raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
93
-
94
- # Check if image preprocessing can be applied
95
- if self.image_processor.has_tesseract==False and rotate_correction:
96
- raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
97
96
 
98
97
  # PDF or TIFF
99
98
  if file_ext in ['.pdf', '.tif', '.tiff']:
@@ -105,8 +104,8 @@ class OCREngine:
105
104
 
106
105
  # OCR each image
107
106
  for i, image in enumerate(images):
108
- # Apply rotate correction if specified and tesseract is available
109
- if rotate_correction and self.image_processor.has_tesseract:
107
+ # Apply rotate correction if specified
108
+ if rotate_correction:
110
109
  try:
111
110
  image, _ = self.image_processor.rotate_correction(image)
112
111
 
@@ -120,7 +119,13 @@ class OCREngine:
120
119
  except Exception as e:
121
120
  yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
122
121
 
123
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
122
+ # Get OCR messages
123
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
124
+ user_prompt=self.user_prompt,
125
+ image=image,
126
+ few_shot_examples=few_shot_examples)
127
+
128
+ # Stream response
124
129
  response_stream = self.vlm_engine.chat(
125
130
  messages,
126
131
  stream=True
@@ -137,8 +142,8 @@ class OCREngine:
137
142
  data_loader = ImageDataLoader(file_path)
138
143
  image = data_loader.get_page(0)
139
144
 
140
- # Apply rotate correction if specified and tesseract is available
141
- if rotate_correction and self.image_processor.has_tesseract:
145
+ # Apply rotate correction if specified
146
+ if rotate_correction:
142
147
  try:
143
148
  image, _ = self.image_processor.rotate_correction(image)
144
149
 
@@ -152,7 +157,12 @@ class OCREngine:
152
157
  except Exception as e:
153
158
  yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
154
159
 
155
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
160
+ # Get OCR messages
161
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
162
+ user_prompt=self.user_prompt,
163
+ image=image,
164
+ few_shot_examples=few_shot_examples)
165
+ # Stream response
156
166
  response_stream = self.vlm_engine.chat(
157
167
  messages,
158
168
  stream=True
@@ -163,7 +173,7 @@ class OCREngine:
163
173
 
164
174
 
165
175
  def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
166
- max_dimension_pixels:int=None, verbose:bool=False) -> List[OCRResult]:
176
+ max_dimension_pixels:int=None, verbose:bool=False, few_shot_examples:List[FewShotExample]=None) -> List[OCRResult]:
167
177
  """
168
178
  This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
169
179
 
@@ -177,6 +187,8 @@ class OCREngine:
177
187
  The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
178
188
  verbose : bool, Optional
179
189
  If True, the function will print the output in terminal.
190
+ few_shot_examples : List[FewShotExample], Optional
191
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
180
192
 
181
193
  Returns:
182
194
  --------
@@ -186,6 +198,7 @@ class OCREngine:
186
198
  if isinstance(file_paths, str):
187
199
  file_paths = [file_paths]
188
200
 
201
+ # Iterate through file paths
189
202
  ocr_results = []
190
203
  for file_path in file_paths:
191
204
  # Define OCRResult object
@@ -235,8 +248,8 @@ class OCREngine:
235
248
  # OCR images
236
249
  for i, image in enumerate(images):
237
250
  image_processing_status = {}
238
- # Apply rotate correction if specified and tesseract is available
239
- if rotate_correction and self.image_processor.has_tesseract:
251
+ # Apply rotate correction if specified
252
+ if rotate_correction:
240
253
  try:
241
254
  image, rotation_angle = self.image_processor.rotate_correction(image)
242
255
  image_processing_status["rotate_correction"] = {
@@ -272,7 +285,10 @@ class OCREngine:
272
285
  print(f"{Fore.RED}Error resizing image for {filename}:{Style.RESET_ALL} {resized['error']}. OCR continues without resizing.")
273
286
 
274
287
  try:
275
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
288
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
289
+ user_prompt=self.user_prompt,
290
+ image=image,
291
+ few_shot_examples=few_shot_examples)
276
292
  # Define a messages logger to capture messages
277
293
  messages_logger = MessagesLogger()
278
294
  # Generate response
@@ -308,11 +324,12 @@ class OCREngine:
308
324
  print(f"{Fore.RED}Error during OCR for a page in {filename}:{Style.RESET_ALL} {page_e}")
309
325
 
310
326
  # Add the OCR result to the list
311
- ocr_result.status = "success"
327
+ if ocr_result.status != "error":
328
+ ocr_result.status = "success"
312
329
  ocr_results.append(ocr_result)
313
330
 
314
331
  if verbose:
315
- print(f"{Fore.BLUE}Successfully processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
332
+ print(f"{Fore.BLUE}Processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
316
333
  for page in ocr_result:
317
334
  print(page)
318
335
  print("-" * 80)
@@ -321,7 +338,8 @@ class OCREngine:
321
338
 
322
339
 
323
340
  def concurrent_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
324
- max_dimension_pixels:int=None, concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
341
+ max_dimension_pixels:int=None, few_shot_examples:List[FewShotExample]=None,
342
+ concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
325
343
  """
326
344
  First complete first out. Input and output order not guaranteed.
327
345
  This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
@@ -335,6 +353,8 @@ class OCREngine:
335
353
  If True, applies rotate correction to the images using pytesseract.
336
354
  max_dimension_pixels : int, Optional
337
355
  The maximum dimension of the image in pixels. Origianl dimensions will be resized to fit in. If None, no resizing is applied.
356
+ few_shot_examples : List[FewShotExample], Optional
357
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
338
358
  concurrent_batch_size : int, Optional
339
359
  The number of concurrent VLM calls to make.
340
360
  max_file_load : int, Optional
@@ -353,18 +373,17 @@ class OCREngine:
353
373
 
354
374
  if not isinstance(max_file_load, int) or max_file_load <= 0:
355
375
  raise ValueError("max_file_load must be a positive integer")
356
-
357
- if self.image_processor.has_tesseract==False and rotate_correction:
358
- raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
359
376
 
360
377
  return self._ocr_async(file_paths=file_paths,
361
378
  rotate_correction=rotate_correction,
362
379
  max_dimension_pixels=max_dimension_pixels,
380
+ few_shot_examples=few_shot_examples,
363
381
  concurrent_batch_size=concurrent_batch_size,
364
382
  max_file_load=max_file_load)
365
383
 
366
384
 
367
385
  async def _ocr_async(self, file_paths: Iterable[str], rotate_correction:bool=False, max_dimension_pixels:int=None,
386
+ few_shot_examples:List[FewShotExample]=None,
368
387
  concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
369
388
  """
370
389
  Internal method to asynchronously process an iterable of file paths.
@@ -380,7 +399,8 @@ class OCREngine:
380
399
  vlm_call_semaphore=vlm_call_semaphore,
381
400
  file_path=file_path,
382
401
  rotate_correction=rotate_correction,
383
- max_dimension_pixels=max_dimension_pixels)
402
+ max_dimension_pixels=max_dimension_pixels,
403
+ few_shot_examples=few_shot_examples)
384
404
  tasks.append(task)
385
405
 
386
406
 
@@ -389,7 +409,8 @@ class OCREngine:
389
409
  yield result
390
410
 
391
411
  async def _ocr_file_with_semaphore(self, file_load_semaphore:asyncio.Semaphore, vlm_call_semaphore:asyncio.Semaphore,
392
- file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> OCRResult:
412
+ file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None,
413
+ few_shot_examples:List[FewShotExample]=None) -> OCRResult:
393
414
  """
394
415
  This internal method takes a semaphore and OCR a single file using the VLM inference engine.
395
416
  """
@@ -428,6 +449,7 @@ class OCREngine:
428
449
  page_index=page_index,
429
450
  rotate_correction=rotate_correction,
430
451
  max_dimension_pixels=max_dimension_pixels,
452
+ few_shot_examples=few_shot_examples,
431
453
  messages_logger=messages_logger
432
454
  )
433
455
  page_processing_tasks.append(task)
@@ -444,13 +466,14 @@ class OCREngine:
444
466
  return result
445
467
 
446
468
  # Set status to success if no errors occurred
447
- result.status = "success"
469
+ if result.status != "error":
470
+ result.status = "success"
448
471
  result.add_messages_to_log(messages_logger.get_messages_log())
449
472
  return result
450
473
 
451
474
  async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
452
475
  page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
453
- messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
476
+ few_shot_examples:List[FewShotExample]=None, messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
454
477
  """
455
478
  This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
456
479
 
@@ -462,8 +485,8 @@ class OCREngine:
462
485
  async with vlm_call_semaphore:
463
486
  image = await data_loader.get_page_async(page_index)
464
487
  image_processing_status = {}
465
- # Apply rotate correction if specified and tesseract is available
466
- if rotate_correction and self.image_processor.has_tesseract:
488
+ # Apply rotate correction if specified
489
+ if rotate_correction:
467
490
  try:
468
491
  image, rotation_angle = await self.image_processor.rotate_correction_async(image)
469
492
  image_processing_status["rotate_correction"] = {
@@ -490,7 +513,10 @@ class OCREngine:
490
513
  "error": str(e)
491
514
  }
492
515
 
493
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
516
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
517
+ user_prompt=self.user_prompt,
518
+ image=image,
519
+ few_shot_examples=few_shot_examples)
494
520
  response = await self.vlm_engine.chat_async(
495
521
  messages,
496
522
  messages_logger=messages_logger
@@ -6,6 +6,7 @@ import os
6
6
  import re
7
7
  from PIL import Image
8
8
  from vlm4ocr.utils import image_to_base64
9
+ from vlm4ocr.data_types import FewShotExample
9
10
 
10
11
 
11
12
  class VLMConfig(abc.ABC):
@@ -332,7 +333,7 @@ class VLMEngine:
332
333
  return NotImplemented
333
334
 
334
335
  @abc.abstractmethod
335
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
336
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
336
337
  """
337
338
  This method inputs an image and returns the correesponding chat messages for the inference engine.
338
339
 
@@ -344,6 +345,8 @@ class VLMEngine:
344
345
  the user prompt.
345
346
  image : Image.Image
346
347
  the image for OCR.
348
+ few_shot_examples : List[FewShotExample], Optional
349
+ list of few-shot examples.
347
350
  """
348
351
  return NotImplemented
349
352
 
@@ -557,7 +560,7 @@ class OllamaVLMEngine(VLMEngine):
557
560
 
558
561
  return res_dict
559
562
 
560
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
563
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
561
564
  """
562
565
  This method inputs an image and returns the correesponding chat messages for the inference engine.
563
566
 
@@ -569,16 +572,32 @@ class OllamaVLMEngine(VLMEngine):
569
572
  the user prompt.
570
573
  image : Image.Image
571
574
  the image for OCR.
575
+ few_shot_examples : List[FewShotExample], Optional
576
+ list of few-shot examples.
572
577
  """
573
578
  base64_str = image_to_base64(image)
574
- return [
575
- {"role": "system", "content": system_prompt},
576
- {
577
- "role": "user",
578
- "content": user_prompt,
579
- "images": [base64_str]
580
- }
581
- ]
579
+ output_messages = []
580
+ # system message
581
+ system_message = {"role": "system", "content": system_prompt}
582
+ output_messages.append(system_message)
583
+
584
+ # few-shot examples
585
+ if few_shot_examples is not None:
586
+ for example in few_shot_examples:
587
+ if not isinstance(example, FewShotExample):
588
+ raise ValueError("Few-shot example must be a FewShotExample object.")
589
+
590
+ example_image_b64 = image_to_base64(example.image)
591
+ example_user_message = {"role": "user", "content": user_prompt, "images": [example_image_b64]}
592
+ example_agent_message = {"role": "assistant", "content": example.text}
593
+ output_messages.append(example_user_message)
594
+ output_messages.append(example_agent_message)
595
+
596
+ # user message
597
+ user_message = {"role": "user", "content": user_prompt, "images": [base64_str]}
598
+ output_messages.append(user_message)
599
+
600
+ return output_messages
582
601
 
583
602
 
584
603
  class OpenAICompatibleVLMEngine(VLMEngine):
@@ -792,7 +811,8 @@ class OpenAICompatibleVLMEngine(VLMEngine):
792
811
 
793
812
  return res_dict
794
813
 
795
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
814
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
815
+ detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
796
816
  """
797
817
  This method inputs an image and returns the correesponding chat messages for the inference engine.
798
818
 
@@ -808,24 +828,55 @@ class OpenAICompatibleVLMEngine(VLMEngine):
808
828
  the image format.
809
829
  detail : str, Optional
810
830
  the detail level of the image. Default is "high".
831
+ few_shot_examples : List[FewShotExample], Optional
832
+ list of few-shot examples.
811
833
  """
812
834
  base64_str = image_to_base64(image)
813
- return [
814
- {"role": "system", "content": system_prompt},
815
- {
816
- "role": "user",
817
- "content": [
818
- {
819
- "type": "image_url",
820
- "image_url": {
821
- "url": f"data:image/{format};base64,{base64_str}",
822
- "detail": detail
835
+ output_messages = []
836
+ # system message
837
+ system_message = {"role": "system", "content": system_prompt}
838
+ output_messages.append(system_message)
839
+
840
+ # few-shot examples
841
+ if few_shot_examples is not None:
842
+ for example in few_shot_examples:
843
+ if not isinstance(example, FewShotExample):
844
+ raise ValueError("Few-shot example must be a FewShotExample object.")
845
+
846
+ example_image_b64 = image_to_base64(example.image)
847
+ example_user_message = {
848
+ "role": "user",
849
+ "content": [
850
+ {
851
+ "type": "image_url",
852
+ "image_url": {
853
+ "url": f"data:image/{format};base64,{example_image_b64}",
854
+ "detail": detail
855
+ },
823
856
  },
857
+ {"type": "text", "text": user_prompt},
858
+ ],
859
+ }
860
+ example_agent_message = {"role": "assistant", "content": example.text}
861
+ output_messages.append(example_user_message)
862
+ output_messages.append(example_agent_message)
863
+
864
+ # user message
865
+ user_message = {
866
+ "role": "user",
867
+ "content": [
868
+ {
869
+ "type": "image_url",
870
+ "image_url": {
871
+ "url": f"data:image/{format};base64,{base64_str}",
872
+ "detail": detail
824
873
  },
825
- {"type": "text", "text": user_prompt},
826
- ],
827
- },
828
- ]
874
+ },
875
+ {"type": "text", "text": user_prompt},
876
+ ],
877
+ }
878
+ output_messages.append(user_message)
879
+ return output_messages
829
880
 
830
881
 
831
882
  class VLLMVLMEngine(OpenAICompatibleVLMEngine):
@@ -1096,7 +1147,8 @@ class OpenAIVLMEngine(VLMEngine):
1096
1147
 
1097
1148
  return res_dict
1098
1149
 
1099
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
1150
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
1151
+ detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
1100
1152
  """
1101
1153
  This method inputs an image and returns the correesponding chat messages for the inference engine.
1102
1154
 
@@ -1112,24 +1164,55 @@ class OpenAIVLMEngine(VLMEngine):
1112
1164
  the image format.
1113
1165
  detail : str, Optional
1114
1166
  the detail level of the image. Default is "high".
1167
+ few_shot_examples : List[FewShotExample], Optional
1168
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
1115
1169
  """
1116
1170
  base64_str = image_to_base64(image)
1117
- return [
1118
- {"role": "system", "content": system_prompt},
1119
- {
1120
- "role": "user",
1121
- "content": [
1122
- {
1123
- "type": "image_url",
1124
- "image_url": {
1125
- "url": f"data:image/{format};base64,{base64_str}",
1126
- "detail": detail
1171
+ output_messages = []
1172
+ # system message
1173
+ system_message = {"role": "system", "content": system_prompt}
1174
+ output_messages.append(system_message)
1175
+
1176
+ # few-shot examples
1177
+ if few_shot_examples is not None:
1178
+ for example in few_shot_examples:
1179
+ if not isinstance(example, FewShotExample):
1180
+ raise ValueError("Few-shot example must be a FewShotExample object.")
1181
+
1182
+ example_image_b64 = image_to_base64(example.image)
1183
+ example_user_message = {
1184
+ "role": "user",
1185
+ "content": [
1186
+ {
1187
+ "type": "image_url",
1188
+ "image_url": {
1189
+ "url": f"data:image/{format};base64,{example_image_b64}",
1190
+ "detail": detail
1191
+ },
1127
1192
  },
1193
+ {"type": "text", "text": user_prompt},
1194
+ ],
1195
+ }
1196
+ example_agent_message = {"role": "assistant", "content": example.text}
1197
+ output_messages.append(example_user_message)
1198
+ output_messages.append(example_agent_message)
1199
+
1200
+ # user message
1201
+ user_message = {
1202
+ "role": "user",
1203
+ "content": [
1204
+ {
1205
+ "type": "image_url",
1206
+ "image_url": {
1207
+ "url": f"data:image/{format};base64,{base64_str}",
1208
+ "detail": detail
1128
1209
  },
1129
- {"type": "text", "text": user_prompt},
1130
- ],
1131
- },
1132
- ]
1210
+ },
1211
+ {"type": "text", "text": user_prompt},
1212
+ ],
1213
+ }
1214
+ output_messages.append(user_message)
1215
+ return output_messages
1133
1216
 
1134
1217
 
1135
1218
  class AzureOpenAIVLMEngine(OpenAIVLMEngine):
File without changes
File without changes