vlm4ocr 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vlm4ocr
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Python package and Web App for OCR with vision language models.
5
5
  License: MIT
6
6
  Author: Enshuo (David) Hsu
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "vlm4ocr"
3
- version = "0.3.0"
3
+ version = "0.4.0"
4
4
  description = "Python package and Web App for OCR with vision language models."
5
5
  authors = ["Enshuo (David) Hsu"]
6
6
  license = "MIT"
@@ -0,0 +1,17 @@
1
+ from .data_types import FewShotExample
2
+ from .ocr_engines import OCREngine
3
+ from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
4
+
5
+ __all__ = [
6
+ "FewShotExample",
7
+ "BasicVLMConfig",
8
+ "ReasoningVLMConfig",
9
+ "OpenAIReasoningVLMConfig",
10
+ "OCREngine",
11
+ "OllamaVLMEngine",
12
+ "OpenAICompatibleVLMEngine",
13
+ "VLLMVLMEngine",
14
+ "OpenRouterVLMEngine",
15
+ "OpenAIVLMEngine",
16
+ "AzureOpenAIVLMEngine"
17
+ ]
@@ -4,18 +4,9 @@ import sys
4
4
  import logging
5
5
  import asyncio
6
6
  import time
7
-
8
- # Attempt to import from the local package structure
9
- try:
10
- from .ocr_engines import OCREngine
11
- from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
12
- from .data_types import OCRResult
13
- except ImportError:
14
- # Fallback for when the package is installed
15
- from vlm4ocr.ocr_engines import OCREngine
16
- from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
17
- from vlm4ocr.data_types import OCRResult
18
-
7
+ from .ocr_engines import OCREngine
8
+ from .vlm_engines import OpenAICompatibleVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
9
+ from .data_types import OCRResult
19
10
  import tqdm.asyncio
20
11
 
21
12
  # --- Global logger setup (console) ---
@@ -24,7 +15,12 @@ logging.basicConfig(
24
15
  format='%(asctime)s - %(levelname)s: %(message)s',
25
16
  datefmt='%Y-%m-%d %H:%M:%S'
26
17
  )
18
+ # Get our specific logger for CLI messages
27
19
  logger = logging.getLogger("vlm4ocr_cli")
20
+ # Get the logger that will receive captured warnings
21
+ # By default, warnings are logged to a logger named 'py.warnings'
22
+ warnings_logger = logging.getLogger('py.warnings')
23
+
28
24
 
29
25
  SUPPORTED_IMAGE_EXTS_CLI = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
30
26
  OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'.txt'}
@@ -65,17 +61,26 @@ def setup_file_logger(log_dir, timestamp_str, debug_mode):
65
61
  log_file_path = os.path.join(log_dir, log_file_name)
66
62
 
67
63
  file_handler = logging.FileHandler(log_file_path, mode='a')
68
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
64
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(name)s:%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
69
65
  file_handler.setFormatter(formatter)
70
66
 
71
67
  log_level = logging.DEBUG if debug_mode else logging.INFO
72
68
  file_handler.setLevel(log_level)
73
69
 
74
- logger.addHandler(file_handler)
70
+ # Add handler to the root logger to capture all logs (from our logger,
71
+ # and from the warnings logger 'py.warnings')
72
+ root_logger = logging.getLogger()
73
+ root_logger.addHandler(file_handler)
74
+
75
+ # We still configure our specific logger's level for console output
75
76
  logger.info(f"Logging to file: {log_file_path}")
76
77
 
77
78
 
78
79
  def main():
