vlm4ocr 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/PKG-INFO +1 -1
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/pyproject.toml +1 -1
- vlm4ocr-0.4.0/vlm4ocr/__init__.py +17 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/cli.py +73 -28
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/data_types.py +57 -7
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/ocr_engines.py +85 -42
- vlm4ocr-0.4.0/vlm4ocr/vlm_engines.py +1246 -0
- vlm4ocr-0.3.0/vlm4ocr/__init__.py +0 -11
- vlm4ocr-0.3.0/vlm4ocr/vlm_engines.py +0 -570
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/README.md +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_markdown_user_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.4.0}/vlm4ocr/utils.py +0 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .data_types import FewShotExample
|
|
2
|
+
from .ocr_engines import OCREngine
|
|
3
|
+
from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"FewShotExample",
|
|
7
|
+
"BasicVLMConfig",
|
|
8
|
+
"ReasoningVLMConfig",
|
|
9
|
+
"OpenAIReasoningVLMConfig",
|
|
10
|
+
"OCREngine",
|
|
11
|
+
"OllamaVLMEngine",
|
|
12
|
+
"OpenAICompatibleVLMEngine",
|
|
13
|
+
"VLLMVLMEngine",
|
|
14
|
+
"OpenRouterVLMEngine",
|
|
15
|
+
"OpenAIVLMEngine",
|
|
16
|
+
"AzureOpenAIVLMEngine"
|
|
17
|
+
]
|
|
@@ -4,18 +4,9 @@ import sys
|
|
|
4
4
|
import logging
|
|
5
5
|
import asyncio
|
|
6
6
|
import time
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from .ocr_engines import OCREngine
|
|
11
|
-
from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
12
|
-
from .data_types import OCRResult
|
|
13
|
-
except ImportError:
|
|
14
|
-
# Fallback for when the package is installed
|
|
15
|
-
from vlm4ocr.ocr_engines import OCREngine
|
|
16
|
-
from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
17
|
-
from vlm4ocr.data_types import OCRResult
|
|
18
|
-
|
|
7
|
+
from .ocr_engines import OCREngine
|
|
8
|
+
from .vlm_engines import OpenAICompatibleVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
9
|
+
from .data_types import OCRResult
|
|
19
10
|
import tqdm.asyncio
|
|
20
11
|
|
|
21
12
|
# --- Global logger setup (console) ---
|
|
@@ -24,7 +15,12 @@ logging.basicConfig(
|
|
|
24
15
|
format='%(asctime)s - %(levelname)s: %(message)s',
|
|
25
16
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
26
17
|
)
|
|
18
|
+
# Get our specific logger for CLI messages
|
|
27
19
|
logger = logging.getLogger("vlm4ocr_cli")
|
|
20
|
+
# Get the logger that will receive captured warnings
|
|
21
|
+
# By default, warnings are logged to a logger named 'py.warnings'
|
|
22
|
+
warnings_logger = logging.getLogger('py.warnings')
|
|
23
|
+
|
|
28
24
|
|
|
29
25
|
SUPPORTED_IMAGE_EXTS_CLI = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
30
26
|
OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'.txt'}
|
|
@@ -65,17 +61,26 @@ def setup_file_logger(log_dir, timestamp_str, debug_mode):
|
|
|
65
61
|
log_file_path = os.path.join(log_dir, log_file_name)
|
|
66
62
|
|
|
67
63
|
file_handler = logging.FileHandler(log_file_path, mode='a')
|
|
68
|
-
formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
64
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(name)s:%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
69
65
|
file_handler.setFormatter(formatter)
|
|
70
66
|
|
|
71
67
|
log_level = logging.DEBUG if debug_mode else logging.INFO
|
|
72
68
|
file_handler.setLevel(log_level)
|
|
73
69
|
|
|
74
|
-
logger
|
|
70
|
+
# Add handler to the root logger to capture all logs (from our logger,
|
|
71
|
+
# and from the warnings logger 'py.warnings')
|
|
72
|
+
root_logger = logging.getLogger()
|
|
73
|
+
root_logger.addHandler(file_handler)
|
|
74
|
+
|
|
75
|
+
# We still configure our specific logger's level for console output
|
|
75
76
|
logger.info(f"Logging to file: {log_file_path}")
|
|
76
77
|
|
|
77
78
|
|
|
78
79
|
def main():
|
|
80
|
+
# Capture warnings from the 'warnings' module (like RuntimeWarning)
|
|
81
|
+
# and redirect them to the 'logging' system.
|
|
82
|
+
logging.captureWarnings(True)
|
|
83
|
+
|
|
79
84
|
parser = argparse.ArgumentParser(
|
|
80
85
|
description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models. Processing is concurrent by default.",
|
|
81
86
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
@@ -103,7 +108,8 @@ def main():
|
|
|
103
108
|
vlm_engine_group.add_argument("--vlm_engine", choices=["openai", "azure_openai", "ollama", "openai_compatible"], required=True, help="VLM engine.")
|
|
104
109
|
vlm_engine_group.add_argument("--model", required=True, help="Model identifier for the VLM engine.")
|
|
105
110
|
vlm_engine_group.add_argument("--max_new_tokens", type=int, default=4096, help="Max new tokens for VLM.")
|
|
106
|
-
vlm_engine_group.add_argument("--temperature", type=float, default=
|
|
111
|
+
vlm_engine_group.add_argument("--temperature", type=float, default=None, help="Sampling temperature.")
|
|
112
|
+
vlm_engine_group.add_argument("--top_p", type=float, default=None, help="Sampling top p.")
|
|
107
113
|
|
|
108
114
|
openai_group = parser.add_argument_group("OpenAI & OpenAI-Compatible Options")
|
|
109
115
|
openai_group.add_argument("--api_key", default=os.environ.get("OPENAI_API_KEY"), help="API key.")
|
|
@@ -144,16 +150,23 @@ def main():
|
|
|
144
150
|
current_timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
145
151
|
|
|
146
152
|
# --- Configure Logger Level based on args ---
|
|
153
|
+
# Get root logger to control global level for libraries
|
|
154
|
+
root_logger = logging.getLogger()
|
|
155
|
+
|
|
147
156
|
if args.debug:
|
|
148
|
-
logger.setLevel(logging.DEBUG)
|
|
149
|
-
#
|
|
150
|
-
|
|
151
|
-
logging.getLogger().setLevel(logging.DEBUG)
|
|
157
|
+
logger.setLevel(logging.DEBUG) # Our logger to DEBUG
|
|
158
|
+
warnings_logger.setLevel(logging.DEBUG) # Warnings logger to DEBUG
|
|
159
|
+
root_logger.setLevel(logging.DEBUG) # Root to DEBUG
|
|
152
160
|
logger.debug("Debug mode enabled for console.")
|
|
153
161
|
else:
|
|
154
|
-
logger.setLevel(logging.INFO) #
|
|
155
|
-
|
|
156
|
-
|
|
162
|
+
logger.setLevel(logging.INFO) # Our logger to INFO
|
|
163
|
+
warnings_logger.setLevel(logging.INFO) # Warnings logger to INFO
|
|
164
|
+
root_logger.setLevel(logging.WARNING) # Root to WARNING (quieter libraries)
|
|
165
|
+
# Our console handler (from basicConfig) is on the root logger,
|
|
166
|
+
# so setting root to WARNING makes console quiet
|
|
167
|
+
# But our logger (vlm4ocr_cli) is INFO, so if a file handler
|
|
168
|
+
# is added, it will get INFO messages from 'logger'
|
|
169
|
+
|
|
157
170
|
if args.concurrent_batch_size < 1:
|
|
158
171
|
parser.error("--concurrent_batch_size must be 1 or greater.")
|
|
159
172
|
|
|
@@ -192,6 +205,15 @@ def main():
|
|
|
192
205
|
# --- Setup File Logger (if --log is specified) ---
|
|
193
206
|
if args.log:
|
|
194
207
|
setup_file_logger(effective_output_dir, current_timestamp_str, args.debug)
|
|
208
|
+
# If logging to file, we want our console to be less verbose
|
|
209
|
+
# if not in debug mode, so we set the console handler's level higher.
|
|
210
|
+
if not args.debug:
|
|
211
|
+
# Find the console handler (from basicConfig) and set its level
|
|
212
|
+
for handler in root_logger.handlers:
|
|
213
|
+
if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
|
|
214
|
+
handler.setLevel(logging.WARNING)
|
|
215
|
+
logger.debug("Set console handler level to WARNING.")
|
|
216
|
+
break
|
|
195
217
|
|
|
196
218
|
logger.debug(f"Parsed arguments: {args}")
|
|
197
219
|
|
|
@@ -199,16 +221,18 @@ def main():
|
|
|
199
221
|
vlm_engine_instance = None
|
|
200
222
|
try:
|
|
201
223
|
logger.info(f"Initializing VLM engine: {args.vlm_engine} with model: {args.model}")
|
|
224
|
+
logger.info(f"max_new_tokens: {args.max_new_tokens}, temperature: {args.temperature}, top_p: {args.top_p}")
|
|
202
225
|
config = BasicVLMConfig(
|
|
203
226
|
max_new_tokens=args.max_new_tokens,
|
|
204
|
-
temperature=args.temperature
|
|
227
|
+
temperature=args.temperature,
|
|
228
|
+
top_p=args.top_p
|
|
205
229
|
)
|
|
206
230
|
if args.vlm_engine == "openai":
|
|
207
231
|
if not args.api_key: parser.error("--api_key (or OPENAI_API_KEY) is required for OpenAI.")
|
|
208
232
|
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
|
|
209
233
|
elif args.vlm_engine == "openai_compatible":
|
|
210
234
|
if not args.base_url: parser.error("--base_url is required for openai_compatible.")
|
|
211
|
-
vlm_engine_instance =
|
|
235
|
+
vlm_engine_instance = OpenAICompatibleVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
|
|
212
236
|
elif args.vlm_engine == "azure_openai":
|
|
213
237
|
if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
|
|
214
238
|
if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
|
|
@@ -295,16 +319,34 @@ def main():
|
|
|
295
319
|
# console verbosity controlled by logger level.
|
|
296
320
|
show_progress_bar = (num_actual_files > 0)
|
|
297
321
|
|
|
322
|
+
# Only show progress bar if not in debug mode (debug logs would interfere)
|
|
323
|
+
# and if there are files to process.
|
|
324
|
+
# If logging to file, console can be quiet (INFO level).
|
|
325
|
+
# If NOT logging to file, console must be INFO level to show bar.
|
|
326
|
+
|
|
327
|
+
# Determine if progress bar should be active (not disabled)
|
|
328
|
+
# Disable bar if in debug mode (logs interfere) or no files
|
|
329
|
+
disable_bar = args.debug or not show_progress_bar
|
|
330
|
+
|
|
331
|
+
# If not logging to file AND not debug, we need console at INFO
|
|
332
|
+
if not args.log and not args.debug:
|
|
333
|
+
for handler in logging.getLogger().handlers:
|
|
334
|
+
if isinstance(handler, logging.StreamHandler) and handler.stream == sys.stderr:
|
|
335
|
+
handler.setLevel(logging.INFO)
|
|
336
|
+
logger.debug("Set console handler level to INFO for progress bar.")
|
|
337
|
+
break
|
|
338
|
+
|
|
298
339
|
iterator_wrapper = tqdm.asyncio.tqdm(
|
|
299
340
|
ocr_task_generator,
|
|
300
341
|
total=num_actual_files,
|
|
301
342
|
desc="Processing files",
|
|
302
343
|
unit="file",
|
|
303
|
-
disable=
|
|
344
|
+
disable=disable_bar
|
|
304
345
|
)
|
|
305
346
|
|
|
306
347
|
async for result_object in iterator_wrapper:
|
|
307
348
|
if not isinstance(result_object, OCRResult):
|
|
349
|
+
# This warning *will* now be captured by the file log
|
|
308
350
|
logger.warning(f"Received unexpected data type: {type(result_object)}")
|
|
309
351
|
continue
|
|
310
352
|
|
|
@@ -323,9 +365,12 @@ def main():
|
|
|
323
365
|
content_to_write = result_object.to_string()
|
|
324
366
|
with open(current_ocr_output_file_path, "w", encoding="utf-8") as f:
|
|
325
367
|
f.write(content_to_write)
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
368
|
+
|
|
369
|
+
# MODIFIED: Always log success info.
|
|
370
|
+
# This will go to the file log if active.
|
|
371
|
+
# It will NOT go to console if console level is WARNING.
|
|
372
|
+
logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
|
|
373
|
+
|
|
329
374
|
except Exception as e:
|
|
330
375
|
logger.error(f"Error writing output for '{input_file_path_from_result}' to '{current_ocr_output_file_path}': {e}")
|
|
331
376
|
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import List, Literal
|
|
2
|
+
from typing import List, Dict, Literal
|
|
3
|
+
from PIL import Image
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
|
-
from vlm4ocr.utils import get_default_page_delimiter
|
|
5
|
+
from vlm4ocr.utils import get_default_page_delimiter, ImageProcessor
|
|
5
6
|
|
|
6
7
|
OutputMode = Literal["markdown", "HTML", "text", "JSON"]
|
|
7
8
|
|
|
@@ -24,6 +25,7 @@ class OCRResult:
|
|
|
24
25
|
pages: List[dict] = field(default_factory=list)
|
|
25
26
|
filename: str = field(init=False)
|
|
26
27
|
status: str = field(init=False, default="processing")
|
|
28
|
+
messages_log: List[List[Dict[str,str]]] = field(default_factory=list)
|
|
27
29
|
|
|
28
30
|
def __post_init__(self):
|
|
29
31
|
"""
|
|
@@ -67,10 +69,6 @@ class OCRResult:
|
|
|
67
69
|
}
|
|
68
70
|
self.pages.append(page)
|
|
69
71
|
|
|
70
|
-
|
|
71
|
-
def __len__(self):
|
|
72
|
-
return len(self.pages)
|
|
73
|
-
|
|
74
72
|
def get_page(self, idx):
|
|
75
73
|
if not isinstance(idx, int):
|
|
76
74
|
raise ValueError("Index must be an integer")
|
|
@@ -78,6 +76,21 @@ class OCRResult:
|
|
|
78
76
|
raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
|
|
79
77
|
|
|
80
78
|
return self.pages[idx]
|
|
79
|
+
|
|
80
|
+
def clear_messages_log(self):
|
|
81
|
+
self.messages_log = []
|
|
82
|
+
|
|
83
|
+
def add_messages_to_log(self, messages: List[Dict[str,str]]):
|
|
84
|
+
if not isinstance(messages, list):
|
|
85
|
+
raise ValueError("messages must be a list of dict")
|
|
86
|
+
|
|
87
|
+
self.messages_log.extend(messages)
|
|
88
|
+
|
|
89
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
90
|
+
return self.messages_log.copy()
|
|
91
|
+
|
|
92
|
+
def __len__(self):
|
|
93
|
+
return len(self.pages)
|
|
81
94
|
|
|
82
95
|
def __iter__(self):
|
|
83
96
|
return iter(self.pages)
|
|
@@ -106,4 +119,41 @@ class OCRResult:
|
|
|
106
119
|
else:
|
|
107
120
|
self.page_delimiter = page_delimiter
|
|
108
121
|
|
|
109
|
-
return self.page_delimiter.join([page.get("text", "") for page in self.pages])
|
|
122
|
+
return self.page_delimiter.join([page.get("text", "") for page in self.pages])
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class FewShotExample:
|
|
126
|
+
"""
|
|
127
|
+
This class represents a few-shot example for OCR tasks.
|
|
128
|
+
|
|
129
|
+
Parameters:
|
|
130
|
+
----------
|
|
131
|
+
image : PIL.Image.Image
|
|
132
|
+
The image associated with the example.
|
|
133
|
+
text : str
|
|
134
|
+
The expected OCR result text for the image.
|
|
135
|
+
rotate_correction : bool, Optional
|
|
136
|
+
If True, applies rotate correction to the images using pytesseract.
|
|
137
|
+
max_dimension_pixels : int, Optional
|
|
138
|
+
The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
|
|
139
|
+
"""
|
|
140
|
+
image: Image.Image
|
|
141
|
+
text: str
|
|
142
|
+
rotate_correction: bool = False
|
|
143
|
+
max_dimension_pixels: int = None
|
|
144
|
+
def __post_init__(self):
|
|
145
|
+
if not isinstance(self.image, Image.Image):
|
|
146
|
+
raise ValueError("image must be a PIL.Image.Image object")
|
|
147
|
+
if not isinstance(self.text, str):
|
|
148
|
+
raise ValueError("text must be a string")
|
|
149
|
+
|
|
150
|
+
if self.rotate_correction or self.max_dimension_pixels is not None:
|
|
151
|
+
self.image_processor = ImageProcessor()
|
|
152
|
+
|
|
153
|
+
# Rotate correction if specified
|
|
154
|
+
if self.rotate_correction:
|
|
155
|
+
self.image, _ = self.image_processor.rotate_correction(self.image)
|
|
156
|
+
|
|
157
|
+
# Resize image if max_dimension_pixels is specified
|
|
158
|
+
if self.max_dimension_pixels is not None:
|
|
159
|
+
self.image, _ = self.image_processor.resize(image=self.image, max_dimension_pixels=self.max_dimension_pixels)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
|
|
2
|
+
from typing import Any, Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
|
|
3
3
|
import importlib
|
|
4
4
|
import asyncio
|
|
5
5
|
from colorama import Fore, Style
|
|
6
6
|
import json
|
|
7
7
|
from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
|
|
8
|
-
from vlm4ocr.data_types import OCRResult
|
|
9
|
-
from vlm4ocr.vlm_engines import VLMEngine
|
|
8
|
+
from vlm4ocr.data_types import OCRResult, FewShotExample
|
|
9
|
+
from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
|
|
10
10
|
|
|
11
11
|
SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
12
12
|
|
|
@@ -60,7 +60,8 @@ class OCREngine:
|
|
|
60
60
|
self.image_processor = ImageProcessor()
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
63
|
+
def stream_ocr(self, file_path: str, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
64
|
+
few_shot_examples:List[FewShotExample]=None) -> Generator[Dict[str, str], None, None]:
|
|
64
65
|
"""
|
|
65
66
|
This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
|
|
66
67
|
Yields dictionaries with 'type' ('ocr_chunk' or 'page_delimiter') and 'data'.
|
|
@@ -73,6 +74,8 @@ class OCREngine:
|
|
|
73
74
|
If True, applies rotate correction to the images using pytesseract.
|
|
74
75
|
max_dimension_pixels : int, Optional
|
|
75
76
|
The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
|
|
77
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
78
|
+
list of few-shot examples.
|
|
76
79
|
|
|
77
80
|
Returns:
|
|
78
81
|
--------
|
|
@@ -90,10 +93,6 @@ class OCREngine:
|
|
|
90
93
|
file_ext = os.path.splitext(file_path)[1].lower()
|
|
91
94
|
if file_ext not in SUPPORTED_IMAGE_EXTS:
|
|
92
95
|
raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS}")
|
|
93
|
-
|
|
94
|
-
# Check if image preprocessing can be applied
|
|
95
|
-
if self.image_processor.has_tesseract==False and rotate_correction:
|
|
96
|
-
raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
|
|
97
96
|
|
|
98
97
|
# PDF or TIFF
|
|
99
98
|
if file_ext in ['.pdf', '.tif', '.tiff']:
|
|
@@ -105,8 +104,8 @@ class OCREngine:
|
|
|
105
104
|
|
|
106
105
|
# OCR each image
|
|
107
106
|
for i, image in enumerate(images):
|
|
108
|
-
# Apply rotate correction if specified
|
|
109
|
-
if rotate_correction
|
|
107
|
+
# Apply rotate correction if specified
|
|
108
|
+
if rotate_correction:
|
|
110
109
|
try:
|
|
111
110
|
image, _ = self.image_processor.rotate_correction(image)
|
|
112
111
|
|
|
@@ -120,13 +119,20 @@ class OCREngine:
|
|
|
120
119
|
except Exception as e:
|
|
121
120
|
yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
|
|
122
121
|
|
|
123
|
-
|
|
122
|
+
# Get OCR messages
|
|
123
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
124
|
+
user_prompt=self.user_prompt,
|
|
125
|
+
image=image,
|
|
126
|
+
few_shot_examples=few_shot_examples)
|
|
127
|
+
|
|
128
|
+
# Stream response
|
|
124
129
|
response_stream = self.vlm_engine.chat(
|
|
125
130
|
messages,
|
|
126
131
|
stream=True
|
|
127
132
|
)
|
|
128
133
|
for chunk in response_stream:
|
|
129
|
-
|
|
134
|
+
if chunk["type"] == "response":
|
|
135
|
+
yield {"type": "ocr_chunk", "data": chunk["data"]}
|
|
130
136
|
|
|
131
137
|
if i < len(images) - 1:
|
|
132
138
|
yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
|
|
@@ -136,8 +142,8 @@ class OCREngine:
|
|
|
136
142
|
data_loader = ImageDataLoader(file_path)
|
|
137
143
|
image = data_loader.get_page(0)
|
|
138
144
|
|
|
139
|
-
# Apply rotate correction if specified
|
|
140
|
-
if rotate_correction
|
|
145
|
+
# Apply rotate correction if specified
|
|
146
|
+
if rotate_correction:
|
|
141
147
|
try:
|
|
142
148
|
image, _ = self.image_processor.rotate_correction(image)
|
|
143
149
|
|
|
@@ -151,17 +157,23 @@ class OCREngine:
|
|
|
151
157
|
except Exception as e:
|
|
152
158
|
yield {"type": "info", "data": f"Error resizing image: {str(e)}"}
|
|
153
159
|
|
|
154
|
-
|
|
160
|
+
# Get OCR messages
|
|
161
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
162
|
+
user_prompt=self.user_prompt,
|
|
163
|
+
image=image,
|
|
164
|
+
few_shot_examples=few_shot_examples)
|
|
165
|
+
# Stream response
|
|
155
166
|
response_stream = self.vlm_engine.chat(
|
|
156
167
|
messages,
|
|
157
168
|
stream=True
|
|
158
169
|
)
|
|
159
170
|
for chunk in response_stream:
|
|
160
|
-
|
|
171
|
+
if chunk["type"] == "response":
|
|
172
|
+
yield {"type": "ocr_chunk", "data": chunk["data"]}
|
|
161
173
|
|
|
162
174
|
|
|
163
175
|
def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
|
|
164
|
-
max_dimension_pixels:int=None, verbose:bool=False) -> List[OCRResult]:
|
|
176
|
+
max_dimension_pixels:int=None, verbose:bool=False, few_shot_examples:List[FewShotExample]=None) -> List[OCRResult]:
|
|
165
177
|
"""
|
|
166
178
|
This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
|
|
167
179
|
|
|
@@ -175,6 +187,8 @@ class OCREngine:
|
|
|
175
187
|
The maximum dimension of the image in pixels. Original dimensions will be resized to fit in. If None, no resizing is applied.
|
|
176
188
|
verbose : bool, Optional
|
|
177
189
|
If True, the function will print the output in terminal.
|
|
190
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
191
|
+
list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
|
|
178
192
|
|
|
179
193
|
Returns:
|
|
180
194
|
--------
|
|
@@ -184,6 +198,7 @@ class OCREngine:
|
|
|
184
198
|
if isinstance(file_paths, str):
|
|
185
199
|
file_paths = [file_paths]
|
|
186
200
|
|
|
201
|
+
# Iterate through file paths
|
|
187
202
|
ocr_results = []
|
|
188
203
|
for file_path in file_paths:
|
|
189
204
|
# Define OCRResult object
|
|
@@ -233,8 +248,8 @@ class OCREngine:
|
|
|
233
248
|
# OCR images
|
|
234
249
|
for i, image in enumerate(images):
|
|
235
250
|
image_processing_status = {}
|
|
236
|
-
# Apply rotate correction if specified
|
|
237
|
-
if rotate_correction
|
|
251
|
+
# Apply rotate correction if specified
|
|
252
|
+
if rotate_correction:
|
|
238
253
|
try:
|
|
239
254
|
image, rotation_angle = self.image_processor.rotate_correction(image)
|
|
240
255
|
image_processing_status["rotate_correction"] = {
|
|
@@ -270,25 +285,36 @@ class OCREngine:
|
|
|
270
285
|
print(f"{Fore.RED}Error resizing image for {filename}:{Style.RESET_ALL} {resized['error']}. OCR continues without resizing.")
|
|
271
286
|
|
|
272
287
|
try:
|
|
273
|
-
messages = self.vlm_engine.get_ocr_messages(self.system_prompt,
|
|
288
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
289
|
+
user_prompt=self.user_prompt,
|
|
290
|
+
image=image,
|
|
291
|
+
few_shot_examples=few_shot_examples)
|
|
292
|
+
# Define a messages logger to capture messages
|
|
293
|
+
messages_logger = MessagesLogger()
|
|
294
|
+
# Generate response
|
|
274
295
|
response = self.vlm_engine.chat(
|
|
275
296
|
messages,
|
|
276
297
|
verbose=verbose,
|
|
277
|
-
stream=False
|
|
298
|
+
stream=False,
|
|
299
|
+
messages_logger=messages_logger
|
|
278
300
|
)
|
|
301
|
+
ocr_text = response["response"]
|
|
279
302
|
# Clean the response if output mode is markdown
|
|
280
303
|
if self.output_mode == "markdown":
|
|
281
|
-
|
|
304
|
+
ocr_text = clean_markdown(ocr_text)
|
|
282
305
|
|
|
283
306
|
# Parse the response if output mode is JSON
|
|
284
|
-
|
|
285
|
-
json_list = extract_json(
|
|
307
|
+
elif self.output_mode == "JSON":
|
|
308
|
+
json_list = extract_json(ocr_text)
|
|
286
309
|
# Serialize the JSON list to a string
|
|
287
|
-
|
|
310
|
+
ocr_text = json.dumps(json_list, indent=4)
|
|
288
311
|
|
|
289
312
|
# Add the page to the OCR result
|
|
290
|
-
ocr_result.add_page(text=
|
|
313
|
+
ocr_result.add_page(text=ocr_text,
|
|
291
314
|
image_processing_status=image_processing_status)
|
|
315
|
+
|
|
316
|
+
# Add messages log to the OCR result
|
|
317
|
+
ocr_result.add_messages_to_log(messages_logger.get_messages_log())
|
|
292
318
|
|
|
293
319
|
except Exception as page_e:
|
|
294
320
|
ocr_result.status = "error"
|
|
@@ -298,11 +324,12 @@ class OCREngine:
|
|
|
298
324
|
print(f"{Fore.RED}Error during OCR for a page in {filename}:{Style.RESET_ALL} {page_e}")
|
|
299
325
|
|
|
300
326
|
# Add the OCR result to the list
|
|
301
|
-
ocr_result.status
|
|
327
|
+
if ocr_result.status != "error":
|
|
328
|
+
ocr_result.status = "success"
|
|
302
329
|
ocr_results.append(ocr_result)
|
|
303
330
|
|
|
304
331
|
if verbose:
|
|
305
|
-
print(f"{Fore.BLUE}
|
|
332
|
+
print(f"{Fore.BLUE}Processed {filename} with {len(ocr_result)} pages.{Style.RESET_ALL}")
|
|
306
333
|
for page in ocr_result:
|
|
307
334
|
print(page)
|
|
308
335
|
print("-" * 80)
|
|
@@ -311,7 +338,8 @@ class OCREngine:
|
|
|
311
338
|
|
|
312
339
|
|
|
313
340
|
def concurrent_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
|
|
314
|
-
max_dimension_pixels:int=None,
|
|
341
|
+
max_dimension_pixels:int=None, few_shot_examples:List[FewShotExample]=None,
|
|
342
|
+
concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
|
|
315
343
|
"""
|
|
316
344
|
First complete first out. Input and output order not guaranteed.
|
|
317
345
|
This method inputs a file path or a list of file paths (image, PDF, TIFF) and performs OCR using the VLM inference engine.
|
|
@@ -325,6 +353,8 @@ class OCREngine:
|
|
|
325
353
|
If True, applies rotate correction to the images using pytesseract.
|
|
326
354
|
max_dimension_pixels : int, Optional
|
|
327
355
|
The maximum dimension of the image in pixels. Origianl dimensions will be resized to fit in. If None, no resizing is applied.
|
|
356
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
357
|
+
list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
|
|
328
358
|
concurrent_batch_size : int, Optional
|
|
329
359
|
The number of concurrent VLM calls to make.
|
|
330
360
|
max_file_load : int, Optional
|
|
@@ -343,18 +373,17 @@ class OCREngine:
|
|
|
343
373
|
|
|
344
374
|
if not isinstance(max_file_load, int) or max_file_load <= 0:
|
|
345
375
|
raise ValueError("max_file_load must be a positive integer")
|
|
346
|
-
|
|
347
|
-
if self.image_processor.has_tesseract==False and rotate_correction:
|
|
348
|
-
raise ImportError("pytesseract is not installed. Please install it to use rotate correction.")
|
|
349
376
|
|
|
350
377
|
return self._ocr_async(file_paths=file_paths,
|
|
351
378
|
rotate_correction=rotate_correction,
|
|
352
379
|
max_dimension_pixels=max_dimension_pixels,
|
|
380
|
+
few_shot_examples=few_shot_examples,
|
|
353
381
|
concurrent_batch_size=concurrent_batch_size,
|
|
354
382
|
max_file_load=max_file_load)
|
|
355
383
|
|
|
356
384
|
|
|
357
385
|
async def _ocr_async(self, file_paths: Iterable[str], rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
386
|
+
few_shot_examples:List[FewShotExample]=None,
|
|
358
387
|
concurrent_batch_size: int=32, max_file_load: int=None) -> AsyncGenerator[OCRResult, None]:
|
|
359
388
|
"""
|
|
360
389
|
Internal method to asynchronously process an iterable of file paths.
|
|
@@ -370,7 +399,8 @@ class OCREngine:
|
|
|
370
399
|
vlm_call_semaphore=vlm_call_semaphore,
|
|
371
400
|
file_path=file_path,
|
|
372
401
|
rotate_correction=rotate_correction,
|
|
373
|
-
max_dimension_pixels=max_dimension_pixels
|
|
402
|
+
max_dimension_pixels=max_dimension_pixels,
|
|
403
|
+
few_shot_examples=few_shot_examples)
|
|
374
404
|
tasks.append(task)
|
|
375
405
|
|
|
376
406
|
|
|
@@ -379,7 +409,8 @@ class OCREngine:
|
|
|
379
409
|
yield result
|
|
380
410
|
|
|
381
411
|
async def _ocr_file_with_semaphore(self, file_load_semaphore:asyncio.Semaphore, vlm_call_semaphore:asyncio.Semaphore,
|
|
382
|
-
file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
412
|
+
file_path:str, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
413
|
+
few_shot_examples:List[FewShotExample]=None) -> OCRResult:
|
|
383
414
|
"""
|
|
384
415
|
This internal method takes a semaphore and OCR a single file using the VLM inference engine.
|
|
385
416
|
"""
|
|
@@ -387,6 +418,7 @@ class OCREngine:
|
|
|
387
418
|
filename = os.path.basename(file_path)
|
|
388
419
|
file_ext = os.path.splitext(file_path)[1].lower()
|
|
389
420
|
result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
|
|
421
|
+
messages_logger = MessagesLogger()
|
|
390
422
|
# check file extension
|
|
391
423
|
if file_ext not in SUPPORTED_IMAGE_EXTS:
|
|
392
424
|
result.status = "error"
|
|
@@ -416,7 +448,9 @@ class OCREngine:
|
|
|
416
448
|
data_loader=data_loader,
|
|
417
449
|
page_index=page_index,
|
|
418
450
|
rotate_correction=rotate_correction,
|
|
419
|
-
max_dimension_pixels=max_dimension_pixels
|
|
451
|
+
max_dimension_pixels=max_dimension_pixels,
|
|
452
|
+
few_shot_examples=few_shot_examples,
|
|
453
|
+
messages_logger=messages_logger
|
|
420
454
|
)
|
|
421
455
|
page_processing_tasks.append(task)
|
|
422
456
|
|
|
@@ -428,14 +462,18 @@ class OCREngine:
|
|
|
428
462
|
except Exception as e:
|
|
429
463
|
result.status = "error"
|
|
430
464
|
result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
|
|
465
|
+
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
431
466
|
return result
|
|
432
467
|
|
|
433
468
|
# Set status to success if no errors occurred
|
|
434
|
-
result.status
|
|
469
|
+
if result.status != "error":
|
|
470
|
+
result.status = "success"
|
|
471
|
+
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
435
472
|
return result
|
|
436
473
|
|
|
437
474
|
async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
|
|
438
|
-
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
475
|
+
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
476
|
+
few_shot_examples:List[FewShotExample]=None, messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
|
|
439
477
|
"""
|
|
440
478
|
This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
|
|
441
479
|
|
|
@@ -447,8 +485,8 @@ class OCREngine:
|
|
|
447
485
|
async with vlm_call_semaphore:
|
|
448
486
|
image = await data_loader.get_page_async(page_index)
|
|
449
487
|
image_processing_status = {}
|
|
450
|
-
# Apply rotate correction if specified
|
|
451
|
-
if rotate_correction
|
|
488
|
+
# Apply rotate correction if specified
|
|
489
|
+
if rotate_correction:
|
|
452
490
|
try:
|
|
453
491
|
image, rotation_angle = await self.image_processor.rotate_correction_async(image)
|
|
454
492
|
image_processing_status["rotate_correction"] = {
|
|
@@ -475,16 +513,21 @@ class OCREngine:
|
|
|
475
513
|
"error": str(e)
|
|
476
514
|
}
|
|
477
515
|
|
|
478
|
-
messages = self.vlm_engine.get_ocr_messages(self.system_prompt,
|
|
479
|
-
|
|
516
|
+
messages = self.vlm_engine.get_ocr_messages(system_prompt=self.system_prompt,
|
|
517
|
+
user_prompt=self.user_prompt,
|
|
518
|
+
image=image,
|
|
519
|
+
few_shot_examples=few_shot_examples)
|
|
520
|
+
response = await self.vlm_engine.chat_async(
|
|
480
521
|
messages,
|
|
522
|
+
messages_logger=messages_logger
|
|
481
523
|
)
|
|
524
|
+
ocr_text = response["response"]
|
|
482
525
|
# Clean the OCR text if output mode is markdown
|
|
483
526
|
if self.output_mode == "markdown":
|
|
484
527
|
ocr_text = clean_markdown(ocr_text)
|
|
485
528
|
|
|
486
529
|
# Parse the response if output mode is JSON
|
|
487
|
-
|
|
530
|
+
elif self.output_mode == "JSON":
|
|
488
531
|
json_list = extract_json(ocr_text)
|
|
489
532
|
# Serialize the JSON list to a string
|
|
490
533
|
ocr_text = json.dumps(json_list, indent=4)
|