vlm4ocr 0.3.1__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/PKG-INFO +2 -1
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/pyproject.toml +2 -1
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/__init__.py +2 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/cli.py +69 -15
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/data_types.py +40 -2
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/ocr_engines.py +56 -30
- vlm4ocr-0.4.1/vlm4ocr/vlm_engines.py +276 -0
- vlm4ocr-0.3.1/vlm4ocr/vlm_engines.py +0 -1163
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/README.md +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_markdown_user_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt +0 -0
- {vlm4ocr-0.3.1 → vlm4ocr-0.4.1}/vlm4ocr/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vlm4ocr
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Python package and Web App for OCR with vision language models.
|
|
5
5
|
License: MIT
|
|
6
6
|
Author: Enshuo (David) Hsu
|
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
12
12
|
Provides-Extra: tesseract
|
|
13
13
|
Requires-Dist: colorama (>=0.4.4)
|
|
14
14
|
Requires-Dist: json-repair (>=0.30.0)
|
|
15
|
+
Requires-Dist: llm-inference-engine (>=0.1.1,<0.2.0)
|
|
15
16
|
Requires-Dist: pdf2image (>=1.16.0)
|
|
16
17
|
Requires-Dist: pillow (>=10.0.0)
|
|
17
18
|
Requires-Dist: pytesseract (>=0.3.13) ; extra == "tesseract"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "vlm4ocr"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.1"
|
|
4
4
|
description = "Python package and Web App for OCR with vision language models."
|
|
5
5
|
authors = ["Enshuo (David) Hsu"]
|
|
6
6
|
license = "MIT"
|
|
@@ -18,6 +18,7 @@ pdf2image = ">=1.16.0"
|
|
|
18
18
|
colorama = ">=0.4.4"
|
|
19
19
|
pillow = ">=10.0.0"
|
|
20
20
|
json-repair = ">=0.30.0"
|
|
21
|
+
llm-inference-engine = "^0.1.1"
|
|
21
22
|
pytesseract = { version = ">=0.3.13", optional = true }
|
|
22
23
|
|
|
23
24
|
[tool.poetry.scripts]
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
from .data_types import FewShotExample
|
|
1
2
|
from .ocr_engines import OCREngine
|
|
2
3
|
from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
3
4
|
|
|
4
5
|
__all__ = [
|
|
6
|
+
"FewShotExample",
|
|
5
7
|
"BasicVLMConfig",
|
|
6
8
|
"ReasoningVLMConfig",
|
|
7
9
|
"OpenAIReasoningVLMConfig",
|
|
@@ -15,7 +15,12 @@ logging.basicConfig(
|
|
|
15
15
|
format='%(asctime)s - %(levelname)s: %(message)s',
|
|
16
16
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
17
17
|
)
|
|
18
|
+
# Get our specific logger for CLI messages
|
|
18
19
|
logger = logging.getLogger("vlm4ocr_cli")
|
|
20
|
+
# Get the logger that will receive captured warnings
|
|
21
|
+
# By default, warnings are logged to a logger named 'py.warnings'
|
|
22
|
+
warnings_logger = logging.getLogger('py.warnings')
|
|
23
|
+
|
|
19
24
|
|
|
20
25
|
SUPPORTED_IMAGE_EXTS_CLI = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
21
26
|
OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'.txt'}
|
|
@@ -56,17 +61,26 @@ def setup_file_logger(log_dir, timestamp_str, debug_mode):
|
|
|
56
61
|
log_file_path = os.path.join(log_dir, log_file_name)
|
|
57
62
|
|
|
58
63
|
file_handler = logging.FileHandler(log_file_path, mode='a')
|
|
59
|
-
formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
64
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(name)s:%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
60
65
|
file_handler.setFormatter(formatter)
|
|
61
66
|
|
|
62
67
|
log_level = logging.DEBUG if debug_mode else logging.INFO
|
|
63
68
|
file_handler.setLevel(log_level)
|
|
64
69
|
|
|
65
|
-
logger
|
|
70
|
+
# Add handler to the root logger to capture all logs (from our logger,
|
|
71
|
+
# and from the warnings logger 'py.warnings')
|
|
72
|
+
root_logger = logging.getLogger()
|
|
73
|
+
root_logger.addHandler(file_handler)
|
|
74
|
+
|
|
75
|
+
# We still configure our specific logger's level for console output
|
|
66
76
|
logger.info(f"Logging to file: {log_file_path}")
|
|
67
77
|
|
|
68
78
|
|
|
69
79
|
def main():
|
|
80
|
+
# Capture warnings from the 'warnings' module (like RuntimeWarning)
|
|
81
|
+
# and redirect them to the 'logging' system.
|
|
82
|
+
logging.captureWarnings(True)
|
|
83
|
+
|
|
70
84
|
parser = argparse.ArgumentParser(
|
|
71
85
|
description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models. Processing is concurrent by default.",
|
|
72
86
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
@@ -94,7 +108,8 @@ def main():
|
|
|
94
108
|
vlm_engine_group.add_argument("--vlm_engine", choices=["openai", "azure_openai", "ollama", "openai_compatible"], required=True, help="VLM engine.")
|
|
95
109
|
vlm_engine_group.add_argument("--model", required=True, help="Model identifier for the VLM engine.")
|
|
96
110
|
vlm_engine_group.add_argument("--max_new_tokens", type=int, default=4096, help="Max new tokens for VLM.")
|
|
97
|
-
vlm_engine_group.add_argument("--temperature", type=float, default=
|
|
111
|
+
vlm_engine_group.add_argument("--temperature", type=float, default=None, help="Sampling temperature.")
|
|
112
|
+
vlm_engine_group.add_argument("--top_p", type=float, default=None, help="Sampling top p.")
|
|
98
113
|
|
|
99
114
|
openai_group = parser.add_argument_group("OpenAI & OpenAI-Compatible Options")
|
|
100
115
|
openai_group.add_argument("--api_key", default=os.environ.get("OPENAI_API_KEY"), help="API key.")
|
|
@@ -135,16 +150,23 @@ def main():
|
|
|
135
150
|
current_timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
136
151
|
|
|
137
152
|
# --- Configure Logger Level based on args ---
|
|
153
|
+
# Get root logger to control global level for libraries
|
|
154
|
+
root_logger = logging.getLogger()
|
|
155
|
+
|
|
138
156
|
if args.debug:
|
|
139
|
-
logger.setLevel(logging.DEBUG)
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
logging.getLogger().setLevel(logging.DEBUG)
|
|
157
|
+
logger.setLevel(logging.DEBUG) # Our logger to DEBUG
|
|
158
|
+
warnings_logger.setLevel(logging.DEBUG) # Warnings logger to DEBUG
|
|
159
|
+
root_logger.setLevel(logging.DEBUG) # Root to DEBUG
|
|
143
160
|
logger.debug("Debug mode enabled for console.")
|
|
144
161
|
else:
|
|
145
|
-
logger.setLevel(logging.INFO) #
|
|
146
|
-
|
|
147
|
-
|
|
162
|
+
logger.setLevel(logging.INFO) # Our logger to INFO
|
|
163
|
+
warnings_logger.setLevel(logging.INFO) # Warnings logger to INFO
|
|
164
|
+
root_logger.setLevel(logging.WARNING) # Root to WARNING (quieter libraries)
|
|
165
|
+
# Our console handler (from basicConfig) is on the root logger,
|
|
166
|
+
# so setting root to WARNING makes console quiet
|
|
167
|
+
# But our logger (vlm4ocr_cli) is INFO, so if a file handler
|
|
168
|
+
# is added, it will get INFO messages from 'logger'
|
|
169
|
+
|
|
148
170
|
if args.concurrent_batch_size < 1:
|
|
149
171
|
parser.error("--concurrent_batch_size must be 1 or greater.")
|
|
150
172
|
|
|
@@ -183,6 +205,15 @@ def main():
|
|
|
183
205
|
# --- Setup File Logger (if --log is specified) ---
|
|
184
206
|
if args.log:
|
|
185
207
|
setup_file_logger(effective_output_dir, current_timestamp_str, args.debug)
|
|
208
|
+
# If logging to file, we want our console to be less verbose
|
|
209
|
+
# if not in debug mode, so we set the console handler's level higher.
|
|
210
|
+
if not args.debug:
|
|
211
|
+
# Find the console handler (from basicConfig) and set its level
|
|
212
|
+
for handler in root_logger.handlers:
|
|
213
|
+
if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
|
|
214
|
+
handler.setLevel(logging.WARNING)
|
|
215
|
+
logger.debug("Set console handler level to WARNING.")
|
|
216
|
+
break
|
|
186
217
|
|
|
187
218
|
logger.debug(f"Parsed arguments: {args}")
|
|
188
219
|
|
|
@@ -190,9 +221,11 @@ def main():
|
|
|
190
221
|
vlm_engine_instance = None
|
|
191
222
|
try:
|
|
192
223
|
logger.info(f"Initializing VLM engine: {args.vlm_engine} with model: {args.model}")
|
|
224
|
+
logger.info(f"max_new_tokens: {args.max_new_tokens}, temperature: {args.temperature}, top_p: {args.top_p}")
|
|
193
225
|
config = BasicVLMConfig(
|
|
194
226
|
max_new_tokens=args.max_new_tokens,
|
|
195
|
-
temperature=args.temperature
|
|
227
|
+
temperature=args.temperature,
|
|
228
|
+
top_p=args.top_p
|
|
196
229
|
)
|
|
197
230
|
if args.vlm_engine == "openai":
|
|
198
231
|
if not args.api_key: parser.error("--api_key (or OPENAI_API_KEY) is required for OpenAI.")
|
|
@@ -286,16 +319,34 @@ def main():
|
|
|
286
319
|
# console verbosity controlled by logger level.
|
|
287
320
|
show_progress_bar = (num_actual_files > 0)
|
|
288
321
|
|
|
322
|
+
# Only show progress bar if not in debug mode (debug logs would interfere)
|
|
323
|
+
# and if there are files to process.
|
|
324
|
+
# If logging to file, console can be quiet (INFO level).
|
|
325
|
+
# If NOT logging to file, console must be INFO level to show bar.
|
|
326
|
+
|
|
327
|
+
# Determine if progress bar should be active (not disabled)
|
|
328
|
+
# Disable bar if in debug mode (logs interfere) or no files
|
|
329
|
+
disable_bar = args.debug or not show_progress_bar
|
|
330
|
+
|
|
331
|
+
# If not logging to file AND not debug, we need console at INFO
|
|
332
|
+
if not args.log and not args.debug:
|
|
333
|
+
for handler in logging.getLogger().handlers:
|
|
334
|
+
if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
|
|
335
|
+
handler.setLevel(logging.INFO)
|
|
336
|
+
logger.debug("Set console handler level to INFO for progress bar.")
|
|
337
|
+
break
|
|
338
|
+
|
|
289
339
|
iterator_wrapper = tqdm.asyncio.tqdm(
|
|
290
340
|
ocr_task_generator,
|
|
291
341
|
total=num_actual_files,
|
|
292
342
|
desc="Processing files",
|
|
293
343
|
unit="file",
|
|
294
|
-
disable=
|
|
344
|
+
disable=disable_bar
|
|
295
345
|
)
|
|
296
346
|
|
|
297
347
|
async for result_object in iterator_wrapper:
|
|
298
348
|
if not isinstance(result_object, OCRResult):
|
|
349
|
+
# This warning *will* now be captured by the file log
|
|
299
350
|
logger.warning(f"Received unexpected data type: {type(result_object)}")
|
|
300
351
|
continue
|
|
301
352
|
|
|
@@ -314,9 +365,12 @@ def main():
|
|
|
314
365
|
content_to_write = result_object.to_string()
|
|
315
366
|
with open(current_ocr_output_file_path, "w", encoding="utf-8") as f:
|
|
316
367
|
f.write(content_to_write)
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
368
|
+
|
|
369
|
+
# MODIFIED: Always log success info.
|
|
370
|
+
# This will go to the file log if active.
|
|
371
|
+
# It will NOT go to console if console level is WARNING.
|
|
372
|
+
logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
|
|
373
|
+
|
|
320
374
|
except Exception as e:
|
|
321
375
|
logger.error(f"Error writing output for '{input_file_path_from_result}' to '{current_ocr_output_file_path}': {e}")
|
|
322
376
|
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import List, Dict, Literal
|
|
3
|
+
from PIL import Image
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
|
-
from vlm4ocr.utils import get_default_page_delimiter
|
|
5
|
+
from vlm4ocr.utils import get_default_page_delimiter, ImageProcessor
|
|
5
6
|
|
|
6
7
|
OutputMode = Literal["markdown", "HTML", "text", "JSON"]
|
|
7
8
|
|
|
@@ -118,4 +119,41 @@ class OCRResult:
|
|
|
118
119
|
else:
|
|
119
120
|
self.page_delimiter = page_delimiter
|
|
120
121
|
|
|
121
|
-
return self.page_delimiter.join([page.get("text", "") for page in self.pages])
|
|
122
|
+
return self.page_delimiter.join([page.get("text", "") for page in self.pages])
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class FewShotExample:
|
|
126
|
+
"""
|
|
127
|
+
This class represents a few-shot example for OCR tasks.
|
|
128
|
+
|
|
129
|
+
Parameters:
|
|
130
|
+
----------
|
|
131
|
+
image : PIL.Image.Image
|
|
132
|
+
The image associated with the example.
|
|
133
|
+
text : str
|
|
134
|
+
The expected OCR result text for the image.
|
|
135
|
+
rotate_correction : bool, Optional
|
|
136
|
+
If True, applies rotate correction to the images using pytesseract.
|
|
137
|
+
max_dimension_pixels : int, Optional
|
|
138
|
+
The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
|
|
139
|
+
"""
|
|
140
|
+
image: Image.Image
|
|
141
|
+
text: str
|
|
142
|
+
rotate_correction: bool = False
|
|
143
|
+
max_dimension_pixels: int = None
|
|
144
|
+
def __post_init__(self):
|
|
145
|
+
if not isinstance(self.image, Image.Image):
|
|
146
|
+
raise ValueError("image must be a PIL.Image.Image object")
|
|
147
|
+
if not isinstance(self.text, str):
|
|
148
|
+
raise ValueError("text must be a string")
|
|
149
|
+
|
|
150
|
+
if self.rotate_correction or self.max_dimension_pixels is not None:
|
|
151
|
+
self.image_processor = ImageProcessor()
|
|
152
|
+
|
|
153
|
+
# Rotate correction if specified
|
|
154
|
+
if self.rotate_correction:
|
|
155
|
+
self.image, _ = self.image_processor.rotate_correction(self.image)
|
|
156
|
+
|
|
157
|
+
# Resize image if max_dimension_pixels is specified
|
|
158
|
+
if self.max_dimension_pixels is not None:
|
|
159
|
+
self.image, _ = self.image_processor.resize(image=self.image, max_dimension_pixels=self.max_dimension_pixels)
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
|
|
2
|
+
from typing import Any, Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
|
|
3
3
|
import importlib
|
|
4
4
|
import asyncio
|
|
5
5
|
from colorama import Fore, Style
|
|
6
6
|
import json
|
|
7
7
|
from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
|
|
8
|
-
from vlm4ocr.data_types import OCRResult
|
|
8
|
+
from vlm4ocr.data_types import OCRResult, FewShotExample
|
|
9
9
|
from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
|
|
10
10
|
|
|
11
11
|
SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
@@ -60,7 +60,8 @@ class OCREngine:
|
|
|
60
60
|
self.image_processor = ImageProcessor()
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
63
|
+
def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
64
|
+
few_shot_examples:List[FewShotExample]=None) -> Generator[Dict[str, str], None, None]:
|
|
64
65
|
"""
|
|
65
66
|
This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
|
|
66
67
|
Yields dictionaries with 'type' ('ocr_chunk' or 'page_delimiter') and 'data'.
|
|
@@ -73,6 +74,8 @@ class OCREngine:
|
|
|
73
74
|
If True, applies rotate correction to the images using pytesseract.
|
|
74
75
|
max_dimension_pixels : int, Optional
|
|
75
76
|
The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
|
|
77
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
78
|
+
list of few-shot examples.
|
|
76
79
|
|
|
77
80
|
Returns:
|
|
78
81
|
--------
|
|
@@ -90,10 +93,6 @@ class OCREngine:
|
|
|
90
93
|
file_ext = os.path.splitext(file_path)[1].lower()
|
|
91
94
|
if file_ext not in SUPPORTED_IMAGE_EXTS:
|
|
92
95
|
raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
|
|
93
|
-
|
|
94
|
-
# Check if image preprocessing can be applied
|
|
95
|
-
if self.image_processor.has_tesseract==False and rotate_correction:
|
|
96
|
-
raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
|
|
97
96
|
|
|
98
97
|
# PDF or TIFF
|
|
99
98
|
if file_ext in ['.pdf', '.tif', '.tiff']:
|
|
@@ -105,8 +104,8 @@ class OCREngine:
|
|
|
105
104
|
|
|
106
105
|
# OCR each image
|
|
107
106
|
for i, image in enumerate(images):
|
|
108
|
-
# Apply rotate correction if specified
|
|
109
|
-
if rotate_correction
|
|
107
|
+
# Apply rotate correction if specified
|
|
108
|
+
if rotate_correction:
|
|
110
109
|
try:
|
|
111
110
|
image, _ = self.image_processor.rotate_correction(image)
|
|
112
111
|
|
|
@@ -120,7 +119,13 @@ class OCREngine:
|
|
|
120
119
|
except Exception as e:
|
|
121
120
|
yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
|
|
122
121
|
|
|
123
|
-
|
|
122
|
+
# Get OCR messages
|
|
123
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
124
|
+
user_prompt=self.user_prompt,
|
|
125
|
+
image=image,
|
|
126
|
+
few_shot_examples=few_shot_examples)
|
|
127
|
+
|
|
128
|
+
# Stream response
|
|
124
129
|
response_stream = self.vlm_engine.chat(
|
|
125
130
|
messages,
|
|
126
131
|
stream=True
|
|
@@ -137,8 +142,8 @@ class OCREngine:
|
|
|
137
142
|
data_loader = ImageDataLoader(file_path)
|
|
138
143
|
image = data_loader.get_page(0)
|
|
139
144
|
|
|
140
|
-
# Apply rotate correction if specified
|
|
141
|
-
if rotate_correction
|
|
145
|
+
# Apply rotate correction if specified
|
|
146
|
+
if rotate_correction:
|
|
142
147
|
try:
|
|
143
148
|
image, _ = self.image_processor.rotate_correction(image)
|
|
144
149
|
|
|
@@ -152,7 +157,12 @@ class OCREngine:
|
|
|
152
157
|
except Exception as e:
|
|
153
158
|
yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
|
|
154
159
|
|
|
155
|
-
|
|
160
|
+
# Get OCR messages
|
|
161
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
162
|
+
user_prompt=self.user_prompt,
|
|
163
|
+
image=image,
|
|
164
|
+
few_shot_examples=few_shot_examples)
|
|
165
|
+
# Stream response
|
|
156
166
|
response_stream = self.vlm_engine.chat(
|
|
157
167
|
messages,
|
|
158
168
|
stream=True
|
|
@@ -163,7 +173,7 @@ class OCREngine:
|
|
|
163
173
|
|
|
164
174
|
|
|
165
175
|
def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
|
|
166
|
-
max_dimension_pixels:int=None, verbose:bool=False) -> List[OCRResult]:
|
|
176
|
+
max_dimension_pixels:int=None, verbose:bool=False, few_shot_examples:List[FewShotExample]=None) -> List[OCRResult]:
|
|
167
177
|
"""
|
|
168
178
|
This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
|
|
169
179
|
|
|
@@ -177,6 +187,8 @@ class OCREngine:
|
|
|
177
187
|
The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
|
|
178
188
|
verbose : bool, Optional
|
|
179
189
|
If True, the function will print the output in terminal.
|
|
190
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
191
|
+
list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
|
|
180
192
|
|
|
181
193
|
Returns:
|
|
182
194
|
--------
|
|
@@ -186,6 +198,7 @@ class OCREngine:
|
|
|
186
198
|
if isinstance(file_paths, str):
|
|
187
199
|
file_paths = [file_paths]
|
|
188
200
|
|
|
201
|
+
# Iterate through file paths
|
|
189
202
|
ocr_results = []
|
|
190
203
|
for file_path in file_paths:
|
|
191
204
|
# Define OCRResult object
|
|
@@ -235,8 +248,8 @@ class OCREngine:
|
|
|
235
248
|
# OCR images
|
|
236
249
|
for i, image in enumerate(images):
|
|
237
250
|
image_processing_status = {}
|
|
238
|
-
# Apply rotate correction if specified
|
|
239
|
-
if rotate_correction
|
|
251
|
+
# Apply rotate correction if specified
|
|
252
|
+
if rotate_correction:
|
|
240
253
|
try:
|
|
241
254
|
image, rotation_angle = self.image_processor.rotate_correction(image)
|
|
242
255
|
image_processing_status["rotate_correction"] = {
|
|
@@ -272,7 +285,10 @@ class OCREngine:
|
|
|
272
285
|
print(f"{Fore.RED}Error resizing image for {filename}:{Style.RESET_ALL} {resized['error']}. OCR continues without resizing.")
|
|
273
286
|
|
|
274
287
|
try:
|
|
275
|
-
messages = self.vlm_engine.get_ocr_messages(self.system_prompt,
|
|
288
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
289
|
+
user_prompt=self.user_prompt,
|
|
290
|
+
image=image,
|
|
291
|
+
few_shot_examples=few_shot_examples)
|
|
276
292
|
# Define a messages logger to capture messages
|
|
277
293
|
messages_logger = MessagesLogger()
|
|
278
294
|
# Generate response
|
|
@@ -308,11 +324,12 @@ class OCREngine:
|
|
|
308
324
|
print(f"{Fore.RED}Error during OCR for a page in {filename}:{Style.RESET_ALL} {page_e}")
|
|
309
325
|
|
|
310
326
|
# Add the OCR result to the list
|
|
311
|
-
ocr_result.status
|
|
327
|
+
if ocr_result.status != "error":
|
|
328
|
+
ocr_result.status = "success"
|
|
312
329
|
ocr_results.append(ocr_result)
|
|
313
330
|
|
|
314
331
|
if verbose:
|
|
315
|
-
print(f"{Fore.BLUE}
|
|
332
|
+
print(f"{Fore.BLUE}Processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
|
|
316
333
|
for page in ocr_result:
|
|
317
334
|
print(page)
|
|
318
335
|
print("-" * 80)
|
|
@@ -321,7 +338,8 @@ class OCREngine:
|
|
|
321
338
|
|
|
322
339
|
|
|
323
340
|
def concurrent_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
|
|
324
|
-
max_dimension_pixels:int=None,
|
|
341
|
+
max_dimension_pixels:int=None, few_shot_examples:List[FewShotExample]=None,
|
|
342
|
+
concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
|
|
325
343
|
"""
|
|
326
344
|
First complete first out. Input and output order not guaranteed.
|
|
327
345
|
This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
|
|
@@ -335,6 +353,8 @@ class OCREngine:
|
|
|
335
353
|
If True, applies rotate correction to the images using pytesseract.
|
|
336
354
|
max_dimension_pixels : int, Optional
|
|
337
355
|
The maximum dimension of the image in pixels. Origianl dimensions will be resized to fit in. If None, no resizing is applied.
|
|
356
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
357
|
+
list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
|
|
338
358
|
concurrent_batch_size : int, Optional
|
|
339
359
|
The number of concurrent VLM calls to make.
|
|
340
360
|
max_file_load : int, Optional
|
|
@@ -353,18 +373,17 @@ class OCREngine:
|
|
|
353
373
|
|
|
354
374
|
if not isinstance(max_file_load, int) or max_file_load <= 0:
|
|
355
375
|
raise ValueError("max_file_load must be a positive integer")
|
|
356
|
-
|
|
357
|
-
if self.image_processor.has_tesseract==False and rotate_correction:
|
|
358
|
-
raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
|
|
359
376
|
|
|
360
377
|
return self._ocr_async(file_paths=file_paths,
|
|
361
378
|
rotate_correction=rotate_correction,
|
|
362
379
|
max_dimension_pixels=max_dimension_pixels,
|
|
380
|
+
few_shot_examples=few_shot_examples,
|
|
363
381
|
concurrent_batch_size=concurrent_batch_size,
|
|
364
382
|
max_file_load=max_file_load)
|
|
365
383
|
|
|
366
384
|
|
|
367
385
|
async def _ocr_async(self, file_paths: Iterable[str], rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
386
|
+
few_shot_examples:List[FewShotExample]=None,
|
|
368
387
|
concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
|
|
369
388
|
"""
|
|
370
389
|
Internal method to asynchronously process an iterable of file paths.
|
|
@@ -380,7 +399,8 @@ class OCREngine:
|
|
|
380
399
|
vlm_call_semaphore=vlm_call_semaphore,
|
|
381
400
|
file_path=file_path,
|
|
382
401
|
rotate_correction=rotate_correction,
|
|
383
|
-
max_dimension_pixels=max_dimension_pixels
|
|
402
|
+
max_dimension_pixels=max_dimension_pixels,
|
|
403
|
+
few_shot_examples=few_shot_examples)
|
|
384
404
|
tasks.append(task)
|
|
385
405
|
|
|
386
406
|
|
|
@@ -389,7 +409,8 @@ class OCREngine:
|
|
|
389
409
|
yield result
|
|
390
410
|
|
|
391
411
|
async def _ocr_file_with_semaphore(self, file_load_semaphore:asyncio.Semaphore, vlm_call_semaphore:asyncio.Semaphore,
|
|
392
|
-
file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
412
|
+
file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
413
|
+
few_shot_examples:List[FewShotExample]=None) -> OCRResult:
|
|
393
414
|
"""
|
|
394
415
|
This internal method takes a semaphore and OCR a single file using the VLM inference engine.
|
|
395
416
|
"""
|
|
@@ -428,6 +449,7 @@ class OCREngine:
|
|
|
428
449
|
page_index=page_index,
|
|
429
450
|
rotate_correction=rotate_correction,
|
|
430
451
|
max_dimension_pixels=max_dimension_pixels,
|
|
452
|
+
few_shot_examples=few_shot_examples,
|
|
431
453
|
messages_logger=messages_logger
|
|
432
454
|
)
|
|
433
455
|
page_processing_tasks.append(task)
|
|
@@ -444,13 +466,14 @@ class OCREngine:
|
|
|
444
466
|
return result
|
|
445
467
|
|
|
446
468
|
# Set status to success if no errors occurred
|
|
447
|
-
result.status
|
|
469
|
+
if result.status != "error":
|
|
470
|
+
result.status = "success"
|
|
448
471
|
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
449
472
|
return result
|
|
450
473
|
|
|
451
474
|
async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
|
|
452
475
|
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
453
|
-
messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
|
|
476
|
+
few_shot_examples:List[FewShotExample]=None, messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
|
|
454
477
|
"""
|
|
455
478
|
This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
|
|
456
479
|
|
|
@@ -462,8 +485,8 @@ class OCREngine:
|
|
|
462
485
|
async with vlm_call_semaphore:
|
|
463
486
|
image = await data_loader.get_page_async(page_index)
|
|
464
487
|
image_processing_status = {}
|
|
465
|
-
# Apply rotate correction if specified
|
|
466
|
-
if rotate_correction
|
|
488
|
+
# Apply rotate correction if specified
|
|
489
|
+
if rotate_correction:
|
|
467
490
|
try:
|
|
468
491
|
image, rotation_angle = await self.image_processor.rotate_correction_async(image)
|
|
469
492
|
image_processing_status["rotate_correction"] = {
|
|
@@ -490,7 +513,10 @@ class OCREngine:
|
|
|
490
513
|
"error": str(e)
|
|
491
514
|
}
|
|
492
515
|
|
|
493
|
-
messages = self.vlm_engine.get_ocr_messages(self.system_prompt,
|
|
516
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
517
|
+
user_prompt=self.user_prompt,
|
|
518
|
+
image=image,
|
|
519
|
+
few_shot_examples=few_shot_examples)
|
|
494
520
|
response = await self.vlm_engine.chat_async(
|
|
495
521
|
messages,
|
|
496
522
|
messages_logger=messages_logger
|