80
+ # Capture warnings from the 'warnings' module (like RuntimeWarning)
81
+ # and redirect them to the 'logging' system.
82
+ logging.captureWarnings(True)
83
+
79
84
  parser = argparse.ArgumentParser(
80
85
  description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models. Processing is concurrent by default.",
81
86
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -103,7 +108,8 @@ def main():
103
108
  vlm_engine_group.add_argument("--vlm_engine", choices=["openai", "azure_openai", "ollama", "openai_compatible"], required=True, help="VLM engine.")
104
109
  vlm_engine_group.add_argument("--model", required=True, help="Model identifier for the VLM engine.")
105
110
  vlm_engine_group.add_argument("--max_new_tokens", type=int, default=4096, help="Max new tokens for VLM.")
106
- vlm_engine_group.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
111
+ vlm_engine_group.add_argument("--temperature", type=float, default=None, help="Sampling temperature.")
112
+ vlm_engine_group.add_argument("--top_p", type=float, default=None, help="Sampling top p.")
107
113
 
108
114
  openai_group = parser.add_argument_group("OpenAI & OpenAI-Compatible Options")
109
115
  openai_group.add_argument("--api_key", default=os.environ.get("OPENAI_API_KEY"), help="API key.")
@@ -144,16 +150,23 @@ def main():
144
150
  current_timestamp_str = time.strftime("%Y%m%d_%H%M%S")
145
151
 
146
152
  # --- Configure Logger Level based on args ---
153
+ # Get root logger to control global level for libraries
154
+ root_logger = logging.getLogger()
155
+
147
156
  if args.debug:
148
- logger.setLevel(logging.DEBUG)
149
- # Set root logger to DEBUG only if our specific logger is DEBUG, to avoid overly verbose library logs unless intended.
150
- if logger.getEffectiveLevel() <= logging.DEBUG:
151
- logging.getLogger().setLevel(logging.DEBUG)
157
+ logger.setLevel(logging.DEBUG) # Our logger to DEBUG
158
+ warnings_logger.setLevel(logging.DEBUG) # Warnings logger to DEBUG
159
+ root_logger.setLevel(logging.DEBUG) # Root to DEBUG
152
160
  logger.debug("Debug mode enabled for console.")
153
161
  else:
154
- logger.setLevel(logging.INFO) # Default for our CLI's own messages
155
- logging.getLogger().setLevel(logging.WARNING) # Keep external libraries quieter by default
156
-
162
+ logger.setLevel(logging.INFO) # Our logger to INFO
163
+ warnings_logger.setLevel(logging.INFO) # Warnings logger to INFO
164
+ root_logger.setLevel(logging.WARNING) # Root to WARNING (quieter libraries)
165
+ # Our console handler (from basicConfig) is on the root logger,
166
+ # so setting root to WARNING makes console quiet
167
+ # But our logger (vlm4ocr_cli) is INFO, so if a file handler
168
+ # is added, it will get INFO messages from 'logger'
169
+
157
170
  if args.concurrent_batch_size < 1:
158
171
  parser.error("--concurrent_batch_size must be 1 or greater.")
159
172
 
@@ -192,6 +205,15 @@ def main():
192
205
  # --- Setup File Logger (if --log is specified) ---
193
206
  if args.log:
194
207
  setup_file_logger(effective_output_dir, current_timestamp_str, args.debug)
208
+ # If logging to file, we want our console to be less verbose
209
+ # if not in debug mode, so we set the console handler's level higher.
210
+ if not args.debug:
211
+ # Find the console handler (from basicConfig) and set its level
212
+ for handler in root_logger.handlers:
213
+ if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
214
+ handler.setLevel(logging.WARNING)
215
+ logger.debug("Set console handler level to WARNING.")
216
+ break
195
217
 
196
218
  logger.debug(f"Parsed arguments: {args}")
197
219
 
@@ -199,16 +221,18 @@ def main():
199
221
  vlm_engine_instance = None
200
222
  try:
201
223
  logger.info(f"Initializing VLM engine: {args.vlm_engine} with model: {args.model}")
224
+ logger.info(f"max_new_tokens: {args.max_new_tokens}, temperature: {args.temperature}, top_p: {args.top_p}")
202
225
  config = BasicVLMConfig(
203
226
  max_new_tokens=args.max_new_tokens,
204
- temperature=args.temperature
227
+ temperature=args.temperature,
228
+ top_p=args.top_p
205
229
  )
206
230
  if args.vlm_engine == "openai":
207
231
  if not args.api_key: parser.error("--api_key (or OPENAI_API_KEY) is required for OpenAI.")
208
232
  vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
209
233
  elif args.vlm_engine == "openai_compatible":
210
234
  if not args.base_url: parser.error("--base_url is required for openai_compatible.")
211
- vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
235
+ vlm_engine_instance = OpenAICompatibleVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
212
236
  elif args.vlm_engine == "azure_openai":
213
237
  if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
214
238
  if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
@@ -295,16 +319,34 @@ def main():
295
319
  # console verbosity controlled by logger level.
296
320
  show_progress_bar = (num_actual_files > 0)
297
321
 
322
+ # Only show progress bar if not in debug mode (debug logs would interfere)
323
+ # and if there are files to process.
324
+ # If logging to file, console can be quiet (INFO level).
325
+ # If NOT logging to file, console must be INFO level to show bar.
326
+
327
+ # Determine if progress bar should be active (not disabled)
328
+ # Disable bar if in debug mode (logs interfere) or no files
329
+ disable_bar = args.debug or not show_progress_bar
330
+
331
+ # If not logging to file AND not debug, we need console at INFO
332
+ if not args.log and not args.debug:
333
+ for handler in logging.getLogger().handlers:
334
+ if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
335
+ handler.setLevel(logging.INFO)
336
+ logger.debug("Set console handler level to INFO for progress bar.")
337
+ break
338
+
298
339
  iterator_wrapper = tqdm.asyncio.tqdm(
299
340
  ocr_task_generator,
300
341
  total=num_actual_files,
301
342
  desc="Processing files",
302
343
  unit="file",
303
- disable=not show_progress_bar # disable if no files, or can remove this disable if tqdm handles total=0
344
+ disable=disable_bar
304
345
  )
305
346
 
306
347
  async for result_object in iterator_wrapper:
307
348
  if not isinstance(result_object, OCRResult):
349
+ # This warning *will* now be captured by the file log
308
350
  logger.warning(f"Received unexpected data type: {type(result_object)}")
309
351
  continue
310
352
 
@@ -323,9 +365,12 @@ def main():
323
365
  content_to_write = result_object.to_string()
324
366
  with open(current_ocr_output_file_path, "w", encoding="utf-8") as f:
325
367
  f.write(content_to_write)
326
- # Log less verbosely to console if progress bar is active
327
- if not show_progress_bar or logger.getEffectiveLevel() <= logging.DEBUG:
328
- logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
368
+
369
+ # MODIFIED: Always log success info.
370
+ # This will go to the file log if active.
371
+ # It will NOT go to console if console level is WARNING.
372
+ logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
373
+
329
374
  except Exception as e:
330
375
  logger.error(f"Error writing output for '{input_file_path_from_result}' to '{current_ocr_output_file_path}': {e}")
331
376
 
@@ -1,7 +1,8 @@
1
1
  import os
2
- from typing import List, Literal
2
+ from typing import List, Dict, Literal
3
+ from PIL import Image
3
4
  from dataclasses import dataclass, field
4
- from vlm4ocr.utils import get_default_page_delimiter
5
+ from vlm4ocr.utils import get_default_page_delimiter, ImageProcessor
5
6
 
6
7
  OutputMode = Literal["markdown", "HTML", "text", "JSON"]
7
8
 
@@ -24,6 +25,7 @@ class OCRResult:
24
25
  pages: List[dict] = field(default_factory=list)
25
26
  filename: str = field(init=False)
26
27
  status: str = field(init=False, default="processing")
28
+ messages_log: List[List[Dict[str,str]]] = field(default_factory=list)
27
29
 
28
30
  def __post_init__(self):
29
31
  """
@@ -67,10 +69,6 @@ class OCRResult:
67
69
  }
68
70
  self.pages.append(page)
69
71
 
70
-
71
- def __len__(self):
72
- return len(self.pages)
73
-
74
72
  def get_page(self, idx):
75
73
  if not isinstance(idx, int):
76
74
  raise ValueError("Index must be an integer")
@@ -78,6 +76,21 @@ class OCRResult:
78
76
  raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
79
77
 
80
78
  return self.pages[idx]
79
+
80
+ def clear_messages_log(self):
81
+ self.messages_log = []
82
+
83
+ def add_messages_to_log(self, messages: List[Dict[str,str]]):
84
+ if not isinstance(messages, list):
85
+ raise ValueError("messages must be a list of dict")
86
+
87
+ self.messages_log.extend(messages)
88
+
89
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
90
+ return self.messages_log.copy()
91
+
92
+ def __len__(self):
93
+ return len(self.pages)
81
94
 
82
95
  def __iter__(self):
83
96
  return iter(self.pages)
@@ -106,4 +119,41 @@ class OCRResult:
106
119
  else:
107
120
  self.page_delimiter = page_delimiter
108
121
 
109
- return self.page_delimiter.join([page.get("text", "") for page in self.pages])
122
+ return self.page_delimiter.join([page.get("text", "") for page in self.pages])
123
+
124
+ @dataclass
125
+ class FewShotExample:
126
+ """
127
+ This class represents a few-shot example for OCR tasks.
128
+
129
+ Parameters:
130
+ ----------
131
+ image : PIL.Image.Image
132
+ The image associated with the example.
133
+ text : str
134
+ The expected OCR result text for the image.
135
+ rotate_correction : bool, Optional
136
+ If True, applies rotate correction to the images using pytesseract.
137
+ max_dimension_pixels : int, Optional
138
+ The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
139
+ """
140
+ image: Image.Image
141
+ text: str
142
+ rotate_correction: bool = False
143
+ max_dimension_pixels: int = None
144
+ def __post_init__(self):
145
+ if not isinstance(self.image, Image.Image):
146
+ raise ValueError("image must be a PIL.Image.Image object")
147
+ if not isinstance(self.text, str):
148
+ raise ValueError("text must be a string")
149
+
150
+ if self.rotate_correction or self.max_dimension_pixels is not None:
151
+ self.image_processor = ImageProcessor()
152
+
153
+ # Rotate correction if specified
154
+ if self.rotate_correction:
155
+ self.image, _ = self.image_processor.rotate_correction(self.image)
156
+
157
+ # Resize image if max_dimension_pixels is specified
158
+ if self.max_dimension_pixels is not None:
159
+ self.image, _ = self.image_processor.resize(image=self.image, max_dimension_pixels=self.max_dimension_pixels)
@@ -1,12 +1,12 @@
1
1
  import os
2
- from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
2
+ from typing import Any, Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
3
3
  import importlib
4
4
  import asyncio
5
5
  from colorama import Fore, Style
6
6
  import json
7
7
  from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
8
- from vlm4ocr.data_types import OCRResult
9
- from vlm4ocr.vlm_engines import VLMEngine
8
+ from vlm4ocr.data_types import OCRResult, FewShotExample
9
+ from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
10
10
 
11
11
  SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
12
12
 
@@ -60,7 +60,8 @@ class OCREngine:
60
60
  self.image_processor = ImageProcessor()
61
61
 
62
62
 
63
- def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Generator[Dict[str, str], None, None]:
63
+ def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None,
64
+ few_shot_examples:List[FewShotExample]=None) -> Generator[Dict[str, str], None, None]:
64
65
  """
65
66
  This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
66
67
  Yields dictionaries with 'type' ('ocr_chunk' or 'page_delimiter') and 'data'.
@@ -73,6 +74,8 @@ class OCREngine:
73
74
  If True, applies rotate correction to the images using pytesseract.
74
75
  max_dimension_pixels : int, Optional
75
76
  The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
77
+ few_shot_examples : List[FewShotExample], Optional
78
+ list of few-shot examples.
76
79
 
77
80
  Returns:
78
81
  --------
@@ -90,10 +93,6 @@ class OCREngine:
90
93
  file_ext = os.path.splitext(file_path)[1].lower()
91
94
  if file_ext not in SUPPORTED_IMAGE_EXTS:
92
95
  raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
93
-
94
- # Check if image preprocessing can be applied
95
- if self.image_processor.has_tesseract==False and rotate_correction:
96
- raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
97
96
 
98
97
  # PDF or TIFF
99
98
  if file_ext in ['.pdf', '.tif', '.tiff']:
@@ -105,8 +104,8 @@ class OCREngine:
105
104
 
106
105
  # OCR each image
107
106
  for i, image in enumerate(images):
108
- # Apply rotate correction if specified and tesseract is available
109
- if rotate_correction and self.image_processor.has_tesseract:
107
+ # Apply rotate correction if specified
108
+ if rotate_correction:
110
109
  try:
111
110
  image, _ = self.image_processor.rotate_correction(image)
112
111
 
@@ -120,13 +119,20 @@ class OCREngine:
120
119
  except Exception as e:
121
120
  yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
122
121
 
123
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
122
+ # Get OCR messages
123
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
124
+ user_prompt=self.user_prompt,
125
+ image=image,
126
+ few_shot_examples=few_shot_examples)
127
+
128
+ # Stream response
124
129
  response_stream = self.vlm_engine.chat(
125
130
  messages,
126
131
  stream=True
127
132
  )
