vlm4ocr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +9 -0
- vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt +1 -0
- vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt +1 -0
- vlm4ocr/assets/default_prompt_templates/ocr_user_prompt.txt +1 -0
- vlm4ocr/ocr_engines.py +311 -0
- vlm4ocr/utils.py +47 -0
- vlm4ocr/vlm_engines.py +444 -0
- vlm4ocr-0.0.1.dist-info/METADATA +16 -0
- vlm4ocr-0.0.1.dist-info/RECORD +10 -0
- vlm4ocr-0.0.1.dist-info/WHEEL +4 -0
vlm4ocr/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
You are a helpful assistant that can convert scanned documents into markdown text. Your output is accurate and well-formatted, starting with ```markdown and ending with ```. You will only output the markdown text without any additional explanations or comments. The markdown should include all text, tables, and lists with appropriate headers (e.g., "##"). You will ignore images, icons, or anything that can not be converted into text.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
You are a helpful assistant that can convert scanned documents into plain text. Your output is accurate and well-formatted. You will only output the plain text without any additional explanations or comments. The plain text should include all text, tables, and lists with appropriate layout (e.g., \n, space, and tab). You will ignore images, icons, or anything that can not be converted into text.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Convert contents in this image into markdown.
|
vlm4ocr/ocr_engines.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Dict, Union, Generator, Iterable
|
|
3
|
+
import importlib
|
|
4
|
+
import asyncio
|
|
5
|
+
from vlm4ocr.utils import get_images_from_pdf, get_image_from_file, clean_markdown
|
|
6
|
+
from vlm4ocr.vlm_engines import VLMEngine
|
|
7
|
+
|
|
8
|
+
SUPPORTED_IMAGE_EXTS = ['.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
9
|
+
|
|
10
|
+
class OCREngine:
|
|
11
|
+
def __init__(self, vlm_engine:VLMEngine, output_mode:str="markdown", system_prompt:str=None, user_prompt:str=None, page_delimiter:str="\n\n---\n\n"):
|
|
12
|
+
"""
|
|
13
|
+
This class inputs a image or PDF file path and processes them using a VLM inference engine. Outputs plain text or markdown.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
-----------
|
|
17
|
+
inference_engine : InferenceEngine
|
|
18
|
+
The inference engine to use for OCR.
|
|
19
|
+
output_mode : str, Optional
|
|
20
|
+
The output format. Can be 'markdown' or 'text'.
|
|
21
|
+
system_prompt : str, Optional
|
|
22
|
+
Custom system prompt. We recommend use a default system prompt by leaving this blank.
|
|
23
|
+
user_prompt : str, Optional
|
|
24
|
+
Custom user prompt. It is good to include some information regarding the document. If not specified, a default will be used.
|
|
25
|
+
page_delimiter : str, Optional
|
|
26
|
+
The delimiter to use between PDF pages.
|
|
27
|
+
"""
|
|
28
|
+
# Check inference engine
|
|
29
|
+
if not isinstance(vlm_engine, VLMEngine):
|
|
30
|
+
raise TypeError("vlm_engine must be an instance of VLMEngine")
|
|
31
|
+
self.vlm_engine = vlm_engine
|
|
32
|
+
|
|
33
|
+
# Check output mode
|
|
34
|
+
if output_mode not in ["markdown", "text"]:
|
|
35
|
+
raise ValueError("output_mode must be 'markdown' or 'text'")
|
|
36
|
+
self.output_mode = output_mode
|
|
37
|
+
|
|
38
|
+
# System prompt
|
|
39
|
+
if isinstance(system_prompt, str) and system_prompt:
|
|
40
|
+
self.system_prompt = system_prompt
|
|
41
|
+
else:
|
|
42
|
+
file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_system_prompt.txt')
|
|
43
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
44
|
+
self.system_prompt = f.read()
|
|
45
|
+
|
|
46
|
+
# User prompt
|
|
47
|
+
if isinstance(user_prompt, str) and user_prompt:
|
|
48
|
+
self.user_prompt = user_prompt
|
|
49
|
+
else:
|
|
50
|
+
file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath('ocr_user_prompt.txt')
|
|
51
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
52
|
+
self.user_prompt = f.read()
|
|
53
|
+
|
|
54
|
+
# Page delimiter
|
|
55
|
+
if isinstance(page_delimiter, str):
|
|
56
|
+
self.page_delimiter = page_delimiter
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError("page_delimiter must be a string")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def stream_ocr(self, file_path: str, max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
|
|
62
|
+
"""
|
|
63
|
+
This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
|
|
64
|
+
|
|
65
|
+
Parameters:
|
|
66
|
+
-----------
|
|
67
|
+
file_path : str
|
|
68
|
+
The path to the image or PDF file. Must be one of '.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
|
|
69
|
+
max_new_tokens : int, Optional
|
|
70
|
+
The maximum number of tokens to generate.
|
|
71
|
+
temperature : float, Optional
|
|
72
|
+
The temperature to use for sampling.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
--------
|
|
76
|
+
Generator[str, None, None]
|
|
77
|
+
A generator that yields the output.
|
|
78
|
+
"""
|
|
79
|
+
# Check file path
|
|
80
|
+
if not isinstance(file_path, str):
|
|
81
|
+
raise TypeError("file_path must be a string")
|
|
82
|
+
|
|
83
|
+
# Check file extension
|
|
84
|
+
file_ext = os.path.splitext(file_path)[1].lower()
|
|
85
|
+
if file_ext not in SUPPORTED_IMAGE_EXTS and file_ext != '.pdf':
|
|
86
|
+
raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS + ['.pdf']}")
|
|
87
|
+
|
|
88
|
+
# PDF
|
|
89
|
+
if file_ext == '.pdf':
|
|
90
|
+
images = get_images_from_pdf(file_path)
|
|
91
|
+
if not images:
|
|
92
|
+
raise ValueError(f"No images extracted from PDF: {file_path}")
|
|
93
|
+
for i, image in enumerate(images):
|
|
94
|
+
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
95
|
+
response_stream = self.vlm_engine.chat(
|
|
96
|
+
messages,
|
|
97
|
+
max_new_tokens=max_new_tokens,
|
|
98
|
+
temperature=temperature,
|
|
99
|
+
stream=True,
|
|
100
|
+
**kwrs
|
|
101
|
+
)
|
|
102
|
+
for chunk in response_stream:
|
|
103
|
+
yield chunk
|
|
104
|
+
|
|
105
|
+
if i < len(images) - 1:
|
|
106
|
+
yield self.page_delimiter
|
|
107
|
+
|
|
108
|
+
# Image
|
|
109
|
+
else:
|
|
110
|
+
image = get_image_from_file(file_path)
|
|
111
|
+
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
112
|
+
response_stream = self.vlm_engine.chat(
|
|
113
|
+
messages,
|
|
114
|
+
max_new_tokens=max_new_tokens,
|
|
115
|
+
temperature=temperature,
|
|
116
|
+
stream=True,
|
|
117
|
+
**kwrs
|
|
118
|
+
)
|
|
119
|
+
for chunk in response_stream:
|
|
120
|
+
yield chunk
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096, temperature:float=0.0,
|
|
124
|
+
verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
125
|
+
"""
|
|
126
|
+
This method takes a list of file paths (image or PDF) and perform OCR using the VLM inference engine.
|
|
127
|
+
|
|
128
|
+
Parameters:
|
|
129
|
+
-----------
|
|
130
|
+
file_paths : Union[str, Iterable[str]]
|
|
131
|
+
A file path or a list of file paths to process. Must be one of '.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
|
|
132
|
+
max_new_tokens : int, Optional
|
|
133
|
+
The maximum number of tokens to generate.
|
|
134
|
+
temperature : float, Optional
|
|
135
|
+
The temperature to use for sampling.
|
|
136
|
+
verbose : bool, Optional
|
|
137
|
+
If True, the function will print the output in terminal.
|
|
138
|
+
concurrent : bool, Optional
|
|
139
|
+
If True, the function will process the files concurrently.
|
|
140
|
+
concurrent_batch_size : int, Optional
|
|
141
|
+
The number of images/pages to process concurrently.
|
|
142
|
+
"""
|
|
143
|
+
# if file_paths is a string, convert it to a list
|
|
144
|
+
if isinstance(file_paths, str):
|
|
145
|
+
file_paths = [file_paths]
|
|
146
|
+
|
|
147
|
+
if not isinstance(file_paths, Iterable):
|
|
148
|
+
raise TypeError("file_paths must be a string or an iterable of strings")
|
|
149
|
+
|
|
150
|
+
# check if all file paths are valid
|
|
151
|
+
for file_path in file_paths:
|
|
152
|
+
if not isinstance(file_path, str):
|
|
153
|
+
raise TypeError("file_paths must be a string or an iterable of strings")
|
|
154
|
+
file_ext = os.path.splitext(file_path)[1].lower()
|
|
155
|
+
if file_ext not in SUPPORTED_IMAGE_EXTS and file_ext != '.pdf':
|
|
156
|
+
raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS + ['.pdf']}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# Concurrent processing
|
|
160
|
+
if concurrent:
|
|
161
|
+
# Check concurrent_batch_size
|
|
162
|
+
if concurrent_batch_size <= 0:
|
|
163
|
+
raise ValueError("concurrent_batch_size must be greater than 0")
|
|
164
|
+
|
|
165
|
+
if verbose:
|
|
166
|
+
Warning("verbose is not supported for concurrent processing.", UserWarning)
|
|
167
|
+
|
|
168
|
+
return asyncio.run(self._run_ocr_async(file_paths,
|
|
169
|
+
max_new_tokens=max_new_tokens,
|
|
170
|
+
temperature=temperature,
|
|
171
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
172
|
+
**kwrs))
|
|
173
|
+
|
|
174
|
+
# Sync processing
|
|
175
|
+
return self._run_ocr(file_paths, max_new_tokens=max_new_tokens, temperature=temperature, verbose=verbose, **kwrs)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
|
|
179
|
+
temperature:float=0.0, verbose:bool=False, **kwrs) -> Iterable[str]:
|
|
180
|
+
"""
|
|
181
|
+
This method inputs a file path or a list of file paths (image or PDF) and performs OCR using the VLM inference engine.
|
|
182
|
+
|
|
183
|
+
Parameters:
|
|
184
|
+
-----------
|
|
185
|
+
file_paths : Union[str, Iterable[str]]
|
|
186
|
+
A file path or a list of file paths to process. Must be one of '.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
|
|
187
|
+
max_new_tokens : int, Optional
|
|
188
|
+
The maximum number of tokens to generate.
|
|
189
|
+
temperature : float, Optional
|
|
190
|
+
The temperature to use for sampling.
|
|
191
|
+
verbose : bool, Optional
|
|
192
|
+
If True, the function will print the output in terminal.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
--------
|
|
196
|
+
Iterable[str]
|
|
197
|
+
A list of strings containing the OCR results.
|
|
198
|
+
"""
|
|
199
|
+
ocr_results = []
|
|
200
|
+
for file_path in file_paths:
|
|
201
|
+
file_ext = os.path.splitext(file_path)[1].lower()
|
|
202
|
+
# PDF
|
|
203
|
+
if file_ext == '.pdf':
|
|
204
|
+
images = get_images_from_pdf(file_path)
|
|
205
|
+
if not images:
|
|
206
|
+
raise ValueError(f"No images extracted from PDF: {file_path}")
|
|
207
|
+
pdf_results = []
|
|
208
|
+
for image in images:
|
|
209
|
+
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
210
|
+
response = self.vlm_engine.chat(
|
|
211
|
+
messages,
|
|
212
|
+
max_new_tokens=max_new_tokens,
|
|
213
|
+
temperature=temperature,
|
|
214
|
+
verbose=verbose,
|
|
215
|
+
stream=False,
|
|
216
|
+
**kwrs
|
|
217
|
+
)
|
|
218
|
+
pdf_results.append(response)
|
|
219
|
+
|
|
220
|
+
ocr_text = self.page_delimiter.join(pdf_results)
|
|
221
|
+
# Image
|
|
222
|
+
else:
|
|
223
|
+
image = get_image_from_file(file_path)
|
|
224
|
+
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
225
|
+
ocr_text = self.vlm_engine.chat(
|
|
226
|
+
messages,
|
|
227
|
+
max_new_tokens=max_new_tokens,
|
|
228
|
+
temperature=temperature,
|
|
229
|
+
verbose=verbose,
|
|
230
|
+
stream=False,
|
|
231
|
+
**kwrs
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Clean markdown
|
|
235
|
+
if self.output_mode == "markdown":
|
|
236
|
+
ocr_text = clean_markdown(ocr_text)
|
|
237
|
+
ocr_results.append(ocr_text)
|
|
238
|
+
|
|
239
|
+
return ocr_results
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
async def _run_ocr_async(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
|
|
243
|
+
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[str]:
|
|
244
|
+
"""
|
|
245
|
+
This is the async version of the _run_ocr method.
|
|
246
|
+
"""
|
|
247
|
+
# flatten pages/images in file_paths
|
|
248
|
+
flat_page_list = []
|
|
249
|
+
for file_path in file_paths:
|
|
250
|
+
file_ext = os.path.splitext(file_path)[1].lower()
|
|
251
|
+
# PDF
|
|
252
|
+
if file_ext == '.pdf':
|
|
253
|
+
images = get_images_from_pdf(file_path)
|
|
254
|
+
if not images:
|
|
255
|
+
flat_page_list.append({'file_path': file_path, 'file_type': "PDF", "image": image, "page_num": 0, "total_page_count": 0})
|
|
256
|
+
for page_num, image in enumerate(images):
|
|
257
|
+
flat_page_list.append({'file_path': file_path, 'file_type': "PDF", "image": image, "page_num": page_num, "total_page_count": len(images)})
|
|
258
|
+
# Image
|
|
259
|
+
else:
|
|
260
|
+
image = get_image_from_file(file_path)
|
|
261
|
+
flat_page_list.append({'file_path': file_path, 'file_type': "image", "image": image})
|
|
262
|
+
|
|
263
|
+
# Process images with asyncio.Semaphore
|
|
264
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
265
|
+
async def semaphore_helper(page:List[Dict[str,str]], max_new_tokens:int, temperature:float, **kwrs):
|
|
266
|
+
try:
|
|
267
|
+
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, page["image"])
|
|
268
|
+
async with semaphore:
|
|
269
|
+
async_task = self.vlm_engine.chat_async(
|
|
270
|
+
messages,
|
|
271
|
+
max_new_tokens=max_new_tokens,
|
|
272
|
+
temperature=temperature,
|
|
273
|
+
**kwrs
|
|
274
|
+
)
|
|
275
|
+
return await async_task
|
|
276
|
+
except Exception as e:
|
|
277
|
+
print(f"Error processing image: {e}")
|
|
278
|
+
return f"[Error: {e}]"
|
|
279
|
+
|
|
280
|
+
tasks = []
|
|
281
|
+
for page in flat_page_list:
|
|
282
|
+
async_task = semaphore_helper(
|
|
283
|
+
page,
|
|
284
|
+
max_new_tokens=max_new_tokens,
|
|
285
|
+
temperature=temperature,
|
|
286
|
+
**kwrs
|
|
287
|
+
)
|
|
288
|
+
tasks.append(asyncio.create_task(async_task))
|
|
289
|
+
|
|
290
|
+
responses = await asyncio.gather(*tasks)
|
|
291
|
+
|
|
292
|
+
# Restructure the results
|
|
293
|
+
ocr_results = []
|
|
294
|
+
pdf_page_text_buffer = ""
|
|
295
|
+
for page, ocr_text in zip(flat_page_list, responses):
|
|
296
|
+
# PDF
|
|
297
|
+
if page['file_type'] == "PDF":
|
|
298
|
+
pdf_page_text_buffer += ocr_text + self.page_delimiter
|
|
299
|
+
if page['page_num'] == page['total_page_count'] - 1:
|
|
300
|
+
if self.output_mode == "markdown":
|
|
301
|
+
pdf_page_text_buffer = clean_markdown(pdf_page_text_buffer)
|
|
302
|
+
ocr_results.append(pdf_page_text_buffer)
|
|
303
|
+
pdf_page_text_buffer = ""
|
|
304
|
+
# Image
|
|
305
|
+
if page['file_type'] == "image":
|
|
306
|
+
if self.output_mode == "markdown":
|
|
307
|
+
ocr_text = clean_markdown(ocr_text)
|
|
308
|
+
ocr_results.append(ocr_text)
|
|
309
|
+
|
|
310
|
+
return ocr_results
|
|
311
|
+
|
vlm4ocr/utils.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import io
|
|
3
|
+
import base64
|
|
4
|
+
from typing import List
|
|
5
|
+
from pdf2image import convert_from_path
|
|
6
|
+
from PIL import Image
|
|
7
|
+
|
|
8
|
+
def get_images_from_pdf(file_path: str) -> List[Image.Image]:
|
|
9
|
+
""" Extracts images from a PDF file. """
|
|
10
|
+
try:
|
|
11
|
+
images = convert_from_path(file_path)
|
|
12
|
+
if not images:
|
|
13
|
+
print(f"Warning: No images extracted from PDF: {file_path}")
|
|
14
|
+
return images
|
|
15
|
+
except Exception as e:
|
|
16
|
+
print(f"Error converting PDF to images: {e}")
|
|
17
|
+
raise ValueError(f"Failed to process PDF file '{os.path.basename(file_path)}'. Ensure poppler is installed and the file is valid.") from e
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_image_from_file(file_path: str) -> Image.Image:
|
|
21
|
+
""" Loads a single image file. """
|
|
22
|
+
try:
|
|
23
|
+
image = Image.open(file_path)
|
|
24
|
+
image.load()
|
|
25
|
+
return image
|
|
26
|
+
except FileNotFoundError:
|
|
27
|
+
raise FileNotFoundError(f"Image file not found: {file_path}")
|
|
28
|
+
except Exception as e:
|
|
29
|
+
raise ValueError(f"Failed to load image file '{os.path.basename(file_path)}': {e}") from e
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def image_to_base64(image:Image.Image, format:str="png") -> str:
|
|
33
|
+
""" Converts an image to a base64 string. """
|
|
34
|
+
try:
|
|
35
|
+
buffered = io.BytesIO()
|
|
36
|
+
image.save(buffered, format=format)
|
|
37
|
+
img_bytes = buffered.getvalue()
|
|
38
|
+
encoded_bytes = base64.b64encode(img_bytes)
|
|
39
|
+
base64_encoded_string = encoded_bytes.decode('utf-8')
|
|
40
|
+
return base64_encoded_string
|
|
41
|
+
except Exception as e:
|
|
42
|
+
print(f"Error converting image to base64: {e}")
|
|
43
|
+
raise ValueError(f"Failed to convert image to base64: {e}") from e
|
|
44
|
+
|
|
45
|
+
def clean_markdown(text:str) -> str:
|
|
46
|
+
cleaned_text = text.replace("```markdown", "").replace("```", "")
|
|
47
|
+
return cleaned_text
|
vlm4ocr/vlm_engines.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import List, Dict, Union, Generator
|
|
4
|
+
import warnings
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from vlm4ocr.utils import image_to_base64
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VLMEngine:
|
|
10
|
+
@abc.abstractmethod
|
|
11
|
+
def __init__(self):
|
|
12
|
+
"""
|
|
13
|
+
This is an abstract class to provide interfaces for VLM inference engines.
|
|
14
|
+
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
15
|
+
"""
|
|
16
|
+
return NotImplemented
|
|
17
|
+
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
|
|
20
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
21
|
+
"""
|
|
22
|
+
This method inputs chat messages and outputs VLM generated text.
|
|
23
|
+
|
|
24
|
+
Parameters:
|
|
25
|
+
----------
|
|
26
|
+
messages : List[Dict[str,str]]
|
|
27
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
28
|
+
max_new_tokens : str, Optional
|
|
29
|
+
the max number of new tokens VLM can generate.
|
|
30
|
+
temperature : float, Optional
|
|
31
|
+
the temperature for token sampling.
|
|
32
|
+
verbose : bool, Optional
|
|
33
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
34
|
+
stream : bool, Optional
|
|
35
|
+
if True, returns a generator that yields the output in real-time.
|
|
36
|
+
"""
|
|
37
|
+
return NotImplemented
|
|
38
|
+
|
|
39
|
+
@abc.abstractmethod
|
|
40
|
+
def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
|
|
41
|
+
"""
|
|
42
|
+
The async version of chat method. Streaming is not supported.
|
|
43
|
+
"""
|
|
44
|
+
return NotImplemented
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image_path:str) -> List[Dict[str,str]]:
|
|
48
|
+
"""
|
|
49
|
+
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
50
|
+
|
|
51
|
+
Parameters:
|
|
52
|
+
----------
|
|
53
|
+
system_prompt : str
|
|
54
|
+
the system prompt.
|
|
55
|
+
user_prompt : str
|
|
56
|
+
the user prompt.
|
|
57
|
+
image_path : str
|
|
58
|
+
the image path for OCR.
|
|
59
|
+
"""
|
|
60
|
+
return NotImplemented
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class OllamaVLMEngine(VLMEngine):
|
|
64
|
+
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
|
|
65
|
+
"""
|
|
66
|
+
The Ollama inference engine.
|
|
67
|
+
|
|
68
|
+
Parameters:
|
|
69
|
+
----------
|
|
70
|
+
model_name : str
|
|
71
|
+
the model name exactly as shown in >> ollama ls
|
|
72
|
+
num_ctx : int, Optional
|
|
73
|
+
context length that VLM will evaluate.
|
|
74
|
+
keep_alive : int, Optional
|
|
75
|
+
seconds to hold the VLM after the last API call.
|
|
76
|
+
"""
|
|
77
|
+
if importlib.util.find_spec("ollama") is None:
|
|
78
|
+
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
79
|
+
|
|
80
|
+
from ollama import Client, AsyncClient
|
|
81
|
+
self.client = Client(**kwrs)
|
|
82
|
+
self.async_client = AsyncClient(**kwrs)
|
|
83
|
+
self.model_name = model_name
|
|
84
|
+
self.num_ctx = num_ctx
|
|
85
|
+
self.keep_alive = keep_alive
|
|
86
|
+
|
|
87
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
|
|
88
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
89
|
+
"""
|
|
90
|
+
This method inputs chat messages and outputs VLM generated text.
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
----------
|
|
94
|
+
messages : List[Dict[str,str]]
|
|
95
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
96
|
+
max_new_tokens : str, Optional
|
|
97
|
+
the max number of new tokens VLM can generate.
|
|
98
|
+
temperature : float, Optional
|
|
99
|
+
the temperature for token sampling.
|
|
100
|
+
verbose : bool, Optional
|
|
101
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
102
|
+
stream : bool, Optional
|
|
103
|
+
if True, returns a generator that yields the output in real-time.
|
|
104
|
+
"""
|
|
105
|
+
options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs}
|
|
106
|
+
if stream:
|
|
107
|
+
def _stream_generator():
|
|
108
|
+
response_stream = self.client.chat(
|
|
109
|
+
model=self.model_name,
|
|
110
|
+
messages=messages,
|
|
111
|
+
options=options,
|
|
112
|
+
stream=True,
|
|
113
|
+
keep_alive=self.keep_alive
|
|
114
|
+
)
|
|
115
|
+
for chunk in response_stream:
|
|
116
|
+
content_chunk = chunk.get('message', {}).get('content')
|
|
117
|
+
if content_chunk:
|
|
118
|
+
yield content_chunk
|
|
119
|
+
|
|
120
|
+
return _stream_generator()
|
|
121
|
+
|
|
122
|
+
elif verbose:
|
|
123
|
+
response = self.client.chat(
|
|
124
|
+
model=self.model_name,
|
|
125
|
+
messages=messages,
|
|
126
|
+
options=options,
|
|
127
|
+
stream=True,
|
|
128
|
+
keep_alive=self.keep_alive
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
res = ''
|
|
132
|
+
for chunk in response:
|
|
133
|
+
content_chunk = chunk.get('message', {}).get('content')
|
|
134
|
+
print(content_chunk, end='', flush=True)
|
|
135
|
+
res += content_chunk
|
|
136
|
+
print('\n')
|
|
137
|
+
return res
|
|
138
|
+
|
|
139
|
+
return response.get('message', {}).get('content', '')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
|
|
143
|
+
"""
|
|
144
|
+
Async version of chat method. Streaming is not supported.
|
|
145
|
+
"""
|
|
146
|
+
response = await self.async_client.chat(
|
|
147
|
+
model=self.model_name,
|
|
148
|
+
messages=messages,
|
|
149
|
+
options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
|
|
150
|
+
stream=False,
|
|
151
|
+
keep_alive=self.keep_alive
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return response.get('message', {}).get('content', '')
|
|
155
|
+
|
|
156
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
157
|
+
"""
|
|
158
|
+
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
159
|
+
|
|
160
|
+
Parameters:
|
|
161
|
+
----------
|
|
162
|
+
system_prompt : str
|
|
163
|
+
the system prompt.
|
|
164
|
+
user_prompt : str
|
|
165
|
+
the user prompt.
|
|
166
|
+
image_path : str
|
|
167
|
+
the image path for OCR.
|
|
168
|
+
"""
|
|
169
|
+
base64_str = image_to_base64(image)
|
|
170
|
+
return [
|
|
171
|
+
{"role": "system", "content": system_prompt},
|
|
172
|
+
{
|
|
173
|
+
"role": "user",
|
|
174
|
+
"content": user_prompt,
|
|
175
|
+
"images": [base64_str]
|
|
176
|
+
}
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class OpenAIVLMEngine(VLMEngine):
|
|
181
|
+
def __init__(self, model:str, reasoning_model:bool=False, **kwrs):
|
|
182
|
+
"""
|
|
183
|
+
The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
|
|
184
|
+
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
185
|
+
|
|
186
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
187
|
+
|
|
188
|
+
Parameters:
|
|
189
|
+
----------
|
|
190
|
+
model_name : str
|
|
191
|
+
model name as described in https://platform.openai.com/docs/models
|
|
192
|
+
reasoning_model : bool, Optional
|
|
193
|
+
indicator for OpenAI reasoning models ("o" series).
|
|
194
|
+
"""
|
|
195
|
+
if importlib.util.find_spec("openai") is None:
|
|
196
|
+
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
197
|
+
|
|
198
|
+
from openai import OpenAI, AsyncOpenAI
|
|
199
|
+
self.client = OpenAI(**kwrs)
|
|
200
|
+
self.async_client = AsyncOpenAI(**kwrs)
|
|
201
|
+
self.model = model
|
|
202
|
+
self.reasoning_model = reasoning_model
|
|
203
|
+
|
|
204
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
|
|
205
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
206
|
+
"""
|
|
207
|
+
This method inputs chat messages and outputs VLM generated text.
|
|
208
|
+
|
|
209
|
+
Parameters:
|
|
210
|
+
----------
|
|
211
|
+
messages : List[Dict[str,str]]
|
|
212
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
213
|
+
max_new_tokens : str, Optional
|
|
214
|
+
the max number of new tokens VLM can generate.
|
|
215
|
+
temperature : float, Optional
|
|
216
|
+
the temperature for token sampling.
|
|
217
|
+
verbose : bool, Optional
|
|
218
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
219
|
+
stream : bool, Optional
|
|
220
|
+
if True, returns a generator that yields the output in real-time.
|
|
221
|
+
"""
|
|
222
|
+
# For reasoning models
|
|
223
|
+
if self.reasoning_model:
|
|
224
|
+
# Reasoning models do not support temperature parameter
|
|
225
|
+
if temperature != 0.0:
|
|
226
|
+
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
227
|
+
|
|
228
|
+
# Reasoning models do not support system prompts
|
|
229
|
+
if any(msg['role'] == 'system' for msg in messages):
|
|
230
|
+
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
231
|
+
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
if stream:
|
|
235
|
+
def _stream_generator():
|
|
236
|
+
response_stream = self.client.chat.completions.create(
|
|
237
|
+
model=self.model,
|
|
238
|
+
messages=messages,
|
|
239
|
+
max_completion_tokens=max_new_tokens,
|
|
240
|
+
stream=True,
|
|
241
|
+
**kwrs
|
|
242
|
+
)
|
|
243
|
+
for chunk in response_stream:
|
|
244
|
+
if len(chunk.choices) > 0:
|
|
245
|
+
if chunk.choices[0].delta.content is not None:
|
|
246
|
+
yield chunk.choices[0].delta.content
|
|
247
|
+
if chunk.choices[0].finish_reason == "length":
|
|
248
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
249
|
+
if self.reasoning_model:
|
|
250
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
251
|
+
return _stream_generator()
|
|
252
|
+
|
|
253
|
+
elif verbose:
|
|
254
|
+
response = self.client.chat.completions.create(
|
|
255
|
+
model=self.model,
|
|
256
|
+
messages=messages,
|
|
257
|
+
max_completion_tokens=max_new_tokens,
|
|
258
|
+
stream=True,
|
|
259
|
+
**kwrs
|
|
260
|
+
)
|
|
261
|
+
res = ''
|
|
262
|
+
for chunk in response:
|
|
263
|
+
if len(chunk.choices) > 0:
|
|
264
|
+
if chunk.choices[0].delta.content is not None:
|
|
265
|
+
res += chunk.choices[0].delta.content
|
|
266
|
+
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
267
|
+
if chunk.choices[0].finish_reason == "length":
|
|
268
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
269
|
+
if self.reasoning_model:
|
|
270
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
271
|
+
|
|
272
|
+
print('\n')
|
|
273
|
+
return res
|
|
274
|
+
else:
|
|
275
|
+
response = self.client.chat.completions.create(
|
|
276
|
+
model=self.model,
|
|
277
|
+
messages=messages,
|
|
278
|
+
max_completion_tokens=max_new_tokens,
|
|
279
|
+
stream=False,
|
|
280
|
+
**kwrs
|
|
281
|
+
)
|
|
282
|
+
return response.choices[0].message.content
|
|
283
|
+
|
|
284
|
+
# For non-reasoning models
|
|
285
|
+
else:
|
|
286
|
+
if stream:
|
|
287
|
+
def _stream_generator():
|
|
288
|
+
response_stream = self.client.chat.completions.create(
|
|
289
|
+
model=self.model,
|
|
290
|
+
messages=messages,
|
|
291
|
+
max_tokens=max_new_tokens,
|
|
292
|
+
temperature=temperature,
|
|
293
|
+
stream=True,
|
|
294
|
+
**kwrs
|
|
295
|
+
)
|
|
296
|
+
for chunk in response_stream:
|
|
297
|
+
if len(chunk.choices) > 0:
|
|
298
|
+
if chunk.choices[0].delta.content is not None:
|
|
299
|
+
yield chunk.choices[0].delta.content
|
|
300
|
+
if chunk.choices[0].finish_reason == "length":
|
|
301
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
302
|
+
if self.reasoning_model:
|
|
303
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
304
|
+
return _stream_generator()
|
|
305
|
+
|
|
306
|
+
elif verbose:
|
|
307
|
+
response = self.client.chat.completions.create(
|
|
308
|
+
model=self.model,
|
|
309
|
+
messages=messages,
|
|
310
|
+
max_tokens=max_new_tokens,
|
|
311
|
+
temperature=temperature,
|
|
312
|
+
stream=True,
|
|
313
|
+
**kwrs
|
|
314
|
+
)
|
|
315
|
+
res = ''
|
|
316
|
+
for chunk in response:
|
|
317
|
+
if len(chunk.choices) > 0:
|
|
318
|
+
if chunk.choices[0].delta.content is not None:
|
|
319
|
+
res += chunk.choices[0].delta.content
|
|
320
|
+
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
321
|
+
if chunk.choices[0].finish_reason == "length":
|
|
322
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
323
|
+
if self.reasoning_model:
|
|
324
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
325
|
+
|
|
326
|
+
print('\n')
|
|
327
|
+
return res
|
|
328
|
+
|
|
329
|
+
else:
|
|
330
|
+
response = self.client.chat.completions.create(
|
|
331
|
+
model=self.model,
|
|
332
|
+
messages=messages,
|
|
333
|
+
max_tokens=max_new_tokens,
|
|
334
|
+
temperature=temperature,
|
|
335
|
+
stream=False,
|
|
336
|
+
**kwrs
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
return response.choices[0].message.content
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
|
|
343
|
+
"""
|
|
344
|
+
Async version of chat method. Streaming is not supported.
|
|
345
|
+
"""
|
|
346
|
+
if self.reasoning_model:
|
|
347
|
+
# Reasoning models do not support temperature parameter
|
|
348
|
+
if temperature != 0.0:
|
|
349
|
+
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
350
|
+
|
|
351
|
+
# Reasoning models do not support system prompts
|
|
352
|
+
if any(msg['role'] == 'system' for msg in messages):
|
|
353
|
+
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
354
|
+
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
355
|
+
|
|
356
|
+
response = await self.async_client.chat.completions.create(
|
|
357
|
+
model=self.model,
|
|
358
|
+
messages=messages,
|
|
359
|
+
max_completion_tokens=max_new_tokens,
|
|
360
|
+
stream=False,
|
|
361
|
+
**kwrs
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
else:
|
|
365
|
+
response = await self.async_client.chat.completions.create(
|
|
366
|
+
model=self.model,
|
|
367
|
+
messages=messages,
|
|
368
|
+
max_tokens=max_new_tokens,
|
|
369
|
+
temperature=temperature,
|
|
370
|
+
stream=False,
|
|
371
|
+
**kwrs
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
if response.choices[0].finish_reason == "length":
|
|
375
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
376
|
+
if self.reasoning_model:
|
|
377
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
378
|
+
|
|
379
|
+
return response.choices[0].message.content
|
|
380
|
+
|
|
381
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
|
|
382
|
+
"""
|
|
383
|
+
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
384
|
+
|
|
385
|
+
Parameters:
|
|
386
|
+
----------
|
|
387
|
+
system_prompt : str
|
|
388
|
+
the system prompt.
|
|
389
|
+
user_prompt : str
|
|
390
|
+
the user prompt.
|
|
391
|
+
image : Image.Image
|
|
392
|
+
the image for OCR.
|
|
393
|
+
format : str, Optional
|
|
394
|
+
the image format.
|
|
395
|
+
detail : str, Optional
|
|
396
|
+
the detail level of the image. Default is "high".
|
|
397
|
+
"""
|
|
398
|
+
base64_str = image_to_base64(image)
|
|
399
|
+
return [
|
|
400
|
+
{"role": "system", "content": system_prompt},
|
|
401
|
+
{
|
|
402
|
+
"role": "user",
|
|
403
|
+
"content": [
|
|
404
|
+
{
|
|
405
|
+
"type": "image_url",
|
|
406
|
+
"image_url": {
|
|
407
|
+
"url": f"data:image/{format};base64,{base64_str}",
|
|
408
|
+
"detail": detail
|
|
409
|
+
},
|
|
410
|
+
},
|
|
411
|
+
{"type": "text", "text": user_prompt},
|
|
412
|
+
],
|
|
413
|
+
},
|
|
414
|
+
]
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class AzureOpenAIVLMEngine(OpenAIVLMEngine):
|
|
418
|
+
def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
|
|
419
|
+
"""
|
|
420
|
+
The Azure OpenAI API inference engine.
|
|
421
|
+
For parameters and documentation, refer to
|
|
422
|
+
- https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
|
423
|
+
- https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
|
|
424
|
+
|
|
425
|
+
Parameters:
|
|
426
|
+
----------
|
|
427
|
+
model : str
|
|
428
|
+
model name as described in https://platform.openai.com/docs/models
|
|
429
|
+
api_version : str
|
|
430
|
+
the Azure OpenAI API version
|
|
431
|
+
reasoning_model : bool, Optional
|
|
432
|
+
indicator for OpenAI reasoning models ("o" series).
|
|
433
|
+
"""
|
|
434
|
+
if importlib.util.find_spec("openai") is None:
|
|
435
|
+
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
436
|
+
|
|
437
|
+
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
438
|
+
self.model = model
|
|
439
|
+
self.api_version = api_version
|
|
440
|
+
self.client = AzureOpenAI(api_version=self.api_version,
|
|
441
|
+
**kwrs)
|
|
442
|
+
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
443
|
+
**kwrs)
|
|
444
|
+
self.reasoning_model = reasoning_model
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: vlm4ocr
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: OCR with vision language models.
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Enshuo (David) Hsu
|
|
7
|
+
Requires-Python: >=3.11,<4.0
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: pdf2image (>=1.16.0)
|
|
13
|
+
Requires-Dist: pillow (>=10.0.0)
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
This is the readme for vlm4ocr Python package.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
vlm4ocr/__init__.py,sha256=gD7WuuxH9_iUu8dgbfonRQD89xiveVP3e9YP4RAHeIQ,219
|
|
2
|
+
vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt,sha256=pIsYO2G3jkZ5EWg7MJixre3Itz1oPqJSduUZT34_RNY,436
|
|
3
|
+
vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt,sha256=WbLSOerqFjlYGaGWJ-w2enhky1WhnPl011s0fgRPgnQ,398
|
|
4
|
+
vlm4ocr/assets/default_prompt_templates/ocr_user_prompt.txt,sha256=61EJv8POsQGIIUVwCjDU73lMXJE7F3qhPIYl6zSbl1Q,45
|
|
5
|
+
vlm4ocr/ocr_engines.py,sha256=gNwN_itHqIQLF3cfMoGV51lAuY8ZJI2A78LWJivqVi4,13687
|
|
6
|
+
vlm4ocr/utils.py,sha256=3o6TLY8YIEorYHluHrNoJesEM8th89uRO_KNEzfdDA8,1763
|
|
7
|
+
vlm4ocr/vlm_engines.py,sha256=3Aymjo3xa-CTHbktOIHSeLSLAdbo7fk9qRw8d3uRmOM,18836
|
|
8
|
+
vlm4ocr-0.0.1.dist-info/METADATA,sha256=_8E0__RQ6bJy_WLlkirJRvf3G8LqmG5_WtOF1ddBaZw,520
|
|
9
|
+
vlm4ocr-0.0.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
10
|
+
vlm4ocr-0.0.1.dist-info/RECORD,,
|