128
133
  for chunk in response_stream:
129
- yield {"type": "ocr_chunk", "data": chunk}
134
+ if chunk["type"] == "response":
135
+ yield {"type": "ocr_chunk", "data": chunk["data"]}
130
136
 
131
137
  if i < len(images) - 1:
132
138
  yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
@@ -136,8 +142,8 @@ class OCREngine:
136
142
  data_loader = ImageDataLoader(file_path)
137
143
  image = data_loader.get_page(0)
138
144
 
139
- # Apply rotate correction if specified and tesseract is available
140
- if rotate_correction and self.image_processor.has_tesseract:
145
+ # Apply rotate correction if specified
146
+ if rotate_correction:
141
147
  try:
142
148
  image, _ = self.image_processor.rotate_correction(image)
143
149
 
@@ -151,17 +157,23 @@ class OCREngine:
151
157
  except Exception as e:
152
158
  yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
153
159
 
154
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
160
+ # Get OCR messages
161
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
162
+ user_prompt=self.user_prompt,
163
+ image=image,
164
+ few_shot_examples=few_shot_examples)
165
+ # Stream response
155
166
  response_stream = self.vlm_engine.chat(
156
167
  messages,
157
168
  stream=True
158
169
  )
159
170
  for chunk in response_stream:
160
- yield {"type": "ocr_chunk", "data": chunk}
171
+ if chunk["type"] == "response":
172
+ yield {"type": "ocr_chunk", "data": chunk["data"]}
161
173
 
162
174
 
163
175
  def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
164
- max_dimension_pixels:int=None, verbose:bool=False) -> List[OCRResult]:
176
+ max_dimension_pixels:int=None, verbose:bool=False, few_shot_examples:List[FewShotExample]=None) -> List[OCRResult]:
165
177
  """
166
178
  This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
167
179
 
@@ -175,6 +187,8 @@ class OCREngine:
175
187
  The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
176
188
  verbose : bool, Optional
177
189
  If True, the function will print the output in terminal.
190
+ few_shot_examples : List[FewShotExample], Optional
191
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
178
192
 
179
193
  Returns:
180
194
  --------
@@ -184,6 +198,7 @@ class OCREngine:
184
198
  if isinstance(file_paths, str):
185
199
  file_paths = [file_paths]
186
200
 
201
+ # Iterate through file paths
187
202
  ocr_results = []
188
203
  for file_path in file_paths:
189
204
  # Define OCRResult object
@@ -233,8 +248,8 @@ class OCREngine:
233
248
  # OCR images
234
249
  for i, image in enumerate(images):
235
250
  image_processing_status = {}
236
- # Apply rotate correction if specified and tesseract is available
237
- if rotate_correction and self.image_processor.has_tesseract:
251
+ # Apply rotate correction if specified
252
+ if rotate_correction:
238
253
  try:
239
254
  image, rotation_angle = self.image_processor.rotate_correction(image)
240
255
  image_processing_status["rotate_correction"] = {
@@ -270,25 +285,36 @@ class OCREngine:
270
285
  print(f"{Fore.RED}Error resizing image for {filename}:{Style.RESET_ALL} {resized['error']}. OCR continues without resizing.")
271
286
 
272
287
  try:
273
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
288
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
289
+ user_prompt=self.user_prompt,
290
+ image=image,
291
+ few_shot_examples=few_shot_examples)
292
+ # Define a messages logger to capture messages
293
+ messages_logger = MessagesLogger()
294
+ # Generate response
274
295
  response = self.vlm_engine.chat(
275
296
  messages,
276
297
  verbose=verbose,
277
- stream=False
298
+ stream=False,
299
+ messages_logger=messages_logger
278
300
  )
301
+ ocr_text = response["response"]
279
302
  # Clean the response if output mode is markdown
280
303
  if self.output_mode == "markdown":
281
- response = clean_markdown(response)
304
+ ocr_text = clean_markdown(ocr_text)
282
305
 
283
306
  # Parse the response if output mode is JSON
284
- if self.output_mode == "JSON":
285
- json_list = extract_json(response)
307
+ elif self.output_mode == "JSON":
308
+ json_list = extract_json(ocr_text)
286
309
  # Serialize the JSON list to a string
287
- response = json.dumps(json_list, indent=4)
310
+ ocr_text = json.dumps(json_list, indent=4)
288
311
 
289
312
  # Add the page to the OCR result
290
- ocr_result.add_page(text=response,
313
+ ocr_result.add_page(text=ocr_text,
291
314
  image_processing_status=image_processing_status)
315
+
316
+ # Add messages log to the OCR result
317
+ ocr_result.add_messages_to_log(messages_logger.get_messages_log())
292
318
 
293
319
  except Exception as page_e:
294
320
  ocr_result.status = "error"
@@ -298,11 +324,12 @@ class OCREngine:
298
324
  print(f"{Fore.RED}Error during OCR for a page in {filename}:{Style.RESET_ALL} {page_e}")
299
325
 
300
326
  # Add the OCR result to the list
301
- ocr_result.status = "success"
327
+ if ocr_result.status != "error":
328
+ ocr_result.status = "success"
302
329
  ocr_results.append(ocr_result)
303
330
 
304
331
  if verbose:
305
- print(f"{Fore.BLUE}Successfully processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
332
+ print(f"{Fore.BLUE}Processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
306
333
  for page in ocr_result:
307
334
  print(page)
308
335
  print("-" * 80)
@@ -311,7 +338,8 @@ class OCREngine:
311
338
 
312
339
 
313
340
  def concurrent_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
314
- max_dimension_pixels:int=None, concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
341
+ max_dimension_pixels:int=None, few_shot_examples:List[FewShotExample]=None,
342
+ concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
315
343
  """
316
344
  First complete first out. Input and output order not guaranteed.
317
345
  This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
@@ -325,6 +353,8 @@ class OCREngine:
325
353
  If True, applies rotate correction to the images using pytesseract.
326
354
  max_dimension_pixels : int, Optional
327
355
  The maximum dimension of the image in pixels. Origianl dimensions will be resized to fit in. If None, no resizing is applied.
356
+ few_shot_examples : List[FewShotExample], Optional
357
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
328
358
  concurrent_batch_size : int, Optional
329
359
  The number of concurrent VLM calls to make.
330
360
  max_file_load : int, Optional
@@ -343,18 +373,17 @@ class OCREngine:
343
373
 
344
374
  if not isinstance(max_file_load, int) or max_file_load <= 0:
345
375
  raise ValueError("max_file_load must be a positive integer")
346
-
347
- if self.image_processor.has_tesseract==False and rotate_correction:
348
- raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
349
376
 
350
377
  return self._ocr_async(file_paths=file_paths,
351
378
  rotate_correction=rotate_correction,
352
379
  max_dimension_pixels=max_dimension_pixels,
380
+ few_shot_examples=few_shot_examples,
353
381
  concurrent_batch_size=concurrent_batch_size,
354
382
  max_file_load=max_file_load)
355
383
 
356
384
 
357
385
  async def _ocr_async(self, file_paths: Iterable[str], rotate_correction:bool=False, max_dimension_pixels:int=None,
386
+ few_shot_examples:List[FewShotExample]=None,
358
387
  concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
359
388
  """
360
389
  Internal method to asynchronously process an iterable of file paths.
@@ -370,7 +399,8 @@ class OCREngine:
370
399
  vlm_call_semaphore=vlm_call_semaphore,
371
400
  file_path=file_path,
372
401
  rotate_correction=rotate_correction,
373
- max_dimension_pixels=max_dimension_pixels)
402
+ max_dimension_pixels=max_dimension_pixels,
403
+ few_shot_examples=few_shot_examples)
374
404
  tasks.append(task)
375
405
 
376
406
 
@@ -379,7 +409,8 @@ class OCREngine:
379
409
  yield result
380
410
 
381
411
  async def _ocr_file_with_semaphore(self, file_load_semaphore:asyncio.Semaphore, vlm_call_semaphore:asyncio.Semaphore,
382
- file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None) -> OCRResult:
412
+ file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None,
413
+ few_shot_examples:List[FewShotExample]=None) -> OCRResult:
383
414
  """
384
415
  This internal method takes a semaphore and OCR a single file using the VLM inference engine.
385
416
  """
@@ -387,6 +418,7 @@ class OCREngine:
387
418
  filename = os.path.basename(file_path)
388
419
  file_ext = os.path.splitext(file_path)[1].lower()
389
420
  result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
421
+ messages_logger = MessagesLogger()
390
422
  # check file extension
391
423
  if file_ext not in SUPPORTED_IMAGE_EXTS:
392
424
  result.status = "error"
@@ -416,7 +448,9 @@ class OCREngine:
416
448
  data_loader=data_loader,
417
449
  page_index=page_index,
418
450
  rotate_correction=rotate_correction,
419
- max_dimension_pixels=max_dimension_pixels
451
+ max_dimension_pixels=max_dimension_pixels,
452
+ few_shot_examples=few_shot_examples,
453
+ messages_logger=messages_logger
420
454
  )
421
455
  page_processing_tasks.append(task)
422
456
 
@@ -428,14 +462,18 @@ class OCREngine:
428
462
  except Exception as e:
429
463
  result.status = "error"
430
464
  result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
465
+ result.add_messages_to_log(messages_logger.get_messages_log())
431
466
  return result
432
467
 
433
468
  # Set status to success if no errors occurred
434
- result.status = "success"
469
+ if result.status != "error":
470
+ result.status = "success"
471
+ result.add_messages_to_log(messages_logger.get_messages_log())
435
472
  return result
436
473
 
437
474
  async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
438
- page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Tuple[str, Dict[str, str]]:
475
+ page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
476
+ few_shot_examples:List[FewShotExample]=None, messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
439
477
  """
440
478
  This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
441
479
 
@@ -447,8 +485,8 @@ class OCREngine:
447
485
  async with vlm_call_semaphore:
448
486
  image = await data_loader.get_page_async(page_index)
449
487
  image_processing_status = {}
450
- # Apply rotate correction if specified and tesseract is available
451
- if rotate_correction and self.image_processor.has_tesseract:
488
+ # Apply rotate correction if specified
489
+ if rotate_correction:
452
490
  try:
453
491
  image, rotation_angle = await self.image_processor.rotate_correction_async(image)
454
492
  image_processing_status["rotate_correction"] = {
@@ -475,16 +513,21 @@ class OCREngine:
475
513
  "error": str(e)
476
514
  }
477
515
 
478
- messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
479
- ocr_text = await self.vlm_engine.chat_async(
516
+ messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
517
+ user_prompt=self.user_prompt,
518
+ image=image,
519
+ few_shot_examples=few_shot_examples)
520
+ response = await self.vlm_engine.chat_async(
480
521
  messages,
522
+ messages_logger=messages_logger
481
523
  )
524
+ ocr_text = response["response"]
482
525
  # Clean the OCR text if output mode is markdown
483
526
  if self.output_mode == "markdown":
484
527
  ocr_text = clean_markdown(ocr_text)
485
528
 
486
529
  # Parse the response if output mode is JSON
487
- if self.output_mode == "JSON":
530
+ elif self.output_mode == "JSON":
488
531
  json_list = extract_json(ocr_text)
489
532
  # Serialize the JSON list to a string
490
533
  ocr_text = json.dumps(json_list, indent=4)